Skip to content

Commit

Permalink
Add GCS storage provider
Browse files Browse the repository at this point in the history
  • Loading branch information
kristinagrig06 committed Aug 20, 2021
1 parent 67268d0 commit 467ae13
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 5 deletions.
3 changes: 3 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def pytest_addoption(parser):
LOCAL_OPT, action="store_true", help="Local tests will run if enabled."
)
parser.addoption(S3_OPT, action="store_true", help="S3 tests will run if enabled.")
parser.addoption(
GCS_OPT, action="store_true", help="GCS tests will run if enabled."
)
parser.addoption(
HUB_CLOUD_OPT, action="store_true", help="Hub cloud tests will run if enabled."
)
Expand Down
2 changes: 2 additions & 0 deletions hub/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
PYTEST_MEMORY_PROVIDER_BASE_ROOT = "mem://hub_pytest"
PYTEST_LOCAL_PROVIDER_BASE_ROOT = "/tmp/hub_pytest/" # TODO: may fail for windows
PYTEST_S3_PROVIDER_BASE_ROOT = "s3://hub-2.0-tests/"
PYTEST_GCS_PROVIDER_BASE_ROOT = "gcs://snark-test/"
PYTEST_HUB_CLOUD_PROVIDER_BASE_ROOT = f"hub://{HUB_CLOUD_DEV_USERNAME}/"

# environment variables
Expand All @@ -99,6 +100,7 @@
MEMORY_OPT = "--memory-skip"
LOCAL_OPT = "--local"
S3_OPT = "--s3"
GCS_OPT = "--gcs"
HUB_CLOUD_OPT = "--hub-cloud"
S3_PATH_OPT = "--s3-path"
KEEP_STORAGE_OPT = "--keep-storage"
Expand Down
76 changes: 76 additions & 0 deletions hub/core/storage/gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import posixpath
from typing import Optional
from hub.core.storage.provider import StorageProvider
import gcsfs


class GCSProvider(StorageProvider):
"""Provider class for using GC storage."""

def __init__(
self,
root: str,
token: Optional[str] = None,
):
"""Initializes the GCSProvider
Example:
s3_provider = GCSProvider("snark-test/benchmarks")
Args:
root (str): The root of the provider. All read/write request keys will be appended to root.
token (str, optional): GCP token, used for fetching credentials for storage).
"""
self.root = root
self.token: Optional[str] = token
self.missing_exceptions = (
FileNotFoundError,
IsADirectoryError,
NotADirectoryError,
)
self.initialize_provider()

def initialize_provider(self):
self._set_bucket_and_path()
self.fs = gcsfs.GCSFileSystem(token=self.token)

def _set_bucket_and_path(self):
root = self.root.replace("gcp://", "").replace("gcs://", "")
self.bucket = root.split("/")[0]
self.path = root
if not self.path.endswith("/"):
self.path += "/"

def clear(self):
"""Remove all keys below root - empties out mapping"""
self.fs.delete(self.path, True)

def __getitem__(self, key):
"""Retrieve data"""
try:
with self.fs.open(posixpath.join(self.path, key), "rb") as f:
return f.read()
except self.missing_exceptions:
raise KeyError(key)

def __setitem__(self, key, value):
"""Store value in key"""
with self.fs.open(posixpath.join(self.path, key), "wb") as f:
f.write(value)

def __iter__(self):
"""iterating over the structure"""
yield from (x for x in self.fs.find(self.root))

def __len__(self):
"""returns length of the structure"""
return len(self.fs.find(self.root))

def __delitem__(self, key):
"""Remove key"""
self.fs.rm(posixpath.join(self.path, key))

def __contains__(self, key):
"""Does key exist in mapping?"""
path = posixpath.join(self.path, key)
return self.fs.exists(path)
17 changes: 15 additions & 2 deletions hub/tests/dataset_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

enabled_datasets = pytest.mark.parametrize(
"ds",
["memory_ds", "local_ds", "s3_ds"],
["memory_ds", "local_ds", "s3_ds", "gcs_ds"],
indirect=True,
)

enabled_persistent_dataset_generators = pytest.mark.parametrize(
"ds_generator",
["local_ds_generator", "s3_ds_generator"],
["local_ds_generator", "s3_ds_generator", "gcs_ds_generator"],
indirect=True,
)

Expand Down Expand Up @@ -45,6 +45,19 @@ def generate_s3_ds():
return generate_s3_ds


@pytest.fixture
def gcs_ds(gcs_ds_generator):
return gcs_ds_generator()


@pytest.fixture
def gcs_ds_generator(gcs_path):
def generate_gcs_ds():
return hub.dataset(gcs_path)

return generate_gcs_ds


