Skip to content

Commit

Permalink
Merge branch 'fy_tfm_cp' into DeepLakeVectorStore
Browse files Browse the repository at this point in the history
  • Loading branch information
adolkhan committed May 12, 2023
2 parents c9d1189 + 19cbd3f commit 5524f7f
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 6 deletions.
12 changes: 11 additions & 1 deletion deeplake/core/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,9 @@ def _commit(
message: Optional[str] = None,
hash: Optional[str] = None,
flush_version_control_info: bool = True,
*,
is_checkpoint: bool = False,
total_samples_processed: int = 0,
) -> str:
if self._is_filtered_view:
raise Exception(
Expand All @@ -1536,7 +1539,14 @@ def _commit(
self.storage.autoflush = False
try:
self._unlock()
commit(self, message, hash, flush_version_control_info)
commit(
self,
message,
hash,
flush_version_control_info,
is_checkpoint=is_checkpoint,
total_samples_processed=total_samples_processed,
)
if not flush_version_control_info:
self.__dict__["_vc_info_updated"] = True
self._lock()
Expand Down
26 changes: 26 additions & 0 deletions deeplake/core/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,32 @@ def double(data_in, ds):
assert ds.abc.numpy(aslist=True) == data_in


def test_transform_checkpoint_store_data(local_ds_generator):
@deeplake.compute
def upload(i, ds):
ds.abc.append(i)

samples = list(range(100))
samples.insert(50, "bad sample")

with pytest.raises(TransformError):
with local_ds_generator() as ds:
ds.create_tensor("abc")
upload().eval(
samples,
ds,
num_workers=TRANSFORM_TEST_NUM_WORKERS,
checkpoint_interval=20,
)

ds = local_ds_generator()

assert len(ds.abc) == 40
last_checkpoint = ds.version_state["commit_node"].parent
assert last_checkpoint.is_checkpoint == True
assert last_checkpoint.total_samples_processed == 40


def create_test_ds(path):
ds = deeplake.empty(path, overwrite=True)
ds.create_tensor("images", htype="image", sample_compression="jpg")
Expand Down
8 changes: 6 additions & 2 deletions deeplake/core/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,12 @@ def my_fn(sample_in: Any, samples_out, my_arg0, my_arg1=0):
progress = 0.0
for data_in in datas_in:
if checkpointing_enabled and progress > 0:
target_ds.commit(
f"Auto-commit during deeplake.compute of {desc} after {progress}% progress"
target_ds._commit(
f"Auto-commit during deeplake.compute of {desc} after {progress}% progress",
None,
False,
is_checkpoint=True,
total_samples_processed=samples_processed,
)
progress = round(
(samples_processed + len(data_in)) / total_samples * 100, 2
Expand Down
13 changes: 12 additions & 1 deletion deeplake/core/version_control/commit_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def __init__(self, branch: str, commit_id: str):
self.commit_user_name: Optional[str] = None
self.merge_parent: Optional["CommitNode"] = None
self._info_updated: bool = False
self.is_checkpoint: bool = False
self.total_samples_processed: int = 0

def add_child(self, node: "CommitNode"):
"""Adds a child to the node, used for branching."""
Expand All @@ -27,6 +29,8 @@ def copy(self):
node.commit_message = self.commit_message
node.commit_user_name = self.commit_user_name
node.commit_time = self.commit_time
node.is_checkpoint = self.is_checkpoint
node.total_samples_processed = self.total_samples_processed
return node

def add_successor(self, node: "CommitNode", message: Optional[str] = None):
Expand All @@ -47,7 +51,12 @@ def is_merge_node(self):
return self.merge_parent is not None

def __repr__(self) -> str:
return f"Commit : {self.commit_id} ({self.branch}) \nAuthor : {self.commit_user_name}\nTime : {str(self.commit_time)[:-7]}\nMessage: {self.commit_message}"
return (
f"Commit : {self.commit_id} ({self.branch}) \nAuthor : {self.commit_user_name}\nTime : {str(self.commit_time)[:-7]}\nMessage: {self.commit_message}"
+ f"\nTotal samples processed in transform: {self.total_samples_processed}"
if self.is_checkpoint
else ""
)

@property
def is_head_node(self) -> bool:
Expand All @@ -64,4 +73,6 @@ def to_json(self):
"commit_message": self.commit_message,
"commit_time": self.commit_time.timestamp() if self.commit_time else None,
"commit_user_name": self.commit_user_name,
"is_checkpoint": self.is_checkpoint,
"total_samples_processed": self.total_samples_processed,
}
18 changes: 16 additions & 2 deletions deeplake/util/version_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def _version_info_to_json(info):
"commit_message": node.commit_message,
"commit_time": node.commit_time.timestamp() if node.commit_time else None,
"commit_user_name": node.commit_user_name,
"is_checkpoint": node.is_checkpoint,
"total_samples_processed": node.total_samples_processed,
}
return {
"commits": commits,
Expand All @@ -82,6 +84,8 @@ def _version_info_from_json(info):
None if commit_time is None else datetime.fromtimestamp(commit_time)
)
node.commit_user_name = commit_data["commit_user_name"]
node.is_checkpoint = commit_data.get("is_checkpoint", False)
node.total_samples_processed = commit_data.get("total_samples_processed", 0)
parent = commit_data["parent"]
if parent:
commit_node_map[parent].add_child(node)
Expand Down Expand Up @@ -140,6 +144,8 @@ def commit(
hash: Optional[str] = None,
flush_version_control_info: bool = True,
reload_meta: bool = True,
is_checkpoint: bool = False,
total_samples_processed: int = 0,
) -> None:
"""Modifies the version state to reflect the commit and also copies required data to the new commit directory."""
storage = dataset.storage
Expand All @@ -157,7 +163,9 @@ def commit(
hash = generate_hash()
version_state["commit_id"] = hash
new_node = CommitNode(version_state["branch"], hash)
version_state["commit_node"].add_successor(new_node, message)
stored_commit_node.add_successor(new_node, message)
stored_commit_node.is_checkpoint = is_checkpoint
stored_commit_node.total_samples_processed = total_samples_processed
version_state["commit_node"] = new_node
version_state["branch_commit_map"][version_state["branch"]] = version_state[
"commit_id"
Expand Down Expand Up @@ -443,7 +451,13 @@ def _merge_node(commit_id):
node2 = map2[commit_id]
merged_node = CommitNode(node1.branch, node2.commit_id)

for attr in ("commit_message", "commit_user_name", "commit_time"):
for attr in (
"commit_message",
"commit_user_name",
"commit_time",
"is_checkpoint",
"total_samples_processed",
):
setattr(merged_node, attr, getattr(node1, attr) or getattr(node2, attr))
for child in set(
[node.commit_id for node in node1.children]
Expand Down

0 comments on commit 5524f7f

Please sign in to comment.