Skip to content

Commit

Permalink
add --kaggle functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
thisiseshan committed Aug 9, 2021
1 parent e39c663 commit 9ba1a5e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 4 deletions.
3 changes: 3 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def pytest_addoption(parser):
help="All storage providers/datasets will have their pytest data wiped. \
Use this option to keep the data after the test run. Note: does not keep memory tests storage.",
)
parser.addoption(
KAGGLE_OPT, action="store_true", help="Kaggle tests will run if enabled."
)


def print_session_id():
Expand Down
18 changes: 14 additions & 4 deletions hub/auto/tests/_test_kaggle.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,44 @@
from hub.auto.unstructured import kaggle
from hub.api.dataset import Dataset
from hub.util.exceptions import (
KaggleDatasetAlreadyDownloadedError,
SamePathException,
KaggleMissingCredentialsError,
ExternalCommandError,
)

from hub.tests.common import get_dummy_data_path
import pytest
import os
import hub


def test_ingestion_simple(local_ds: Dataset):
def test_ingestion_simple(local_ds: Dataset, hub_kaggle_credentials):
kaggle_path = os.path.join(local_ds.path, "unstructured_kaggle_data_simple")
username, key = hub_kaggle_credentials

ds = hub.ingest_kaggle(
tag="andradaolteanu/birdcall-recognition-data",
src=kaggle_path,
dest=local_ds.path,
images_compression="jpeg",
kaggle_credentials={"username": username, "key": key},
overwrite=False,
)

assert list(ds.tensors.keys()) == ["images", "labels"]
assert ds["labels"].numpy().shape == (10, 1)


def test_ingestion_sets(local_ds: Dataset):
def test_ingestion_sets(local_ds: Dataset, hub_kaggle_credentials):
kaggle_path = os.path.join(local_ds.path, "unstructured_kaggle_data_sets")
username, key = hub_kaggle_credentials

ds = hub.ingest_kaggle(
tag="thisiseshan/bird-classes",
src=kaggle_path,
dest=local_ds.path,
images_compression="jpeg",
kaggle_credentials={"username": username, "key": key},
overwrite=False,
)

Expand All @@ -52,16 +57,18 @@ def test_ingestion_sets(local_ds: Dataset):
assert ds["train/labels"].info.class_names == ("class0", "class1", "class2")


def test_kaggle_exception(local_ds: Dataset):
def test_kaggle_exception(local_ds: Dataset, hub_kaggle_credentials):
kaggle_path = os.path.join(local_ds.path, "unstructured_kaggle_data")
dummy_path = get_dummy_data_path("tests_auto/image_classification")
username, key = hub_kaggle_credentials

with pytest.raises(SamePathException):
hub.ingest_kaggle(
tag="thisiseshan/bird-classes",
src=dummy_path,
dest=dummy_path,
images_compression="jpeg",
kaggle_credentials={"username": username, "key": key},
overwrite=False,
)

Expand Down Expand Up @@ -91,6 +98,7 @@ def test_kaggle_exception(local_ds: Dataset):
src=kaggle_path,
dest=local_ds.path,
images_compression="jpeg",
kaggle_credentials={"username": username, "key": key},
overwrite=False,
)

Expand All @@ -99,6 +107,7 @@ def test_kaggle_exception(local_ds: Dataset):
src=kaggle_path,
dest=local_ds.path,
images_compression="jpeg",
kaggle_credentials={"username": username, "key": key},
overwrite=False,
)

Expand All @@ -108,5 +117,6 @@ def test_kaggle_exception(local_ds: Dataset):
src=kaggle_path,
dest=local_ds.path,
images_compression="jpeg",
kaggle_credentials={"username": username, "key": key},
overwrite=False,
)
1 change: 1 addition & 0 deletions hub/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,4 @@
HUB_CLOUD_OPT = "--hub-cloud"
S3_PATH_OPT = "--s3-path"
KEEP_STORAGE_OPT = "--keep-storage"
KAGGLE_OPT = "--kaggle"
18 changes: 18 additions & 0 deletions hub/tests/client_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
HUB_CLOUD_DEV_USERNAME,
HUB_CLOUD_OPT,
ENV_HUB_DEV_PASSWORD,
ENV_KAGGLE_USERNAME,
ENV_KAGGLE_KEY,
KAGGLE_OPT,
)
from hub.tests.common import is_opt_true
import os
Expand Down Expand Up @@ -32,3 +35,18 @@ def hub_cloud_dev_token(hub_cloud_dev_credentials):
client = HubBackendClient()
token = client.request_auth_token(username, password)
return token


@pytest.fixture(scope="session")
def hub_kaggle_credentials(request):
if not is_opt_true(request, KAGGLE_OPT):
pytest.skip()

username = os.getenv(ENV_KAGGLE_USERNAME)
key = os.getenv(ENV_KAGGLE_KEY)

assert (
key is not None
), f"Kaggle credentials were not found in environment variable. This is necessary for testing kaggle ingestion datasets."

return username, key

0 comments on commit 9ba1a5e

Please sign in to comment.