Skip to content

Commit

Permalink
fix: LVM - Update Video.load_from_file() to support storage.googlea…
Browse files Browse the repository at this point in the history
…pis.com links

PiperOrigin-RevId: 649149724
  • Loading branch information
holtskinner authored and copybara-github committed Jul 3, 2024
1 parent a6f68df commit b63f960
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 2 deletions.
72 changes: 72 additions & 0 deletions tests/unit/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,12 @@ def generate_video_from_gcs_uri(
return ga_vision_models.Video.load_from_file(gcs_uri)


def generate_video_from_storage_url(
gcs_uri: str = "https://storage.googleapis.com/cloud-samples-data/vertex-ai-vision/highway_vehicles.mp4",
) -> ga_vision_models.Video:
return ga_vision_models.Video.load_from_file(gcs_uri)


@pytest.mark.usefixtures("google_auth_mock")
class TestImageGenerationModels:
"""Unit tests for the image generation models."""
Expand Down Expand Up @@ -1215,6 +1221,72 @@ def test_video_embedding_model_with_only_video(self):
assert not embedding_response.text_embedding
assert not embedding_response.image_embedding

def test_video_embedding_model_with_storage_url(self):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
"multimodalembedding@001"
)

mock_get_publisher_model.assert_called_once_with(
name="publishers/google/models/multimodalembedding@001",
retry=base._DEFAULT_RETRY,
)

test_video_embeddings = [
ga_vision_models.VideoEmbedding(
start_offset_sec=0,
end_offset_sec=7,
embedding=[0, 7],
)
]

gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(
{
"videoEmbeddings": [
{
"startOffsetSec": test_video_embeddings[0].start_offset_sec,
"endOffsetSec": test_video_embeddings[0].end_offset_sec,
"embedding": test_video_embeddings[0].embedding,
}
]
}
)

video = generate_video_from_storage_url()

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
):
embedding_response = model.get_embeddings(video=video)

assert (
embedding_response.video_embeddings[0].embedding
== test_video_embeddings[0].embedding
)
assert (
embedding_response.video_embeddings[0].start_offset_sec
== test_video_embeddings[0].start_offset_sec
)
assert (
embedding_response.video_embeddings[0].end_offset_sec
== test_video_embeddings[0].end_offset_sec
)
assert not embedding_response.text_embedding
assert not embedding_response.image_embedding

def test_video_embedding_model_with_video_and_text(self):
aiplatform.init(
project=_TEST_PROJECT,
Expand Down
15 changes: 13 additions & 2 deletions vertexai/vision_models/_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(
video_bytes: Optional[bytes] = None,
gcs_uri: Optional[str] = None,
):
"""Creates an `Image` object.
"""Creates a `Video` object.
Args:
video_bytes: Video file bytes. Video can be in AVI, FLV, MKV, MOV,
Expand All @@ -211,9 +211,20 @@ def load_from_file(location: str) -> "Video":
Returns:
Loaded video as an `Video` object.
"""
if location.startswith("gs://"):
parsed_url = urllib.parse.urlparse(location)
if (
parsed_url.scheme == "https"
and parsed_url.netloc == "storage.googleapis.com"
):
parsed_url = parsed_url._replace(
scheme="gs", netloc="", path=f"/{urllib.parse.unquote(parsed_url.path)}"
)
location = urllib.parse.urlunparse(parsed_url)

if parsed_url.scheme == "gs":
return Video(gcs_uri=location)

# Load video from local path
video_bytes = pathlib.Path(location).read_bytes()
video = Video(video_bytes=video_bytes)
return video
Expand Down

0 comments on commit b63f960

Please sign in to comment.