Skip to content

Commit

Permalink
Use os.walk instead of os.listdir
Browse files Browse the repository at this point in the history
a3magic3pocket committed Mar 10, 2023
1 parent 022a849 commit a148f9f
Showing 1 changed file with 40 additions and 37 deletions.
77 changes: 40 additions & 37 deletions tool/create_dataset.py
Original file line number Diff line number Diff line change
@@ -2,13 +2,16 @@
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from multiprocessing import Pool


def checkImageIsValid(imageBin):
if imageBin is None:
return False
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
if img is None:
return False
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
@@ -24,8 +27,6 @@ def writeCache(env, cache):
else:
txn.put(str(k).encode(), v)



def createDataset(outputPath, imageDirPath, lexiconList=None, checkValid=True, db_volume=100995116):
"""
Create LMDB dataset for CRNN training.
@@ -37,47 +38,49 @@ def createDataset(outputPath, imageDirPath, lexiconList=None, checkValid=True, d
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
imageNames = [
f
for f in os.listdir(imageDirPath)
if os.path.isfile(os.path.join(imageDirPath, f))
]
nSamples = len(imageNames)
nSamples = 0
env = lmdb.open(outputPath, map_size=db_volume)
cache = {}
cnt = 1
for i in range(nSamples):
imagePath = os.path.join(imageDirPath, imageNames[i])
label = imageNames[i].split('_')[0].lstrip('0')
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'rb') as fd:
imageBin = fd.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
for _, _, imageNames in os.walk(imageDirPath):
imageNames.sort(key=lambda x: len(x.split('_')[0]))
nSamples = len(imageNames)

for i, imageName in enumerate(imageNames):
imagePath = os.path.join(imageDirPath, imageName)
label = imageName.split('_')[0]

if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'rb') as fd:
imageBin = fd.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image. remove it' % imagePath)
os.remove(imagePath)
continue

imageKey = 'image-%09d' % cnt
labelKey = 'label-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label
if lexiconList:
lexiconKey = 'lexicon-%09d' % cnt
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
nSamples = cnt-1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)
imageKey = 'image-%09d' % cnt
labelKey = 'label-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label
if lexiconList:
lexiconKey = 'lexicon-%09d' % cnt
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1

nSamples = cnt-1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)


if __name__ == '__main__':
createDataset('train_crnn_images_lmdb', 'train_crnn_images', db_volume=100995116)
createDataset('val_crnn_images_lmdb', 'val_crnn_images', db_volume=10099511)
createDataset('train_crnn_images_lmdb', 'train_crnn_images', db_volume=1009951160)
createDataset('val_crnn_images_lmdb', 'val_crnn_images', db_volume=100995116)
pass

0 comments on commit a148f9f

Please sign in to comment.