Skip to content

Commit

Permalink
fix: Made the Endpoint prediction client initialization lazy
Browse files Browse the repository at this point in the history
This is mainly to avoid issues with the `PredictionAsyncClient` which is based on `asyncio` and conflicts with other asynchronous solutions.
Fixes #2620

PiperOrigin-RevId: 574978613
  • Loading branch information
Ark-kun authored and copybara-github committed Oct 19, 2023
1 parent 98ab2f9 commit eb6071f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 88 deletions.
1 change: 0 additions & 1 deletion google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,6 @@ def _sync_object_with_future_result(
"credentials",
]
optional_sync_attributes = [
"_prediction_client",
"_authorized_session",
"_raw_predict_request_url",
]
Expand Down
93 changes: 30 additions & 63 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import pathlib
import re
Expand Down Expand Up @@ -227,16 +226,39 @@ def __init__(
# Lazy load the Endpoint gca_resource until needed
self._gca_resource = gca_endpoint_compat.Endpoint(name=endpoint_name)

(
self._prediction_client,
self._prediction_async_client,
) = self._instantiate_prediction_clients(
location=self.location,
credentials=credentials,
)
self.authorized_session = None
self.raw_predict_request_url = None

@property
def _prediction_client(self) -> utils.PredictionClientWithOverride:
# The attribute might not exist due to issues in
# `VertexAiResourceNounWithFutureManager._sync_object_with_future_result`
# We should switch to @functools.cached_property once its available.
if not getattr(self, "_prediction_client_value", None):
self._prediction_client_value = initializer.global_config.create_client(
client_class=utils.PredictionClientWithOverride,
credentials=self.credentials,
location_override=self.location,
prediction_client=True,
)
return self._prediction_client_value

@property
def _prediction_async_client(self) -> utils.PredictionAsyncClientWithOverride:
# The attribute might not exist due to issues in
# `VertexAiResourceNounWithFutureManager._sync_object_with_future_result`
# We should switch to @functools.cached_property once its available.
if not getattr(self, "_prediction_async_client_value", None):
self._prediction_async_client_value = (
initializer.global_config.create_client(
client_class=utils.PredictionAsyncClientWithOverride,
credentials=self.credentials,
location_override=self.location,
prediction_client=True,
)
)
return self._prediction_async_client_value

def _skipped_getter_call(self) -> bool:
"""Check if GAPIC resource was populated by call to get/list API methods
Expand Down Expand Up @@ -575,14 +597,6 @@ def _construct_sdk_resource_from_gapic(
location=location,
credentials=credentials,
)

(
endpoint._prediction_client,
endpoint._prediction_async_client,
) = cls._instantiate_prediction_clients(
location=endpoint.location,
credentials=credentials,
)
endpoint.authorized_session = None
endpoint.raw_predict_request_url = None

Expand Down Expand Up @@ -1390,53 +1404,6 @@ def _undeploy(
# update local resource
self._sync_gca_resource()

@staticmethod
def _instantiate_prediction_clients(
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Tuple[
utils.PredictionClientWithOverride, utils.PredictionAsyncClientWithOverride
]:
"""Helper method to instantiates prediction client with optional
overrides for this endpoint.
Args:
location (str): The location of this endpoint.
credentials (google.auth.credentials.Credentials):
Optional custom credentials to use when accessing interacting with
the prediction client.
Returns:
prediction_client (prediction_service_client.PredictionServiceClient):
prediction_async_client (PredictionServiceAsyncClient):
Initialized prediction clients with optional overrides.
"""

# Creating an event loop if needed.
# PredictionServiceAsyncClient constructor calls `asyncio.get_event_loop`,
# which fails when there is no event loop (which does not exist by default
# in non-main threads in thread pool used when `sync=False`).
try:
asyncio.get_event_loop()
except RuntimeError:
asyncio.set_event_loop(asyncio.new_event_loop())

async_client = initializer.global_config.create_client(
client_class=utils.PredictionAsyncClientWithOverride,
credentials=credentials,
location_override=location,
prediction_client=True,
)
# We could use `client = async_client._client`, but then client would be
# a concrete `PredictionServiceClient`, not `PredictionClientWithOverride`.
client = initializer.global_config.create_client(
client_class=utils.PredictionClientWithOverride,
credentials=credentials,
location_override=location,
prediction_client=True,
)
return (client, async_client)

def update(
self,
display_name: Optional[str] = None,
Expand Down
24 changes: 0 additions & 24 deletions tests/unit/aiplatform/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,18 +658,6 @@ def test_constructor(self, create_endpoint_client_mock):
location_override=_TEST_LOCATION,
appended_user_agent=None,
),
mock.call(
client_class=utils.PredictionAsyncClientWithOverride,
credentials=None,
location_override=_TEST_LOCATION,
prediction_client=True,
),
mock.call(
client_class=utils.PredictionClientWithOverride,
credentials=None,
location_override=_TEST_LOCATION,
prediction_client=True,
),
]
)

Expand Down Expand Up @@ -754,18 +742,6 @@ def test_constructor_with_custom_credentials(self, create_endpoint_client_mock):
location_override=_TEST_LOCATION,
appended_user_agent=None,
),
mock.call(
client_class=utils.PredictionAsyncClientWithOverride,
credentials=creds,
location_override=_TEST_LOCATION,
prediction_client=True,
),
mock.call(
client_class=utils.PredictionClientWithOverride,
credentials=creds,
location_override=_TEST_LOCATION,
prediction_client=True,
),
]
)

Expand Down

0 comments on commit eb6071f

Please sign in to comment.