-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
576 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
# -*- coding: UTF-8 -*- | ||
# @Time : 2022/9/11 23:01 | ||
# @File : dataset_collection.py | ||
# @Author : jian<jian@mltalker.com> | ||
from __future__ import division | ||
from __future__ import unicode_literals | ||
from __future__ import print_function | ||
from .data_collection import DataCollection, DataFrame | ||
from .entity import Entity | ||
from .image import * | ||
from .common import * | ||
from antgo.pipeline.hparam import HyperParameter as State | ||
from antgo.pipeline.hparam import param_scope | ||
from antgo.pipeline.hparam import dynamic_dispatch | ||
from antgo.pipeline.functional.common.config import * | ||
from antgo.dataflow.dataset.base_coco_style_dataset import BaseCocoStyleDataset | ||
from tfrecord.reader import * | ||
from tfrecord import iterator_utils | ||
from tfrecord import example_pb2 | ||
from antgo.dataflow.datasetio import * | ||
import numpy as np | ||
import json | ||
import os | ||
import cv2 | ||
import yaml | ||
|
||
|
||
@dynamic_dispatch | ||
def coco_format_dc(dir, ann_file, data_prefix, mode='detect', normalize=False): | ||
coco_style_dataset = BaseCocoStyleDataset( | ||
dir=dir, | ||
ann_file=ann_file, | ||
data_prefix=data_prefix, | ||
data_mode='bottomup' | ||
) | ||
|
||
def inner(): | ||
sample_num = len(coco_style_dataset) | ||
for sample_i in range(sample_num): | ||
sample_info = coco_style_dataset[sample_i] | ||
|
||
bboxes = sample_info['bboxes'] | ||
if normalize: | ||
for box_info in bboxes: | ||
x0,y0,w,h = box_info | ||
box_info[2] = x0 + w | ||
box_info[3] = y0 + h | ||
|
||
export_info = { | ||
'image': sample_info['image'], | ||
'bboxes': bboxes, | ||
'labels': sample_info['category_id'], | ||
'joints2d': sample_info['keypoints'], | ||
'joints_vis': sample_info['keypoints_visible'] | ||
} | ||
|
||
entity = Entity()(**export_info) | ||
yield entity | ||
|
||
return DataFrame(inner()) | ||
|
||
|
||
@dynamic_dispatch | ||
def yolo_format_dc(ann_file, mode='detect', stage='train', normalize=False): | ||
assert(stage in ['train', 'val']) | ||
with open(ann_file, "r", errors="ignore", encoding="utf-8") as f: | ||
data = yaml.safe_load(f) | ||
|
||
ann_folder = os.path.dirname(ann_file) | ||
data_folder = data['path'] | ||
image_folder_map = { | ||
'train': os.path.join(ann_folder, data_folder, data['train']), | ||
'val': os.path.join(ann_folder, data_folder, data['val']) | ||
} | ||
label_folder_map = { | ||
'train': os.path.join(ann_folder, data_folder, data['train'].replace('images', 'labels')), | ||
'val': os.path.join(ann_folder, data_folder, data['val'].replace('images', 'labels')) | ||
} | ||
file_name_list = os.listdir(image_folder_map[stage]) | ||
file_name_list = [name for name in file_name_list if name[0] != '.'] | ||
|
||
category_map = data["names"] | ||
def inner(): | ||
sample_num = len(file_name_list) | ||
for sample_i in range(sample_num): | ||
file_name = file_name_list[sample_i] | ||
image_path = f'{image_folder_map[stage]}/{file_name}' | ||
label_path = f'{label_folder_map[stage]}/{file_name.split(".")[0]}.txt' | ||
|
||
image = cv2.imread(image_path) | ||
image_h, image_w = image.shape[:2] | ||
export_info = { | ||
'image': image | ||
} | ||
if mode == 'detect': | ||
with open(label_path, 'r') as fp: | ||
content = fp.readline().strip() | ||
bboxes = [] | ||
labels = [] | ||
while content: | ||
class_id, cx,cy,w,h = content.split(' ') | ||
cx = float(cx) | ||
cy = float(cy) | ||
w = float(w) | ||
h = float(h) | ||
if normalize: | ||
x0,y0,x1,y1 = (cx - w/2)*image_w, (cy - h/2)*image_h, (cx + w/2)*image_w, (cy + h/2)*image_h | ||
bboxes.append([ | ||
x0,y0,x1,y1 | ||
]) | ||
else: | ||
bboxes.append([ | ||
cx,cy,w,h | ||
]) | ||
|
||
labels.append(int(class_id)) | ||
content = fp.readline().strip() | ||
|
||
export_info['bboxes'] = np.array(bboxes) | ||
export_info['labels'] = np.array(labels) | ||
|
||
entity = Entity()(**export_info) | ||
yield entity | ||
|
||
return DataFrame(inner()) | ||
|
||
|
||
def _order_iterators(iterators): | ||
iterators = [iterator() for iterator in iterators] | ||
choice = 0 | ||
while iterators: | ||
try: | ||
yield next(iterators[choice]) | ||
except StopIteration: | ||
if iterators: | ||
del iterators[choice] | ||
choice += 1 | ||
|
||
|
||
def _transform(description, sample): | ||
new_sample = {} | ||
for k in sample.keys(): | ||
if k == 'image': | ||
image = cv2.imdecode(np.frombuffer(sample[k], dtype='uint8'), 1) # BGR mode, but need RGB mode | ||
new_sample[k] = image | ||
continue | ||
if not k.startswith('__'): | ||
if description[k] == 'numpy': | ||
dtype = numpy_dtype_map[sample[f'__{k}_type'][0]] | ||
shape = tuple(sample[f'__{k}_shape']) | ||
if isinstance(sample[k], bytes): | ||
new_sample[k] = np.frombuffer(bytearray(sample[k]), dtype=dtype).reshape(shape).copy() | ||
else: | ||
new_sample[k] = np.frombuffer(bytearray(sample[k].tobytes()), dtype=dtype).reshape(shape).copy() | ||
|
||
if new_sample[k].dtype == np.float64: | ||
new_sample[k] = new_sample[k].astype(np.float32) | ||
if k == 'bboxes' and new_sample[k].dtype != np.float32: | ||
new_sample[k] = new_sample[k].astype(np.float32) | ||
elif description[k] == 'str': | ||
new_sample[k] = sample[k].tobytes().decode('utf-8') | ||
elif description[k] == 'dict': | ||
new_sample[k] = json.loads(sample[k].tobytes().decode('utf-8')) | ||
else: | ||
new_sample[k] = sample[k] | ||
|
||
return new_sample | ||
|
||
|
||
@dynamic_dispatch | ||
def tfrecord_format_dc(dir, mode='detect'): | ||
# 遍历文件夹,发现所有tfrecord数据 | ||
dataset_folders = dir | ||
if isinstance(dir, str): | ||
dataset_folders = [dir] | ||
data_path_list = [] | ||
index_path_list = [] | ||
|
||
for dataset_folder in dataset_folders: | ||
part_path_list = [] | ||
for tfrecord_file in os.listdir(dataset_folder): | ||
if tfrecord_file.endswith('tfrecord'): | ||
tfrecord_file = '-'.join(tfrecord_file.split('/')[-1].split('-')[:-1]+['tfrecord']) | ||
part_path_list.append(f'{dataset_folder}/{tfrecord_file}') | ||
|
||
part_index_path_list = [] | ||
for i in range(len(part_path_list)): | ||
tfrecord_file = part_path_list[i] | ||
folder = os.path.dirname(tfrecord_file) | ||
if tfrecord_file.endswith('tfrecord') or tfrecord_file.endswith('index'): | ||
index_file = '-'.join(tfrecord_file.split('/')[-1].split('-')[:-1]+['index']) | ||
index_file = f'{folder}/{index_file}' | ||
tfrecord_file = '-'.join(tfrecord_file.split('/')[-1].split('-')[:-1]+['tfrecord']) | ||
part_path_list[i] = f'{folder}/{tfrecord_file}' | ||
else: | ||
index_file = tfrecord_file+'-index' | ||
part_path_list[i] = tfrecord_file+'-tfrecord' | ||
part_index_path_list.append(index_file) | ||
|
||
data_path_list.extend(part_path_list) | ||
index_path_list.extend(part_index_path_list) | ||
|
||
num_samples = 0 | ||
num_samples_list = [] | ||
for i, index_path in enumerate(index_path_list): | ||
index = np.loadtxt(index_path, dtype=np.int64)[:, 0] | ||
num_samples += len(index) | ||
num_samples_list.append(len(index)) | ||
|
||
# 分析解析信息 | ||
description = {} | ||
if mode == "detect": | ||
description = { | ||
'image': 'byte', | ||
'bboxes': 'numpy', | ||
'labels': 'numpy' | ||
} | ||
elif mode == "segment": | ||
pass | ||
elif mode == "pose": | ||
pass | ||
|
||
inner_description = {} | ||
for k, v in description.items(): | ||
if v == 'numpy': | ||
inner_description.update({ | ||
k: 'byte', | ||
f'__{k}_type': 'int', | ||
f'__{k}_shape': 'int' | ||
}) | ||
elif v == 'str': | ||
inner_description.update({ | ||
k: 'byte' | ||
}) | ||
elif v == 'dict': | ||
inner_description.update({ | ||
k: 'byte' | ||
}) | ||
else: | ||
inner_description.update({ | ||
k: v | ||
}) | ||
|
||
loaders = [functools.partial(tfrecord_loader, data_path=data_path, | ||
index_path=index_path, | ||
shard=None, | ||
description=inner_description, | ||
sequence_description=None, | ||
compression_type=None, | ||
) | ||
for data_path, index_path in zip(data_path_list, index_path_list)] | ||
|
||
it = _order_iterators(loaders) | ||
_transform_func = lambda x: _transform(description, x) | ||
it = map(_transform_func, it) | ||
|
||
def inner(): | ||
while True: | ||
export_info = next(it) | ||
entity = Entity()(**export_info) | ||
yield entity | ||
|
||
return DataFrame(inner()) | ||
|
||
|
||
class _dataset_dc(object): | ||
def __getattr__(self, name): | ||
if name not in ['coco','yolo','tfrecord']: | ||
return None | ||
|
||
return globals()[f'{name}_format_dc'] | ||
|
||
|
||
dataset_dc = _dataset_dc() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.