-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathesposallesData.py
152 lines (133 loc) · 5.5 KB
/
esposallesData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import cv2
import numpy as np
import string
'''
Total number of the textline-based datasets:
training data: 2759
validation data: 311
test data: 757
Total number of the word-based datasets:
training data: 28346
validation data: 3155
test data: 8026
'''
IMG_HEIGHT = 80
IMG_WIDTH_TEXTLINE = 1400
IMG_WIDTH_WORD = 460
TEXTLINE = False # True: textline-based False: word-based
# Download the Esposalles datasets from http://rrc.cvc.uab.es/?ch=10&com=downloads
baseDir = '/home/lkang/datasets/OfficialEsposalles/'
train = baseDir + 'train/'
validation = baseDir + 'validation/'
test = baseDir + 'test/'
def labelDictionary():
if TEXTLINE:
labels = [' ']
labels += list(string.ascii_lowercase)
else:
labels = list(string.ascii_lowercase)
labels += list(string.ascii_uppercase)
labels += list('0123456789#ç')
labels.remove('k')
labels.remove('K')
labels.remove('w')
labels.remove('W')
labels.remove('Z')
return len(labels), {label:n for n, label in enumerate(labels)}
def init():
trainImage = []
validationImage = []
testImage = []
trainLabel = []
validationLabel = []
testLabel = []
tmp_trainLabel = []
tmp_validationLabel = []
tmp_testLabel = []
for i, v in enumerate([train, validation, test]):
# Do remember to put the groundtruth file to the datasets directory
if TEXTLINE:
groundtruth = 'line_groundtruth.txt'
else:
groundtruth = 'groundtruth.txt'
with open(v + groundtruth, 'r') as gt:
for line in gt:
values = line[:-1].split(':')
if i == 0:
trainImage.append(values[0])
tmp_trainLabel.append(values[1])
elif i == 1:
validationImage.append(values[0])
tmp_validationLabel.append(values[1])
elif i == 2:
testImage.append(values[0])
tmp_testLabel.append(values[1])
labelNum, labelDict = labelDictionary()
for i in tmp_trainLabel:
label = [labelDict[j] for j in i]
trainLabel.append(label)
for i in tmp_validationLabel:
label = [labelDict[j] for j in i]
validationLabel.append(label)
for i in tmp_testLabel:
label = [labelDict[j] for j in i]
testLabel.append(label)
#trainLabel = sequence.pad_sequences(trainLabel, padding='post', maxlen=LABEL_LENGTH)
#validationLabel = sequence.pad_sequences(validationLabel, padding='post', maxlen=LABEL_LENGTH)
#testLabel = sequence.pad_sequences(testLabel, padding='post', maxlen=LABEL_LENGTH)
return labelNum, (trainImage, trainLabel), (validationImage, validationLabel), (testImage, testLabel)
def readImage(base, imageId):
info = imageId.split('_')
if TEXTLINE:
fileName = base + '_'.join(info[:-1]) + '/lines/' + imageId + '.png'
else:
fileName = base + '_'.join(info[:2]) + '/words/' + imageId + '.png'
img = cv2.imread(fileName, 0)
rate = float(IMG_HEIGHT) / img.shape[0]
img = cv2.resize(img, (int(img.shape[1]*rate), IMG_HEIGHT), interpolation=cv2.INTER_AREA)
#img = 1. - img.astype('float32') / 255.
ret, thresh = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)
binary = thresh/255.
img_width = binary.shape[-1]
#return binary, img_width
if TEXTLINE:
IMG_WIDTH = IMG_WIDTH_TEXTLINE
else:
IMG_WIDTH = IMG_WIDTH_WORD
outImg = np.zeros((IMG_HEIGHT, IMG_WIDTH), dtype='float32')
if img_width > IMG_WIDTH:
outImg = binary[:, :IMG_WIDTH]
else:
outImg[:, :img_width] = binary #outImg.shape (IMG_HEIGHT, IMG_WIDTH)
#output = np.transpose(outImg, (1, 0)) #output shape (IMG_WIDTH, IMG_HEIGHT)
return outImg, img_width
#return output[:, :, None] #return shape (IMG_WIDTH, IMG_HEIGHT, 1)
#img has fixed height of IMG_HEIGHT, and its value is between 0-1 0:background
#img = readImage(train, trainImage[0])
def getData(train_data_size=None, validation_data_size=None, test_data_size=None):
labelNum, (trainImage, trainLabel), (validationImage, validationLabel), (testImage, testLabel) = init()
trainImg = []
seqLen_train = []
for i in trainImage[:train_data_size]:
img, width = readImage(train, i)
trainImg.append(img)
seqLen_train.append(width)
#validationImg = [readImage(validation, i) for i in validationImage[:validation_data_size]]
validationImg = []
seqLen_validation = []
for i in validationImage[:validation_data_size]:
img, width = readImage(validation, i)
validationImg.append(img)
seqLen_validation.append(width)
testImg = []
seqLen_test = []
for i in testImage[:test_data_size]:
img, width = readImage(test, i)
testImg.append(img)
seqLen_test.append(width)
return labelNum, (trainImg, seqLen_train, trainLabel[:train_data_size]), (validationImg, seqLen_validation, validationLabel[:validation_data_size]), (testImg, seqLen_test, testLabel[:test_data_size])
if __name__ == '__main__':
#labelNum, (trainImg, seqLen_train, trainLabel), (validationImg, seqLen_validation, validationLabel), (testImg, seqLen_test, testLabel) = getData(50, 10, 20)
labelNum, (trainImg, seqLen_train, trainLabel), (validationImg, seqLen_validation, validationLabel), (testImg, seqLen_test, testLabel) = getData(None, None, None)
#print(max(seqLen_train), max(seqLen_validation), max(seqLen_test))
#print(len(trainLabel), len(validationLabel), len(testLabel))