Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify Feature implementation #5539

Merged
merged 2 commits into from
Mar 3, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion torchvision/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import enum
from typing import TypeVar, Type

T = TypeVar("T")


class StrEnumMeta(enum.EnumMeta):
auto = enum.auto

def from_str(self, member: str):
def from_str(self: Type[T], member: str) -> T:
try:
return self[member]
except KeyError:
Expand Down
28 changes: 24 additions & 4 deletions torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,40 @@ def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
format: Union[BoundingBoxFormat, str],
image_size: Tuple[int, int],
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> BoundingBox:
bounding_box = super().__new__(cls, data, dtype=dtype, device=device)
bounding_box = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)

if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper())
bounding_box.format = format

bounding_box._metadata.update(dict(format=format, image_size=image_size))
bounding_box.image_size = image_size

return bounding_box

@classmethod
def new_like(
cls,
other: BoundingBox,
data: Any,
*,
format: Optional[Union[BoundingBoxFormat, str]] = None,
image_size: Optional[Tuple[int, int]] = None,
**kwargs: Any,
pmeier marked this conversation as resolved.
Show resolved Hide resolved
) -> BoundingBox:
return super().new_like(
other,
data,
format=format if format is not None else other.format,
image_size=image_size if image_size is not None else other.image_size,
**kwargs,
)

def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
Expand Down
22 changes: 15 additions & 7 deletions torchvision/prototype/features/_encoded.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import sys
from typing import BinaryIO, Tuple, Type, TypeVar, Union, Optional, Any
Expand All @@ -13,19 +15,25 @@


class EncodedData(_Feature):
@classmethod
def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> EncodedData:
# TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8?
return super()._to_tensor(data, dtype=dtype, device=device)
return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)

@classmethod
def from_file(cls: Type[D], file: BinaryIO) -> D:
return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder))
def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D:
return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs)

@classmethod
def from_path(cls: Type[D], path: Union[str, os.PathLike]) -> D:
def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D:
with open(path, "rb") as file:
return cls.from_file(file)
return cls.from_file(file, **kwargs)


class EncodedImage(EncodedData):
Expand Down
64 changes: 16 additions & 48 deletions torchvision/prototype/features/_feature.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping
from typing import Any, cast, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping

import torch
from torch._C import _TensorBase, DisableTorchFunction
Expand All @@ -8,59 +8,22 @@


class _Feature(torch.Tensor):
_META_ATTRS: Set[str] = set()
_metadata: Dict[str, Any]

def __init_subclass__(cls) -> None:
"""
For convenient copying of metadata, we store it inside a dictionary rather than multiple individual attributes.
By adding the metadata attributes as class annotations on subclasses of :class:`Feature`, this method adds
properties to have the same convenient access as regular attributes.

>>> class Foo(_Feature):
... bar: str
... baz: Optional[str]
>>> foo = Foo()
>>> foo.bar
>>> foo.baz

This has the additional benefit that autocomplete engines and static type checkers are aware of the metadata.
"""
meta_attrs = {attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")}
for super_cls in cls.__mro__[1:]:
if super_cls is _Feature:
break

meta_attrs.update(cast(Type[_Feature], super_cls)._META_ATTRS)

cls._META_ATTRS = meta_attrs
for name in meta_attrs:
setattr(cls, name, property(cast(Callable[[F], Any], lambda self, name=name: self._metadata[name])))

def __new__(
cls: Type[F],
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> F:
if isinstance(device, str):
device = torch.device(device)
feature = cast(
return cast(
F,
torch.Tensor._make_subclass(
cast(_TensorBase, cls),
cls._to_tensor(data, dtype=dtype, device=device),
# requires_grad
False,
torch.as_tensor(data, dtype=dtype, device=device), # type: ignore[arg-type]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, the annotations of as_tensor are wrong. str and int are valid types for the device.

requires_grad,
),
)
feature._metadata = dict()
return feature

@classmethod
def _to_tensor(self, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
return torch.as_tensor(data, dtype=dtype, device=device)

@classmethod
def new_like(
Expand All @@ -69,12 +32,17 @@ def new_like(
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
**metadata: Any,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
**kwargs: Any,
) -> F:
_metadata = other._metadata.copy()
_metadata.update(metadata)
return cls(data, dtype=dtype or other.dtype, device=device or other.device, **_metadata)
return cls(
data,
dtype=dtype if dtype is not None else other.dtype,
device=device if device is not None else other.device,
requires_grad=requires_grad if requires_grad is not None else other.requires_grad,
**kwargs,
)

@classmethod
def __torch_function__(
Expand Down
30 changes: 18 additions & 12 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,37 @@ def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
color_space: Optional[Union[ColorSpace, str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> Image:
image = super().__new__(cls, data, dtype=dtype, device=device)
data = torch.as_tensor(data, dtype=dtype, device=device) # type: ignore[arg-type]
if data.ndim < 2:
raise ValueError
elif data.ndim == 2:
data = data.unsqueeze(0)
image = super().__new__(cls, data, requires_grad=requires_grad)

if color_space is None:
color_space = cls.guess_color_space(image)
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())

image._metadata.update(dict(color_space=color_space))
elif not isinstance(color_space, ColorSpace):
raise ValueError
image.color_space = color_space

return image

@classmethod
def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
tensor = super()._to_tensor(data, dtype=dtype, device=device)
if tensor.ndim < 2:
raise ValueError
elif tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
return tensor
def new_like(
cls, other: Image, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any
) -> Image:
return super().new_like(
other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs
)

@property
def image_size(self) -> Tuple[int, int]:
Expand Down
46 changes: 33 additions & 13 deletions torchvision/prototype/features/_label.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Optional, Sequence, cast
from typing import Any, Optional, Sequence, cast, Union

import torch
from torchvision.prototype.utils._internal import apply_recursively
Expand All @@ -15,20 +15,32 @@ def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
like: Optional[Label] = None,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> Label:
label = super().__new__(cls, data, dtype=dtype, device=device)
label = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)

label._metadata.update(dict(categories=categories))
label.categories = categories

return label

@classmethod
def from_category(cls, category: str, *, categories: Sequence[str]) -> Label:
return cls(categories.index(category), categories=categories)
def new_like(cls, other: Label, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any) -> Label:
return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs
)

@classmethod
def from_category(
cls,
category: str,
*,
categories: Sequence[str],
**kwargs: Any,
) -> Label:
return cls(categories.index(category), categories=categories, **kwargs)

def to_categories(self) -> Any:
if not self.categories:
Expand All @@ -44,16 +56,24 @@ def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
like: Optional[Label] = None,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> OneHotLabel:
one_hot_label = super().__new__(cls, data, dtype=dtype, device=device)
one_hot_label = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)

if categories is not None and len(categories) != one_hot_label.shape[-1]:
raise ValueError()

one_hot_label._metadata.update(dict(categories=categories))
one_hot_label.categories = categories

return one_hot_label

@classmethod
def new_like(
cls, other: OneHotLabel, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any
) -> OneHotLabel:
return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs
)
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
return features.SegmentationMask.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.resize_bounding_box(input, self.size, image_size=input.image_size)
return features.BoundingBox.new_like(input, output, image_size=self.size)
return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size)))
elif isinstance(input, PIL.Image.Image):
return F.resize_image_pil(input, self.size, interpolation=self.interpolation)
elif isinstance(input, torch.Tensor):
Expand Down