Skip to content

Commit

Permalink
Set token to be expired if response comes back as unauthorized, withi…
Browse files Browse the repository at this point in the history
…n BeareTokenAuthenticationPolicy. (Azure#6151)

* Set token to be expired if response comes back as unauthorized.

* Add CL entry.

* Update CL.

* Use trc MinimumExpiration to invalidate the credential's token cache.

* Add test.

* Address PR feedback.

* Remove comment as it is no longer relevant.

* Use initializer list syntax to see if posix compilers are okay with that.

* Keep the bool field as non-atomic.

* Revert "Keep the bool field as non-atomic."

This reverts commit 1b8c762.
  • Loading branch information
ahsonkhan authored Oct 30, 2024
1 parent c168d73 commit 064fcad
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 5 deletions.
1 change: 1 addition & 0 deletions sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
### Bugs Fixed

- Fixed warning for an unused function in curl.cpp when building the SDK using a version of libcurl older than 7.77.0.
- Invalidate the token cache within `BearerTokenAuthenticationPolicy` whenever a token request comes back with a 401 response.

### Other Changes

Expand Down
2 changes: 2 additions & 0 deletions sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ namespace Azure { namespace Core { namespace Http { namespace Policies {
mutable Credentials::AccessToken m_accessToken;
mutable std::shared_timed_mutex m_accessTokenMutex;
mutable Credentials::TokenRequestContext m_accessTokenContext;
mutable std::atomic<bool> m_invalidateToken = {false};

public:
/**
Expand Down Expand Up @@ -610,6 +611,7 @@ namespace Azure { namespace Core { namespace Http { namespace Policies {
std::shared_lock<std::shared_timed_mutex> readLock(other.m_accessTokenMutex);
m_accessToken = other.m_accessToken;
m_accessTokenContext = other.m_accessTokenContext;
m_invalidateToken.store(other.m_invalidateToken.load());
}

void operator=(BearerTokenAuthenticationPolicy const&) = delete;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ std::unique_ptr<RawResponse> BearerTokenAuthenticationPolicy::Send(
auto result = AuthorizeAndSendRequest(request, nextPolicy, context);
{
auto const& response = *result;
m_invalidateToken = (response.GetStatusCode() == HttpStatusCode::Unauthorized);
auto const& challenge = AuthorizationChallengeHelper::GetChallenge(response);
if (!challenge.empty() && AuthorizeRequestOnChallenge(challenge, request, context))
{
Expand Down Expand Up @@ -67,9 +68,10 @@ bool TokenNeedsRefresh(
Azure::Core::Credentials::AccessToken const& cachedToken,
Azure::Core::Credentials::TokenRequestContext const& cachedTokenRequestContext,
Azure::DateTime const& currentTime,
Azure::Core::Credentials::TokenRequestContext const& newTokenRequestContext)
Azure::Core::Credentials::TokenRequestContext const& newTokenRequestContext,
bool forceRefresh)
{
return newTokenRequestContext.TenantId != cachedTokenRequestContext.TenantId
return forceRefresh || newTokenRequestContext.TenantId != cachedTokenRequestContext.TenantId
|| newTokenRequestContext.Scopes != cachedTokenRequestContext.Scopes
|| currentTime > (cachedToken.ExpiresOn - newTokenRequestContext.MinimumExpiration);
}
Expand All @@ -91,7 +93,12 @@ void BearerTokenAuthenticationPolicy::AuthenticateAndAuthorizeRequest(

{
std::shared_lock<std::shared_timed_mutex> readLock(m_accessTokenMutex);
if (!TokenNeedsRefresh(m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext))
if (!TokenNeedsRefresh(
m_accessToken,
m_accessTokenContext,
currentTime,
tokenRequestContext,
m_invalidateToken))
{
ApplyBearerToken(request, m_accessToken);
return;
Expand All @@ -100,10 +107,20 @@ void BearerTokenAuthenticationPolicy::AuthenticateAndAuthorizeRequest(

std::unique_lock<std::shared_timed_mutex> writeLock(m_accessTokenMutex);
// Check if token needs refresh for the second time in case another thread has just updated it.
if (TokenNeedsRefresh(m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext))
if (TokenNeedsRefresh(
m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext, m_invalidateToken))
{
m_accessToken = m_credential->GetToken(tokenRequestContext, context);
TokenRequestContext trcCopy = tokenRequestContext;
if (m_invalidateToken)
{
// Need to set this to invalidate the credential's token cache to ensure we fetch a new token
// on subsequent GetToken calls.
trcCopy.MinimumExpiration = DateTime::duration::max();
}

m_accessToken = m_credential->GetToken(trcCopy, context);
m_accessTokenContext = tokenRequestContext;
m_invalidateToken = false;
}

ApplyBearerToken(request, m_accessToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,37 @@ class TestTransportPolicy final : public HttpPolicy {
return std::make_unique<TestTransportPolicy>(*this);
}
};

class TestTransportPolicyMultipleResponses final : public HttpPolicy {
private:
mutable int m_responsesCount = 0;

public:
std::unique_ptr<RawResponse> Send(Request&, NextHttpPolicy, Context const&) const override
{
if (m_responsesCount == 1)
{
m_responsesCount++;
return std::make_unique<RawResponse>(1, 1, HttpStatusCode::Unauthorized, "TestStatus");
}
if (m_responsesCount == 2)
{
m_responsesCount++;
return std::make_unique<RawResponse>(1, 1, HttpStatusCode::Ok, "TestStatus");
}
if (m_responsesCount > 2)
{
EXPECT_TRUE(false);
}
m_responsesCount++;
return std::make_unique<RawResponse>(1, 1, HttpStatusCode::Ok, "TestStatus");
}

std::unique_ptr<HttpPolicy> Clone() const override
{
return std::make_unique<TestTransportPolicyMultipleResponses>(*this);
}
};
} // namespace

TEST(BearerTokenAuthenticationPolicy, InitialGet)
Expand Down Expand Up @@ -169,6 +200,66 @@ TEST(BearerTokenAuthenticationPolicy, RefreshNearExpiry)
}
}

TEST(BearerTokenAuthenticationPolicy, TokenInvalidatedAfterUnauth)
{
using namespace std::chrono_literals;
auto accessToken = std::make_shared<AccessToken>();

std::vector<std::unique_ptr<HttpPolicy>> policies;

TokenRequestContext tokenRequestContext;
tokenRequestContext.Scopes = {"https://microsoft.com/.default"};

policies.emplace_back(std::make_unique<BearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));

policies.emplace_back(std::make_unique<TestTransportPolicyMultipleResponses>());

HttpPipeline pipeline(policies);

// The first request is successful, the token gets cached in the credential
{
Request request(HttpMethod::Get, Url("https://www.azure.com"));

*accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now() + 1h};

pipeline.Send(request, Context());

{
auto const headers = request.GetHeaders();
auto const authHeader = headers.find("authorization");
EXPECT_NE(authHeader, headers.end());
EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN1");
}
}

// The second request returns unauthorized, the token should be invalidated
{
Request request(HttpMethod::Get, Url("https://www.azure.com"));

*accessToken = {"ACCESSTOKEN2", std::chrono::system_clock::now() + 1h};

pipeline.Send(request, Context());

{
auto const headers = request.GetHeaders();
auto const authHeader = headers.find("authorization");
EXPECT_NE(authHeader, headers.end());
EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN1");
}

// We expect the next call to return a new token
pipeline.Send(request, Context());

{
auto const headers = request.GetHeaders();
auto const authHeader = headers.find("authorization");
EXPECT_NE(authHeader, headers.end());
EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN2");
}
}
}

TEST(BearerTokenAuthenticationPolicy, RefreshAfterExpiry)
{
using namespace std::chrono_literals;
Expand Down

0 comments on commit 064fcad

Please sign in to comment.