@pytest.fixture
def hub_cloud_ds(hub_cloud_ds_generator):
return hub_cloud_ds_generator()
Expand Down
26 changes: 26 additions & 0 deletions hub/tests/path_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from hub.core.storage.gcs import GCSProvider
from hub.util.storage import storage_provider_from_hub_path
from hub.core.storage.s3 import S3Provider
from hub.core.storage.local import LocalProvider
Expand All @@ -8,10 +9,12 @@
KEEP_STORAGE_OPT,
LOCAL_OPT,
MEMORY_OPT,
PYTEST_GCS_PROVIDER_BASE_ROOT,
PYTEST_HUB_CLOUD_PROVIDER_BASE_ROOT,
PYTEST_LOCAL_PROVIDER_BASE_ROOT,
PYTEST_MEMORY_PROVIDER_BASE_ROOT,
S3_OPT,
GCS_OPT,
)
import posixpath
from hub.tests.common import (
Expand All @@ -25,6 +28,7 @@
MEMORY = "memory"
LOCAL = "local"
S3 = "s3"
GCS = "gcs"
HUB_CLOUD = "hub_cloud"


Expand All @@ -49,6 +53,12 @@ def _get_path_composition_configs(request):
"is_id_prefix": True,
"use_underscores": False,
},
GCS: {
"base_root": PYTEST_GCS_PROVIDER_BASE_ROOT,
"use_id": True,
"is_id_prefix": True,
"use_underscores": False,
},
HUB_CLOUD: {
"base_root": PYTEST_HUB_CLOUD_PROVIDER_BASE_ROOT,
"use_id": True,
Expand Down Expand Up @@ -125,6 +135,22 @@ def s3_path(request):
S3Provider(path).clear()


@pytest.fixture
def gcs_path(request):
if not is_opt_true(request, GCS_OPT):
pytest.skip()
return

path = _get_storage_path(request, GCS)
GCSProvider(path).clear()

yield path

# clear storage unless flagged otherwise
if not is_opt_true(request, KEEP_STORAGE_OPT):
S3Provider(path).clear()


@pytest.fixture
def hub_cloud_path(request, hub_cloud_dev_token):
if not is_opt_true(request, HUB_CLOUD_OPT):
Expand Down
10 changes: 8 additions & 2 deletions hub/tests/storage_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from hub.core.storage.gcs import GCSProvider
from hub.util.storage import storage_provider_from_hub_path
from hub.core.storage.s3 import S3Provider
from hub.core.storage.local import LocalProvider
Expand All @@ -7,13 +8,13 @@

enabled_storages = pytest.mark.parametrize(
"storage",
["memory_storage", "local_storage", "s3_storage"],
["memory_storage", "local_storage", "s3_storage", "gcs_storage"],
indirect=True,
)

enabled_persistent_storages = pytest.mark.parametrize(
"storage",
["local_storage", "s3_storage"],
["local_storage", "s3_storage", "gcs_storage"],
indirect=True,
)

Expand All @@ -33,6 +34,11 @@ def s3_storage(s3_path):
return S3Provider(s3_path)


@pytest.fixture
def gcs_storage(gcs_path):
return GCSProvider(gcs_path)


@pytest.fixture
def hub_cloud_storage(hub_cloud_path, hub_cloud_dev_token):
return storage_provider_from_hub_path(hub_cloud_path, token=hub_cloud_dev_token)
Expand Down
2 changes: 1 addition & 1 deletion hub/util/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def dataset_exists(storage: StorageProvider) -> bool:
try:
storage[get_dataset_meta_key()]
return True
except KeyError:
except (KeyError, FileNotFoundError):
return False


Expand Down
4 changes: 4 additions & 0 deletions hub/util/storage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from hub.core.storage.gcs import GCSProvider
from hub.util.cache_chain import generate_chain
from hub.constants import MB
from hub.util.tag import check_hub_path
Expand Down Expand Up @@ -25,6 +26,7 @@ def storage_provider_from_path(
Returns:
If given a path starting with s3:// returns the S3Provider.
If given a path starting with gcp:// orreturns the GCPProvider.
If given a path starting with mem:// returns the MemoryProvider.
If given a path starting with hub:// returns the underlying cloud Provider.
If given a valid local path, returns the LocalProvider.
Expand All @@ -43,6 +45,8 @@ def storage_provider_from_path(
storage: StorageProvider = S3Provider(
path, key, secret, session_token, endpoint_url, region, token=token
)
elif path.startswith("gcp://") or path.startswith("gcs://"):
return GCSProvider(path, creds)
elif path.startswith("mem://"):
storage = MemoryProvider(path)
elif path.startswith("hub://"):
Expand Down

0 comments on commit 467ae13

Please sign in to comment.