Skip to content

[datasets] Targets are modified inplace #840

Closed
@max-kalganov

Description

Bug description

Targets are being changed when iterating over some dataset more than one time.
The reason is storing targets in self.data, and changing them in the __getitem__ in place using pre_transforms, etc.

# _AbstractDataset
def __getitem__(
        self,
        index: int
    ) -> Tuple[Any, Any]:

        # Read image
        img, target = self._read_sample(index)
        # Pre-transforms (format conversion at run-time etc.)
        if self._pre_transforms is not None:
            img, target = self._pre_transforms(img, target)

        if self.img_transforms is not None:
            # typing issue cf. https://github.com/python/mypy/issues/5485
            img = self.img_transforms(img)  # type: ignore[call-arg]

        if self.sample_transforms is not None:
            img, target = self.sample_transforms(img, target)

        return img, target

This can be fixed by copying target in the _read_sample

# AbstractDataset
def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]:
        img_name, target = self.data[index]
        # Read image
        img = read_img_as_tensor(os.path.join(self.root, img_name), dtype=tf.float32)

        return img, target

OR returning a copy of the target in all transform methods.

def convert_target_to_relative(img: ImageTensor, target: Dict[str, Any]) -> Tuple[ImageTensor, Dict[str, Any]]:

    target['boxes'] = convert_to_relative_coords(target['boxes'], get_img_shape(img))
    return img, target

Code snippet to reproduce the bug

def process_image(train_example):
  img, target = train_example
  img_numpy = img.numpy() * 255
  for example in target['boxes']:
      print(example)
      unnormalized_example = [int(example[0]*img.shape[1]), int(example[1]*img.shape[0]),
                              int(example[2]*img.shape[1]), int(example[3]*img.shape[0])]
      cv2.rectangle(img=img_numpy,
                    pt1=(unnormalized_example[0], unnormalized_example[1]),
                    pt2=(unnormalized_example[2], unnormalized_example[3]),
                    color=(0, 0, 255), thickness=2)
  return img_numpy  


train_set = SROIE(train=True, download=True)

for i in range(2):
  for j, example in enumerate(train_set):
    if j == 0:    
      print(f"{i} ____")
      img_n = process_image(example)

P.S. Sorry for not a pretty code style. This snippet is just for an example :)

Error traceback

~changed target box coordinates

Environment

.

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions