From 059500f84b49566fd9b55907fc3a456554e1b32b Mon Sep 17 00:00:00 2001 From: Abhinav Tuli <42538472+AbhinavTuli@users.noreply.github.com> Date: Wed, 31 Mar 2021 10:52:13 +0530 Subject: [PATCH] fixed an issue when overwriting transforms (#724) --- hub/api/dataset.py | 4 +++- hub/compute/tests/test_transform.py | 26 +++++++++++++++++++++++++- hub/store/metastore.py | 10 +++++++--- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/hub/api/dataset.py b/hub/api/dataset.py index 1d72547de8..d8743e3f65 100644 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -237,8 +237,10 @@ def __init__( self._version_node = VersionNode(self._commit_id, self._branch) self._branch_node_map = {self._branch: self._version_node} self._commit_node_map = {self._commit_id: self._version_node} + self._chunk_commit_map = { + path: defaultdict(set) for schema, path in self._flat_tensors + } self._tensors = dict(self._generate_storage_tensors()) - self._chunk_commit_map = {key: defaultdict(set) for key in self.keys} except Exception as e: try: self.close() diff --git a/hub/compute/tests/test_transform.py b/hub/compute/tests/test_transform.py index 21bb4ab7b8..bab6a2a199 100644 --- a/hub/compute/tests/test_transform.py +++ b/hub/compute/tests/test_transform.py @@ -6,7 +6,8 @@ import numpy as np import zarr - +import os +from hub.cli.auth import login_fn import hub from hub.schema import Tensor, Image, Text from hub.utils import Timer @@ -364,6 +365,29 @@ def my_transform(sample): out_ds.store(f"./data/test/test_pipeline_basic_output_{name}") +def test_transform_overwrite(): + password = os.getenv("ACTIVELOOP_HUB_PASSWORD") + login_fn("testingacc", password) + + schema = { + "image": hub.schema.Image( + shape=(None, None, 1), dtype="uint8", max_shape=(1, 1, 1) + ), + } + + @hub.transform(schema=schema) + def create_image(value): + return {"image": np.ones((1, 1, 1), dtype="uint8") * value} + + ds1 = create_image(range(5)) + ds = ds1.store("testingacc/ds_transform", public=False) + for i in range(5): + assert (ds["image", i].compute() == i * np.ones((1, 1, 1))).all() + ds = ds1.store("testingacc/ds_transform", public=False) + for i in range(5): + assert (ds["image", i].compute() == i * np.ones((1, 1, 1))).all() + + if __name__ == "__main__": with Timer("Test Transform"): with Timer("test threaded"): diff --git a/hub/store/metastore.py b/hub/store/metastore.py index 5533aab746..24ff8211de 100644 --- a/hub/store/metastore.py +++ b/hub/store/metastore.py @@ -3,7 +3,6 @@ This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. """ -from collections import defaultdict import json from collections.abc import MutableMapping import posixpath @@ -125,8 +124,13 @@ def __delitem__(self, k: str): if self._ds._commit_id: k = self.find_chunk(k) or f"{k}:{self._ds._commit_id}" commit_id = k.split(":")[-1] - self._ds._chunk_commit_map[self._path][chunk_key].remove(commit_id) - del self._fs_map[k] + try: + self._ds._chunk_commit_map[self._path][chunk_key].remove(commit_id) + except Exception: + try: + del self._fs_map[k] + except Exception: + pass def flush(self): self._meta.flush()