-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
46 changed files
with
6,418 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import os.path | ||
|
||
DATA_CONFIG_DICT = {} | ||
|
||
ROOT_PATH_ = "./" | ||
|
||
# coco_train2014.pkl | ||
# coco_train2014_seg_scale224_sigma9.0_min_size224.lmdb | ||
# karpathy | ||
DATA_CONFIG_DICT["coco"] = { | ||
"train": { | ||
"features_path": os.path.join(ROOT_PATH_, ""), | ||
"data_path": os.path.join(ROOT_PATH_, "karpathy"), | ||
}, | ||
"val": { | ||
"features_path": os.path.join(ROOT_PATH_, ""), | ||
"data_path": os.path.join(ROOT_PATH_, "karpathy"), | ||
}, | ||
"test": None | ||
} | ||
|
||
# cc3m_train_desc.pkl | ||
# cc3m_train_lmdb_total | ||
# cc3m_train_lmdb_total_keys.pkl | ||
# cc3m_train_lmdb_total_seg_scale224_sigma9.0_min_size224.lmdb | ||
# Train_GCC-training.tsv | ||
# Validation_GCC-1.1.0-Validation.tsv | ||
DATA_CONFIG_DICT["cc"] = { | ||
"train": { | ||
"features_path": os.path.join(ROOT_PATH_, ""), | ||
"data_path": os.path.join(ROOT_PATH_, ""), | ||
}, | ||
"val": { | ||
"features_path": os.path.join(ROOT_PATH_, ""), | ||
"data_path": os.path.join(ROOT_PATH_, ""), | ||
}, | ||
"test": None | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import torch | ||
from torch.utils.data import ConcatDataset, DataLoader | ||
from dataloaders.dataloader_coco_retrieval import COCO_DataLoader | ||
from dataloaders.dataloader_cc_retrieval import GCC_DataLoader | ||
from dataloaders.data_config import DATA_CONFIG_DICT | ||
# pip install prefetch_generator | ||
from prefetch_generator import BackgroundGenerator | ||
|
||
class DataLoaderX(DataLoader): | ||
def __iter__(self): | ||
# transforms generator into a background-thead generator. | ||
return BackgroundGenerator(super().__iter__(), max_prefetch=1) | ||
|
||
DATALOADER_FCT_DICT_ = {} | ||
DATALOADER_FCT_DICT_["coco"] = COCO_DataLoader | ||
DATALOADER_FCT_DICT_["cc"] = GCC_DataLoader | ||
|
||
def _get_dataset(args, tokenizer, dataloader_fct, data_path, features_path, subset="train"): | ||
dataset = dataloader_fct( | ||
subset=subset, | ||
data_path=data_path, | ||
features_path=features_path, | ||
max_words=args.max_words, | ||
tokenizer=tokenizer, | ||
max_frames=1, | ||
image_resolution=224, | ||
vit_version=args.pretrained_clip_name, | ||
use_felzenszwalb=args.use_seglabel, | ||
) | ||
return dataset | ||
|
||
def _train_sampler_dataloader(args, dataset): | ||
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) | ||
train_dataloader = DataLoaderX( | ||
dataset, | ||
batch_size=args.batch_size // args.n_gpu, | ||
num_workers=args.num_thread_reader, | ||
pin_memory=True if args.use_pin_memory else False, | ||
shuffle=(train_sampler is None), | ||
sampler=train_sampler, | ||
drop_last=True, | ||
) | ||
return train_dataloader, train_sampler | ||
|
||
def _test_sampler_dataloader(args, dataset): | ||
test_dataloader = DataLoaderX( | ||
dataset, | ||
batch_size=args.batch_size_val, | ||
num_workers=args.num_thread_reader, | ||
pin_memory=True if args.use_pin_memory else False, | ||
shuffle=False, | ||
drop_last=False, | ||
) | ||
return test_dataloader | ||
|
||
def _get_train_dataloader_fct(data_name): | ||
def _train_dataloader_fct(args, tokenizer): | ||
assert data_name in DATALOADER_FCT_DICT_, "{} not in DATALOADER_FCT_DICT".format(data_name) | ||
assert data_name in DATA_CONFIG_DICT, "{} not in DATA_CONFIG_DICT".format(data_name) | ||
dataloader_fct = DATALOADER_FCT_DICT_[data_name] | ||
data_path = DATA_CONFIG_DICT[data_name]["train"]["data_path"] | ||
features_path = DATA_CONFIG_DICT[data_name]["train"]["features_path"] | ||
dataset = _get_dataset(args, tokenizer, dataloader_fct, data_path, features_path, subset="train") | ||
train_dataloader, train_sampler = _train_sampler_dataloader(args, dataset) | ||
return train_dataloader, len(dataset), train_sampler | ||
return _train_dataloader_fct | ||
|
||
def _get_test_dataloader_fct(data_name): | ||
def _test_dataloader_fct(args, tokenizer, subset="test"): | ||
assert data_name in DATALOADER_FCT_DICT_, "{} not in DATALOADER_FCT_DICT".format(data_name) | ||
assert data_name in DATA_CONFIG_DICT, "{} not in DATA_CONFIG_DICT".format(data_name) | ||
dataloader_fct = DATALOADER_FCT_DICT_[data_name] | ||
data_path = DATA_CONFIG_DICT[data_name][subset]["data_path"] | ||
features_path = DATA_CONFIG_DICT[data_name][subset]["features_path"] | ||
testset = _get_dataset(args, tokenizer, dataloader_fct, data_path, features_path, subset=subset) | ||
test_dataloader = _test_sampler_dataloader(args, testset) | ||
return test_dataloader, len(testset) | ||
return _test_dataloader_fct | ||
|
||
def _get_train_multi_dataloader_fct(data_name): | ||
def _train_dataloader_fct(args, tokenizer): | ||
data_name_list = data_name.split(",") | ||
dataset_list = [] | ||
for data_name_ in data_name_list: | ||
if len(data_name_) == 0: continue | ||
assert data_name_ in DATALOADER_FCT_DICT_, "{} not in DATALOADER_FCT_DICT".format(data_name_) | ||
assert data_name_ in DATA_CONFIG_DICT, "{} not in DATA_CONFIG_DICT".format(data_name_) | ||
dataloader_fct = DATALOADER_FCT_DICT_[data_name_] | ||
data_path = DATA_CONFIG_DICT[data_name_]["train"]["data_path"] | ||
features_path = DATA_CONFIG_DICT[data_name_]["train"]["features_path"] | ||
dataset_ = _get_dataset(args, tokenizer, dataloader_fct, data_path, features_path, subset="train") | ||
dataset_list.append(dataset_) | ||
|
||
dataset = ConcatDataset(dataset_list) | ||
train_dataloader, train_sampler = _train_sampler_dataloader(args, dataset) | ||
return train_dataloader, len(dataset), train_sampler | ||
return _train_dataloader_fct | ||
|
||
dataloader_cc_train = _get_train_dataloader_fct("cc") | ||
dataloader_cc_test = _get_test_dataloader_fct("cc") | ||
|
||
dataloader_coco_train = _get_train_dataloader_fct("coco") | ||
dataloader_coco_test = _get_test_dataloader_fct("coco") | ||
|
||
class DataloaderDictClass(dict): | ||
def __getitem__(self, item): | ||
if item not in self and item.find(",") > -1: | ||
train_loader_ = _get_train_multi_dataloader_fct(item) | ||
v = {"train": train_loader_, "val": None, "test": None} | ||
else: | ||
v = super(DataloaderDictClass, self).__getitem__(item) | ||
return v | ||
|
||
DATALOADER_DICT = DataloaderDictClass() | ||
DATALOADER_DICT["cc"] = {"train":dataloader_cc_train, "val":dataloader_cc_test, "test":None} | ||
DATALOADER_DICT["coco"] = {"train":dataloader_coco_train, "val":dataloader_coco_test, "test":None} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import unicode_literals | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
from torch.utils.data import Dataset | ||
import torch.distributed as dist | ||
from util import get_logger | ||
|
||
class DatasetBase(Dataset): | ||
def __init__(self, tokenizer, max_words=30, **kwargs): | ||
self.tokenizer = tokenizer | ||
self.max_words = max_words | ||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | ||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | ||
|
||
def _get_text(self, image_id, caption): | ||
k = 1 | ||
choice_image_ids = [image_id] | ||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | ||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | ||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | ||
|
||
for i, image_id in enumerate(choice_image_ids): | ||
words = self.tokenizer.tokenize(caption) | ||
|
||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | ||
total_length_with_CLS = self.max_words - 1 | ||
if len(words) > total_length_with_CLS: | ||
words = words[:total_length_with_CLS] | ||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | ||
|
||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | ||
input_mask = [1] * len(input_ids) | ||
segment_ids = [0] * len(input_ids) | ||
while len(input_ids) < self.max_words: | ||
input_ids.append(0) | ||
input_mask.append(0) | ||
segment_ids.append(0) | ||
assert len(input_ids) == self.max_words | ||
assert len(input_mask) == self.max_words | ||
assert len(segment_ids) == self.max_words | ||
|
||
pairs_text[i] = np.array(input_ids) | ||
pairs_mask[i] = np.array(input_mask) | ||
pairs_segment[i] = np.array(segment_ids) | ||
|
||
return pairs_text, pairs_mask, pairs_segment, choice_image_ids | ||
|
||
def print_dist(self, info): | ||
if dist.is_available() and dist.is_initialized(): | ||
rank = dist.get_rank() | ||
else: | ||
rank = 0 | ||
if rank == 0: | ||
get_logger().info(info) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import unicode_literals | ||
from __future__ import print_function | ||
|
||
import os | ||
import io | ||
import zlib | ||
import numpy as np | ||
import pickle | ||
import json | ||
from collections import defaultdict | ||
import lmdb | ||
import base64 | ||
from dataloaders.rawimage_util import RawImageExtractor | ||
from dataloaders.rawimage_util import get_felzenszwalb_from_cache | ||
from dataloaders.dataloader_base import DatasetBase | ||
|
||
class GCC_DataLoader(DatasetBase): | ||
"""GCC dataset loader.""" | ||
def __init__( | ||
self, | ||
subset, | ||
data_path, | ||
features_path, | ||
tokenizer, | ||
max_words=30, | ||
max_frames=1, | ||
image_resolution=224, | ||
vit_version="ViT-B/32", | ||
use_felzenszwalb=False, | ||
): | ||
super(GCC_DataLoader, self).__init__(tokenizer, max_words) | ||
assert max_frames == 1, "GCC dataset is an image dataset." | ||
self.data_path = data_path | ||
self.features_path = features_path | ||
self.max_words = max_words | ||
self.max_frames = max_frames | ||
self.tokenizer = tokenizer | ||
self.image_resolution = image_resolution | ||
|
||
self.subset = subset | ||
assert self.subset in ["train", "val", "test"] | ||
|
||
self.use_felzenszwalb = use_felzenszwalb and self.subset == "train" | ||
|
||
csv_map = {"train": "cc3m_train_desc.pkl", "val": "cc3m_val_desc.pkl", "test": None} | ||
assert csv_map[self.subset] is not None, "The caption file of {} is unavailable.".format(self.subset) | ||
|
||
data_csv = os.path.join(data_path, csv_map[self.subset]) | ||
assert os.path.exists(data_csv), "Missed csv file, download from {}".\ | ||
format("https://ai.google.com/research/ConceptualCaptions/download") | ||
|
||
features_map = {"train":"cc3m_train_lmdb_total", "val":"cc3m_val.pkl", "test":None} | ||
assert features_map[self.subset] is not None, "The feature of {} is unavailable.".format(self.subset) | ||
|
||
scale, sigma, min_size = 224, 0.9, 224 | ||
seg_path_ = "cc3m_train_lmdb_total_seg_scale{}_sigma{}_min_size{}.lmdb".format(scale, sigma * 10, min_size) | ||
seg_map = {"train": seg_path_, "val": None, "test": None} | ||
assert seg_map[self.subset] is not None, "The feature of {} is unavailable.".format(self.subset) | ||
|
||
with open(data_csv, 'rb') as f: | ||
captions_dict_ = pickle.load(f) | ||
self.captions_dict = captions_dict_ | ||
|
||
self.seg_lmdb_path = None | ||
self.seg_env = None | ||
self.seg_txn = None | ||
if self.use_felzenszwalb: | ||
seg_lmdb_path = os.path.join(features_path, seg_map[self.subset]) | ||
self.seg_lmdb_path = seg_lmdb_path | ||
|
||
if self.subset == "train": | ||
lmdb_path = os.path.join(features_path, features_map[self.subset]) | ||
lmdb_keys_path = os.path.join(features_path, features_map[self.subset] + "_keys.pkl") | ||
# env and txn is delay-loaded in ddp. | ||
self.lmdb_path = lmdb_path | ||
self.env = None | ||
self.txn = None | ||
with open(lmdb_keys_path, 'rb') as f: | ||
lmdb_keys = pickle.load(f) | ||
self.img_keys = lmdb_keys['key'] | ||
else: | ||
features_path = os.path.join(self.features_path, features_map[self.subset]) | ||
with open(features_path, 'rb') as f: | ||
img_data = pickle.load(f) | ||
self.img_data = img_data | ||
self.img_keys = list(self.img_data.keys()) | ||
|
||
self.print_dist("Total Pair: {}".format(len(self.img_keys))) | ||
|
||
self.sample_len = len(self.img_keys) | ||
self.rawImageExtractor = RawImageExtractor(is_train=True, size=self.image_resolution) | ||
|
||
def __len__(self): | ||
return self.sample_len | ||
|
||
def _init_env(self): | ||
self.env = lmdb.open(self.lmdb_path, map_size=96 * 1024 * 1024 * 1024, subdir=True, | ||
readonly=True, readahead=False, meminit=False, max_spare_txns=1, lock=False) | ||
self.txn = self.env.begin(write=False, buffers=True) | ||
|
||
def _init_seg_env(self): | ||
self.seg_env = lmdb.open(self.seg_lmdb_path, map_size=96 * 1024 * 1024 * 1024, subdir=True, | ||
readonly=True, readahead=False, meminit=False, max_spare_txns=1, lock=False) | ||
self.seg_txn = self.seg_env.begin(write=False, buffers=True) | ||
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
if self.txn is not None: | ||
self.txn.__exit__(exc_type, exc_val, exc_tb) | ||
if self.env is not None: | ||
self.env.close() | ||
if self.seg_txn is not None: | ||
self.seg_txn.__exit__(exc_type, exc_val, exc_tb) | ||
if self.seg_env is not None: | ||
self.seg_env.close() | ||
|
||
def _get_rawimage(self, image_id, aug_images=False): | ||
# Pair x 3 x H x W, Pair is 3 as using two extra views of image | ||
image = np.zeros((1 if aug_images is False else 3, 3, self.image_resolution, self.image_resolution), dtype=np.float) | ||
coord = np.zeros((1 if aug_images is False else 3, 4), dtype=np.float) | ||
|
||
if self.subset == "train": | ||
image_bytes = self.txn.get(image_id.encode('ascii')) | ||
else: | ||
image_bytes = self.img_data[image_id] | ||
|
||
get_image_bool = True | ||
try: | ||
raw_image_data = self.rawImageExtractor.get_image_data_from_bytes(image_bytes, paired_aug=aug_images) | ||
for id_, (k_, image_data_) in enumerate(raw_image_data.items()): | ||
image_data_, coord_ = image_data_ | ||
image[id_] = image_data_ # 3 x H x W | ||
coord[id_] = coord_ # 4 | ||
except Exception as excep: | ||
self.print_dist("Raw Image reading Error in CC3M!") | ||
get_image_bool = False | ||
|
||
return image, coord, get_image_bool | ||
|
||
def __getitem__(self, idx): | ||
if self.subset == "train" and self.env is None: | ||
self._init_env() | ||
if self.subset == "train" and self.seg_lmdb_path is not None \ | ||
and self.seg_env is None: | ||
self._init_seg_env() | ||
|
||
get_image_bool = False | ||
retry_num = 0 | ||
while get_image_bool is False: | ||
image_id = self.img_keys[idx] | ||
caption = self.captions_dict[image_id]["caption"] | ||
|
||
pairs_text, pairs_mask, pairs_segment, choice_image_ids = self._get_text(image_id, caption) | ||
image, coord, get_image_bool = self._get_rawimage(image_id) | ||
|
||
if get_image_bool is False: | ||
idx = (idx + 1) % self.sample_len | ||
retry_num += 1 | ||
if retry_num > 50: | ||
raise ValueError("Retry Limited: {}".format(retry_num)) | ||
|
||
if self.use_felzenszwalb: | ||
seg4image_ = np.array(json.loads(zlib.decompress(self.seg_txn.get(image_id.encode('ascii')))), dtype=np.long) | ||
seg4image_ = seg4image_[2:].reshape(seg4image_[0], seg4image_[1]) | ||
image_seg = get_felzenszwalb_from_cache(seg4image_, coord, img_size=self.image_resolution, patch_size=16) | ||
|
||
return_tuple = (pairs_text, pairs_mask, pairs_segment, image, coord) | ||
|
||
if self.use_felzenszwalb: | ||
return_tuple = return_tuple + (image_seg,) | ||
|
||
return return_tuple |
Oops, something went wrong.