Skip to content

Commit

Permalink
feat: LVM - Added support for Images from GCS uri for multimodal embe…
Browse files Browse the repository at this point in the history
…ddings

PiperOrigin-RevId: 605748060
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Feb 9, 2024
1 parent 716f3e1 commit 90d95d7
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 7 deletions.
22 changes: 22 additions & 0 deletions tests/system/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def _create_blank_image(
return vision_models.Image.load_from_file(image_path)


def _load_image_from_gcs(
gcs_uri: str = "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
) -> vision_models.Image:
return vision_models.Image.load_from_file(gcs_uri)


class VisionModelTestSuite(e2e_base.TestEndToEnd):
"""System tests for vision models."""

Expand Down Expand Up @@ -85,6 +91,22 @@ def test_multi_modal_embedding_model(self):
assert len(embeddings.image_embedding) == 1408
assert len(embeddings.text_embedding) == 1408

def test_multi_modal_embedding_model_with_gcs_uri(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = ga_vision_models.MultiModalEmbeddingModel.from_pretrained(
"multimodalembedding@001"
)
image = _load_image_from_gcs()
embeddings = model.get_embeddings(
image=image,
# Optional:
contextual_text="this is a car",
)
# The service is expected to return the embeddings of size 1408
assert len(embeddings.image_embedding) == 1408
assert len(embeddings.text_embedding) == 1408

def test_image_generation_model_generate_images(self):
"""Tests the image generation model generating images."""
model = vision_models.ImageGenerationModel.from_pretrained(
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ def generate_image_from_file(
return ga_vision_models.Image.load_from_file(image_path)


def generate_image_from_gcs_uri(
gcs_uri: str = "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
) -> ga_vision_models.Image:
return ga_vision_models.Image.load_from_file(gcs_uri)


@pytest.mark.usefixtures("google_auth_mock")
class TestImageGenerationModels:
"""Unit tests for the image generation models."""
Expand Down Expand Up @@ -721,6 +727,42 @@ def test_image_embedding_model_with_lower_dimensions(self):
assert embedding_response.image_embedding == test_embeddings
assert embedding_response.text_embedding == test_embeddings

def test_image_embedding_model_with_gcs_uri(self):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
),
):
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
"multimodalembedding@001"
)

test_embeddings = [0, 0]
gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(
{"imageEmbedding": test_embeddings, "textEmbedding": test_embeddings}
)

image = generate_image_from_gcs_uri()

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
):
embedding_response = model.get_embeddings(
image=image, contextual_text="hello world"
)

assert embedding_response.image_embedding == test_embeddings
assert embedding_response.text_embedding == test_embeddings


@pytest.mark.usefixtures("google_auth_mock")
class ImageTextModelTests:
Expand Down
49 changes: 42 additions & 7 deletions vertexai/vision_models/_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import typing
from typing import Any, Dict, List, Optional, Union

from google.cloud import storage

from google.cloud.aiplatform import initializer as aiplatform_initializer
from vertexai._model_garden import _model_garden_models

# pylint: disable=g-import-not-at-top
Expand All @@ -45,31 +48,60 @@ class Image:

__module__ = "vertexai.vision_models"

_image_bytes: bytes
_loaded_bytes: Optional[bytes] = None
_loaded_image: Optional["PIL_Image.Image"] = None
_gcs_uri: Optional[str] = None

def __init__(self, image_bytes: bytes):
def __init__(
self,
image_bytes: Optional[bytes] = None,
gcs_uri: Optional[str] = None,
):
"""Creates an `Image` object.
Args:
image_bytes: Image file bytes. Image can be in PNG or JPEG format.
gcs_uri: Image URI in Google Cloud Storage.
"""
if bool(image_bytes) == bool(gcs_uri):
raise ValueError("Either image_bytes or gcs_uri must be provided.")

self._image_bytes = image_bytes
self._gcs_uri = gcs_uri

@staticmethod
def load_from_file(location: str) -> "Image":
"""Loads image from file.
"""Loads image from local file or Google Cloud Storage.
Args:
location: Local path from where to load the image.
location: Local path or Google Cloud Storage uri from where to load
the image.
Returns:
Loaded image as an `Image` object.
"""
if location.startswith("gs://"):
return Image(gcs_uri=location)

image_bytes = pathlib.Path(location).read_bytes()
image = Image(image_bytes=image_bytes)
return image

@property
def _image_bytes(self) -> bytes:
if self._loaded_bytes is None:
storage_client = storage.Client(
credentials=aiplatform_initializer.global_config.credentials
)
self._loaded_bytes = storage.Blob.from_string(
uri=self._gcs_uri, client=storage_client
).download_as_bytes()
return self._loaded_bytes

@_image_bytes.setter
def _image_bytes(self, value: bytes):
self._loaded_bytes = value

@property
def _pil_image(self) -> "PIL_Image.Image":
if self._loaded_image is None:
Expand Down Expand Up @@ -664,7 +696,7 @@ def get_embeddings(
values: `128`, `256`, `512`, and `1408` (default).
Returns:
ImageEmbeddingResponse:
MultiModalEmbeddingResponse:
The image and text embedding vectors.
"""

Expand All @@ -674,7 +706,10 @@ def get_embeddings(
instance = {}

if image:
instance["image"] = {"bytesBase64Encoded": image._as_base64_string()}
if image._gcs_uri:
instance["image"] = {"gcsUri": image._gcs_uri}
else:
instance["image"] = {"bytesBase64Encoded": image._as_base64_string()}

if contextual_text:
instance["text"] = contextual_text
Expand Down Expand Up @@ -702,7 +737,7 @@ def get_embeddings(

@dataclasses.dataclass
class MultiModalEmbeddingResponse:
"""The image embedding response.
"""The multimodal embedding response.
Attributes:
image_embedding (List[float]):
Expand Down

0 comments on commit 90d95d7

Please sign in to comment.