forked from kuangliu/pytorch-retinanet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
transform.py
142 lines (118 loc) · 4.09 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
'''Perform transforms on both PIL image and object boxes.'''
import math
import random
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
def resize(img, boxes, size, max_size=1000):
'''Resize the input PIL image to the given size.
Args:
img: (PIL.Image) image to be resized.
boxes: (tensor) object boxes, sized [#ojb,4].
size: (tuple or int)
- if is tuple, resize image to the size.
- if is int, resize the shorter side to the size while maintaining the aspect ratio.
max_size: (int) when size is int, limit the image longer size to max_size.
This is essential to limit the usage of GPU memory.
Returns:
img: (PIL.Image) resized image.
boxes: (tensor) resized boxes.
'''
w, h = img.size
if isinstance(size, int):
size_min = min(w,h)
size_max = max(w,h)
sw = sh = float(size) / size_min
if sw * size_max > max_size:
sw = sh = float(max_size) / size_max
ow = int(w * sw + 0.5)
oh = int(h * sh + 0.5)
else:
ow, oh = size
sw = float(ow) / w
sh = float(oh) / h
return img.resize((ow,oh), Image.BILINEAR), \
boxes*torch.Tensor([sw,sh,sw,sh])
def random_crop(img, boxes):
'''Crop the given PIL image to a random size and aspect ratio.
A crop of random size of (0.08 to 1.0) of the original size and a random
aspect ratio of 3/4 to 4/3 of the original aspect ratio is made.
Args:
img: (PIL.Image) image to be cropped.
boxes: (tensor) object boxes, sized [#ojb,4].
Returns:
img: (PIL.Image) randomly cropped image.
boxes: (tensor) randomly cropped boxes.
'''
success = False
for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(0.56, 1.0) * area
aspect_ratio = random.uniform(3. / 4, 4. / 3)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
x = random.randint(0, img.size[0] - w)
y = random.randint(0, img.size[1] - h)
success = True
break
# Fallback
if not success:
w = h = min(img.size[0], img.size[1])
x = (img.size[0] - w) // 2
y = (img.size[1] - h) // 2
img = img.crop((x, y, x+w, y+h))
boxes -= torch.Tensor([x,y,x,y])
boxes[:,0::2].clamp_(min=0, max=w-1)
boxes[:,1::2].clamp_(min=0, max=h-1)
return img, boxes
def center_crop(img, boxes, size):
'''Crops the given PIL Image at the center.
Args:
img: (PIL.Image) image to be cropped.
boxes: (tensor) object boxes, sized [#ojb,4].
size (tuple): desired output size of (w,h).
Returns:
img: (PIL.Image) center cropped image.
boxes: (tensor) center cropped boxes.
'''
w, h = img.size
ow, oh = size
i = int(round((h - oh) / 2.))
j = int(round((w - ow) / 2.))
img = img.crop((j, i, j+ow, i+oh))
boxes -= torch.Tensor([j,i,j,i])
boxes[:,0::2].clamp_(min=0, max=ow-1)
boxes[:,1::2].clamp_(min=0, max=oh-1)
return img, boxes
def random_flip(img, boxes):
'''Randomly flip the given PIL Image.
Args:
img: (PIL Image) image to be flipped.
boxes: (tensor) object boxes, sized [#ojb,4].
Returns:
img: (PIL.Image) randomly flipped image.
boxes: (tensor) randomly flipped boxes.
'''
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
w = img.width
xmin = w - boxes[:,2]
xmax = w - boxes[:,0]
boxes[:,0] = xmin
boxes[:,2] = xmax
return img, boxes
def draw(img, boxes):
draw = ImageDraw.Draw(img)
for box in boxes:
draw.rectangle(list(box), outline='red')
img.show()
def test():
img = Image.open('./image/000001.jpg')
boxes = torch.Tensor([[48, 240, 195, 371], [8, 12, 352, 498]])
img, boxes = random_crop(img, boxes)
print(img.size)
draw(img, boxes)
# test()