forked from facebookresearch/vissl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_transforms.py
30 lines (22 loc) · 860 Bytes
/
test_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import unittest
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import ToTensor
from vissl.data.ssl_transforms.img_pil_to_tensor import ImgToTensor
RAND_TENSOR = (torch.rand((224, 224, 3)) * 255).to(dtype=torch.uint8)
RAND_PIL = Image.fromarray(RAND_TENSOR.numpy())
RAND_NUMPY = np.asarray(RAND_PIL)
class TestTransform(unittest.TestCase):
def test_to_tensor(self):
# Check that ImgToTensor and torchvision.transforms.ToTensor
# are mostly equivalent
# PIL.Image
a = ImgToTensor()(RAND_PIL)
b = ToTensor()(RAND_PIL)
self.assertTrue(torch.allclose(a, b))
# Numpy array
c = ImgToTensor()(RAND_NUMPY)
d = ToTensor()(RAND_NUMPY)
self.assertTrue(torch.allclose(c, d))