Skip to content

Commit

Permalink
chore: add an util function download_from_gcs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 543926933
  • Loading branch information
jaycee-li authored and copybara-github committed Jun 28, 2023
1 parent d84f687 commit f3b0e65
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 2 deletions.
49 changes: 48 additions & 1 deletion google/cloud/aiplatform/utils/gcs_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -300,6 +300,53 @@ def download_file_from_gcs(
source_blob.download_to_filename(filename=destination_file_path)


def download_from_gcs(
source_uri: str,
destination_path: str,
project: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Downloads GCS files to local path.
Args:
source_uri (str):
Required. GCS URI(or prefix) of the file(s) to download.
destination_path (str):
Required. local path where the data should be downloaded.
If provided a file path, then `source_uri` must refer to a file.
If provided a directory path, then `source_uri` must refer to a prefix.
project (str):
Optional. Google Cloud Project that contains the staging bucket.
credentials (auth_credentials.Credentials):
Optional. The custom credentials to use when making API calls.
If not provided, default credentials will be used.
Raises:
GoogleCloudError: When the download process fails.
"""
project = project or initializer.global_config.project
credentials = credentials or initializer.global_config.credentials

storage_client = storage.Client(project=project, credentials=credentials)

validate_gcs_path(source_uri)
bucket_name, prefix = source_uri.replace("gs://", "").split("/", maxsplit=1)

blobs = storage_client.list_blobs(bucket_or_name=bucket_name, prefix=prefix)
for blob in blobs:
# In SDK 2.0 remote training, we'll create some empty files.
# These files ends with '/', and we'll skip them.
if not blob.name.endswith("/"):
rel_path = os.path.relpath(blob.name, prefix)
filename = (
destination_path
if rel_path == "."
else os.path.join(destination_path, rel_path)
)
os.makedirs(os.path.dirname(filename), exist_ok=True)
blob.download_to_filename(filename=filename)


def _upload_pandas_df_to_gcs(
df: "pandas.DataFrame", upload_gcs_path: str, file_format: str = "jsonl"
) -> None:
Expand Down
87 changes: 86 additions & 1 deletion tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@
import json
import os
import re
import tempfile
import textwrap
from typing import Callable, Dict, Optional, Tuple
from unittest import mock
Expand Down Expand Up @@ -97,6 +98,36 @@ def mock_storage_blob_upload_from_filename():
yield mock_blob_upload_from_filename


@pytest.fixture
def mock_storage_client_list_blobs():
with patch("google.cloud.storage.Client.list_blobs") as mock_list_blobs:
mock_list_blobs.return_value = [
storage.Blob(name=f"{GCS_PREFIX}/", bucket=GCS_BUCKET),
storage.Blob(name=f"{GCS_PREFIX}/{FAKE_FILENAME}-1", bucket=GCS_BUCKET),
storage.Blob(
name=f"{GCS_PREFIX}/fake-dir/{FAKE_FILENAME}-2", bucket=GCS_BUCKET
),
]
yield mock_list_blobs


@pytest.fixture
def mock_storage_client_list_blob():
with patch("google.cloud.storage.Client.list_blobs") as mock_list_blobs:
mock_list_blobs.return_value = [
storage.Blob(name=f"{GCS_PREFIX}/{FAKE_FILENAME}", bucket=GCS_BUCKET),
]
yield mock_list_blobs


@pytest.fixture
def mock_storage_blob_download_to_filename():
with patch(
"google.cloud.storage.Blob.download_to_filename"
) as mock_blob_download_to_filename:
yield mock_blob_download_to_filename


@pytest.fixture()
def mock_bucket_not_exist():
with patch("google.cloud.storage.Blob.from_string") as mock_bucket_not_exist, patch(
Expand Down Expand Up @@ -570,6 +601,60 @@ def test_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
output == "gs://test-project-vertex-pipelines-us-central1/output_artifacts/"
)

def test_download_from_gcs_dir(
self, mock_storage_client_list_blobs, mock_storage_blob_download_to_filename
):
with tempfile.TemporaryDirectory() as temp_dir:
source_uri = f"gs://{GCS_BUCKET}/{GCS_PREFIX}"
destination_path = f"{temp_dir}/test-dir"

gcs_utils.download_from_gcs(source_uri, destination_path)

mock_storage_client_list_blobs.assert_called_once_with(
bucket_or_name=GCS_BUCKET,
prefix=GCS_PREFIX,
)

assert mock_storage_blob_download_to_filename.call_count == 2
mock_storage_blob_download_to_filename.assert_any_call(
filename=f"{destination_path}/{FAKE_FILENAME}-1"
)
mock_storage_blob_download_to_filename.assert_any_call(
filename=f"{destination_path}/fake-dir/{FAKE_FILENAME}-2"
)

def test_download_from_gcs_file(
self, mock_storage_client_list_blob, mock_storage_blob_download_to_filename
):
with tempfile.TemporaryDirectory() as temp_dir:
source_uri = f"gs://{GCS_BUCKET}/{GCS_PREFIX}/{FAKE_FILENAME}"
destination_path = f"{temp_dir}/test-file"

gcs_utils.download_from_gcs(source_uri, destination_path)

mock_storage_client_list_blob.assert_called_once_with(
bucket_or_name=GCS_BUCKET,
prefix=f"{GCS_PREFIX}/{FAKE_FILENAME}",
)

mock_storage_blob_download_to_filename.assert_called_once_with(
filename=destination_path
)

def test_download_from_gcs_invalid_source_uri(self):
with tempfile.TemporaryDirectory() as temp_dir:
source_uri = f"{GCS_BUCKET}/{GCS_PREFIX}"
destination_path = f"{temp_dir}/test-dir"

with pytest.raises(
ValueError,
match=(
f"Invalid GCS path {source_uri}. "
"Please provide a valid GCS path starting with 'gs://'"
),
):
gcs_utils.download_from_gcs(source_uri, destination_path)

def test_validate_gcs_path(self):
test_valid_path = "gs://test_valid_path"
gcs_utils.validate_gcs_path(test_valid_path)
Expand Down

0 comments on commit f3b0e65

Please sign in to comment.