Skip to content

Commit

Permalink
add pickling support (#34134)
Browse files Browse the repository at this point in the history
* add pickling support

* update

* update

* update

* update

* update

* typing

* black

* update

* update

* updates

* black

* update

* update changelog
  • Loading branch information
xiangyan99 authored Feb 22, 2024
1 parent da82787 commit 8b921c4
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 1 deletion.
6 changes: 6 additions & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,12 @@
"Jwcmlud"
]
},
{
"filename": "sdk/identity/azure-identity/tests/*.py",
"words": [
"infile"
]
},
{
"filename": "sdk/identity/test-resources*",
"words": [
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

- Added pickling support. ([#34134](https://github.com/Azure/azure-sdk-for-python/pull/34134))

### Breaking Changes

### Bugs Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def __init__(
self._cache = cache
self._cae_cache = cae_cache
self._cache_options = kwargs.pop("cache_persistence_options", None)
if self._cache or self._cae_cache:
self._custom_cache = True
else:
self._custom_cache = False

def _get_cache(self, **kwargs: Any) -> TokenCache:
cache = self._cae_cache if kwargs.get("enable_cae") else self._cache
Expand Down Expand Up @@ -346,6 +350,21 @@ def _post(self, data: Dict, **kwargs: Any) -> HttpRequest:
url = self._get_token_url(**kwargs)
return HttpRequest("POST", url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"})

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
# Remove the non-picklable entries
if not self._custom_cache:
del state["_cache"]
del state["_cae_cache"]
return state

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
# Re-create the unpickable entries
if not self._custom_cache:
self._cache = None
self._cae_cache = None


def _merge_claims_challenge_and_capabilities(capabilities, claims_challenge):
# Represent capabilities as {"access_token": {"xms_cc": {"values": capabilities}}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def __init__(
identity_config: Optional[Dict] = None,
**kwargs: Any
) -> None:
self._cache = kwargs.pop("_cache", None) or TokenCache()
self._custom_cache = False
self._cache = kwargs.pop("_cache", None)
if self._cache:
self._custom_cache = True
else:
self._cache = TokenCache()
self._content_callback = kwargs.pop("_content_callback", None)
self._identity_config = identity_config or {}
if client_id:
Expand Down Expand Up @@ -91,6 +96,19 @@ def request_token(self, *scopes, **kwargs):
def _build_pipeline(self, **kwargs):
pass

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
# Remove the non-picklable entries
if not self._custom_cache:
del state["_cache"]
return state

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
# Re-create the unpickable entries
if not self._custom_cache:
self._cache = TokenCache()


class ManagedIdentityClient(ManagedIdentityClientBase):
def __enter__(self) -> "ManagedIdentityClient":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,14 @@ def _store_auth_error(self, response: PipelineResponse) -> None:
content = response.context.get(ContentDecodePolicy.CONTEXT_NAME)
if content and "error" in content:
self._local.error = (content["error"], response.http_response)

def __getstate__(self) -> Dict[str, Any]: # pylint:disable=client-method-name-no-double-underscore
state = self.__dict__.copy()
# Remove the non-picklable entries
del state["_local"]
return state

def __setstate__(self, state: Dict[str, Any]) -> None: # pylint:disable=client-method-name-no-double-underscore
self.__dict__.update(state)
# Re-create the unpickable entries
self._local = threading.local()
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __init__(

self._cache = kwargs.pop("_cache", None)
self._cae_cache = kwargs.pop("_cae_cache", None)
if self._cache or self._cae_cache:
self._custom_cache = True
else:
self._custom_cache = False
self._cache_options = kwargs.pop("cache_persistence_options", None)

super(MsalCredential, self).__init__()
Expand Down Expand Up @@ -112,3 +116,22 @@ def _get_app(self, **kwargs: Any) -> msal.ClientApplication:
)

return client_applications_map[tenant_id]

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
# Remove the non-picklable entries
del state["_client_applications"]
del state["_cae_client_applications"]
if not self._custom_cache:
del state["_cache"]
del state["_cae_cache"]
return state

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
# Re-create the unpickable entries
self._client_applications = {}
self._cae_client_applications = {}
if not self._custom_cache:
self._cache = None
self._cae_cache = None
24 changes: 24 additions & 0 deletions sdk/identity/azure-identity/tests/test_pickling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import pickle
from azure.identity import DefaultAzureCredential
from azure.identity._internal.msal_credentials import MsalCredential


def test_pickle_dac():
cred = DefaultAzureCredential()
with open("data.pkl", "wb") as outfile:
pickle.dump(cred, outfile)
with open("data.pkl", "rb") as infile:
data_loaded = pickle.load(infile)


def test_pickle_msal_credential():
cred = MsalCredential(client_id="CLIENT_ID")
app = cred._get_app()
with open("data.pkl", "wb") as outfile:
pickle.dump(cred, outfile)
with open("data.pkl", "rb") as infile:
data_loaded = pickle.load(infile)
14 changes: 14 additions & 0 deletions sdk/identity/azure-identity/tests/test_pickling_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import pickle
from azure.identity.aio import DefaultAzureCredential


def test_pickle_dac():
cred = DefaultAzureCredential()
with open("data_aio.pkl", "wb") as outfile:
pickle.dump(cred, outfile)
with open("data_aio.pkl", "rb") as infile:
data_loaded = pickle.load(infile)

0 comments on commit 8b921c4

Please sign in to comment.