Skip to content

Commit

Permalink
json tests
Browse files Browse the repository at this point in the history
  • Loading branch information
farizrahman4u committed Oct 13, 2021
1 parent 23ea203 commit 5f4fd3f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 3 deletions.
41 changes: 41 additions & 0 deletions hub/api/tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hub
import pytest
from hub.util.json import JsonValidationError
from hub.tests.dataset_fixtures import enabled_non_gcs_datasets
from typing import Any, Optional, Union, List, Dict


Expand Down Expand Up @@ -150,3 +151,43 @@ def test_json_with_schema(memory_ds):
ds.json2.extend(items)
for i in range(len(items)):
assert ds.json2[i].data() == ds.json2.data()[i] == items[i]


@enabled_non_gcs_datasets
def test_json_transform(ds, scheduler="threaded"):
ds.create_tensor("json", htype="json")

items = [
{"x": [1, 2, 3], "y": [4, [5, 6]]},
{"x": [1, 2, 3], "y": [4, {"z": [0.1, 0.2, []]}]},
["a", ["b", "c"], {"d": 1.0}],
[1.0, 2.0, 3.0, 4.0],
["a", "b", "c", "d"],
] * 5

@hub.compute
def upload(stuff, ds):
ds.json.append(stuff)
return ds

upload().eval(items, ds, num_workers=2, scheduler=scheduler)
assert ds.json.data() == items


@enabled_non_gcs_datasets
def test_list_transform(ds, scheduler="threaded"):
ds.create_tensor("list", htype="list")

items = [
["a", ["b", "c"], {"d": 1.0}],
[1.0, 2.0, 3.0, 4.0],
["a", "b", "c", "d"],
] * 5

@hub.compute
def upload(stuff, ds):
ds.list.append(stuff)
return ds

upload().eval(items, ds, num_workers=2, scheduler=scheduler)
assert ds.list.data() == items
8 changes: 5 additions & 3 deletions hub/api/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ def test_text(memory_ds):


@enabled_non_gcs_datasets
def test_text(ds, scheduler="threaded"):
def test_text_transform(ds, scheduler="threaded"):
ds.create_tensor("text", htype="text")

@hub.compute
def upload(some_str, ds):
ds.text.append(some_str)
return ds

upload().eval(["hi", "if ur reading this ur a nerd"], ds, num_workers=2, scheduler=scheduler)
upload().eval(
["hi", "if ur reading this ur a nerd"], ds, num_workers=2, scheduler=scheduler
)

assert len(ds) == 2
assert ds.text.numpy() == ["hi", "if ur reading this ur a nerd"]
assert ds.text.data() == ["hi", "if ur reading this ur a nerd"]
15 changes: 15 additions & 0 deletions hub/core/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,28 @@ def _serialize_input_sample(
"""Converts the incoming sample into a buffer with the proper dtype and compression."""

if htype in ("json", "list"):
if isinstance(sample, np.ndarray):
if htype == "list":
if sample.dtype == object:
sample = list(sample)
else:
sample = sample.tolist()
elif htype == "json":
if sample.ndim == 0:
sample = sample.tolist() # actually returns dict
elif sample.dtype == object:
sample = list(sample)
else:
sample = sample.tolist()
validate_json_object(sample, expected_dtype)
byts = json.dumps(sample, cls=HubJsonEncoder).encode()
if sample_compression:
byts = compress_bytes(byts, compression=sample_compression)
shape = (len(sample),) if htype == "list" else (1,)
return byts, shape
elif htype == "text":
if isinstance(sample, np.ndarray):
sample = sample.tolist()
if not isinstance(sample, str):
raise TypeError("Expected str, received: " + str(sample))
byts = sample.encode()
Expand Down

0 comments on commit 5f4fd3f

Please sign in to comment.