Skip to content

Commit

Permalink
Png compression implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
Edward Grigoryan committed Dec 1, 2020
1 parent 3a53c2d commit 54d096e
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 1 deletion.
28 changes: 28 additions & 0 deletions examples/zarr_compressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from io import BytesIO

import zarr
import numcodecs
from numcodecs.abc import Codec
from numcodecs import MsgPack
import numpy as np
from PIL import Image

from hub.utils import Timer
from hub.numcodecs import PngCodec


def main():
numcodecs.register_codec(PngCodec, "png")
with Timer("Compress"):
arr = zarr.create(
shape=(10, 10, 1920, 1080, 7),
dtype="uint8",
compressor=PngCodec(solo_channel=True),
store=zarr.MemoryStore(),
)
arr[:] = np.ones((10, 10, 1920, 1080, 7), dtype="uint8")
print(arr[:].shape)


if __name__ == "__main__":
main()
7 changes: 6 additions & 1 deletion hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from hub.store.metastore import MetaStorage
from hub.client.hub_control import HubControlClient
from hub.features import Audio, BBox, ClassLabel, Image, Sequence, Text, Video
from hub.numcodecs import PngCodec
from collections import defaultdict


Expand Down Expand Up @@ -234,6 +235,8 @@ def _get_compressor(self, compressor: str):
return numcodecs.Zstd(numcodecs.zstd.DEFAULT_CLEVEL)
elif compressor.lower() == "default":
return "default"
elif compressor.lower() == "png":
return PngCodec(solo_channel=True)
else:
raise ValueError(
f"Wrong compressor: {compressor}, only LZ4 and ZSTD are supported"
Expand Down Expand Up @@ -760,7 +763,9 @@ def image_to_hub(tf_dt, max_shape=None):
max_shape = max_shape or tuple(
10000 if dim is None else dim for dim in tf_dt.shape
)
return Image(shape=tf_dt.shape, dtype=dt, max_shape=max_shape)
return Image(
shape=tf_dt.shape, dtype=dt, max_shape=max_shape, compressor="png"
)

def class_label_to_hub(tf_dt, max_shape=None):
if hasattr(tf_dt, "_num_classes"):
Expand Down
87 changes: 87 additions & 0 deletions hub/numcodecs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from io import BytesIO

import zarr
import numcodecs
from numcodecs.abc import Codec
from numcodecs import MsgPack
import numpy as np
from PIL import Image


class PngCodec(Codec):
def __init__(self, solo_channel=True):
self.codec_id = "png"
self.solo_channel = solo_channel
self._msgpack = MsgPack()

def encode_single_image(self, image: np.ndarray) -> bytes:
with BytesIO() as buffer:
Image.fromarray(image).save(buffer, format="png")
return buffer.getvalue()

def decode_single_image(self, buf) -> np.ndarray:
with BytesIO(bytearray(buf)) as buffer:
buffer.seek(0)
return np.array(Image.open(buffer, mode="r"))

def encode(self, buf: np.ndarray):
append_one = False
if self.solo_channel and buf.shape[-1] == 1:
buf = np.reshape(buf, buf.shape[:-1])
append_one = True
if not self.solo_channel or append_one:
shape_dims = 2
else:
shape_dims = 3
assert len(buf.shape) >= shape_dims
if len(buf.shape) == shape_dims:
return self._msgpack.encode(
[{"items": self.encode_single_image(buf), "append_one": append_one}]
)
else:
image_shape = buf.shape[-shape_dims:]
items_shape = buf.shape[:-shape_dims]
items = []
for i in np.ndindex(items_shape):
items.append(self.encode_single_image(buf[i]))
return self._msgpack.encode(
[
{
"items": items,
"items_shape": items_shape,
"image_shape": image_shape,
"dtype": str(buf.dtype),
"append_one": append_one,
}
]
)

def decode(self, buf, out=None):
data = self._msgpack.decode(buf)[0]
if "items_shape" not in data:
images = self.decode_single_image(data["items"])
else:
items = data["items"]
images = np.zeros(
data["items_shape"] + data["image_shape"], dtype=data["dtype"]
)

for i, index in enumerate(np.ndindex(tuple(data["items_shape"]))):
images[index] = self.decode_single_image(items[i])

if data.get("append_one"):
images = np.reshape(images, images.shape + (1,))
return images

def get_config(self):
return {"id": self.codec_id, "solo_channel": self.solo_channel}

# def __dict__(self):
# return self.get_config()

@classmethod
def from_config(cls, config):
return PngCodec(config["solo_channel"])


numcodecs.register_codec(PngCodec, "png")
1 change: 1 addition & 0 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ torch>=1,<2
ray>=1.0
transformers>=3.5.1
dask[complete]>=2.30
tensorflow_datasets

0 comments on commit 54d096e

Please sign in to comment.