Skip to content

Commit

Permalink
--update
Browse files Browse the repository at this point in the history
  • Loading branch information
ArrowLuo committed Jan 29, 2023
1 parent 24bd371 commit d48d825
Show file tree
Hide file tree
Showing 46 changed files with 6,418 additions and 0 deletions.
39 changes: 39 additions & 0 deletions dataloaders/data_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os.path

DATA_CONFIG_DICT = {}

ROOT_PATH_ = "./"

# coco_train2014.pkl
# coco_train2014_seg_scale224_sigma9.0_min_size224.lmdb
# karpathy
DATA_CONFIG_DICT["coco"] = {
"train": {
"features_path": os.path.join(ROOT_PATH_, ""),
"data_path": os.path.join(ROOT_PATH_, "karpathy"),
},
"val": {
"features_path": os.path.join(ROOT_PATH_, ""),
"data_path": os.path.join(ROOT_PATH_, "karpathy"),
},
"test": None
}

# cc3m_train_desc.pkl
# cc3m_train_lmdb_total
# cc3m_train_lmdb_total_keys.pkl
# cc3m_train_lmdb_total_seg_scale224_sigma9.0_min_size224.lmdb
# Train_GCC-training.tsv
# Validation_GCC-1.1.0-Validation.tsv
DATA_CONFIG_DICT["cc"] = {
"train": {
"features_path": os.path.join(ROOT_PATH_, ""),
"data_path": os.path.join(ROOT_PATH_, ""),
},
"val": {
"features_path": os.path.join(ROOT_PATH_, ""),
"data_path": os.path.join(ROOT_PATH_, ""),
},
"test": None
}

116 changes: 116 additions & 0 deletions dataloaders/data_dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torch
from torch.utils.data import ConcatDataset, DataLoader
from dataloaders.dataloader_coco_retrieval import COCO_DataLoader
from dataloaders.dataloader_cc_retrieval import GCC_DataLoader
from dataloaders.data_config import DATA_CONFIG_DICT
# pip install prefetch_generator
from prefetch_generator import BackgroundGenerator

class DataLoaderX(DataLoader):
def __iter__(self):
# transforms generator into a background-thead generator.
return BackgroundGenerator(super().__iter__(), max_prefetch=1)

DATALOADER_FCT_DICT_ = {}
DATALOADER_FCT_DICT_["coco"] = COCO_DataLoader
DATALOADER_FCT_DICT_["cc"] = GCC_DataLoader

def _get_dataset(args, tokenizer, dataloader_fct, data_path, features_path, subset="train"):
dataset = dataloader_fct(
subset=subset,
data_path=data_path,
features_path=features_path,
max_words=args.max_words,
tokenizer=tokenizer,
max_frames=1,
image_resolution=224,
vit_version=args.pretrained_clip_name,
use_felzenszwalb=args.use_seglabel,
)
return dataset

def _train_sampler_dataloader(args, dataset):
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
train_dataloader = DataLoaderX(
dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=True if args.use_pin_memory else False,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
)
return train_dataloader, train_sampler

def _test_sampler_dataloader(args, dataset):
test_dataloader = DataLoaderX(
dataset,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
pin_memory=True if args.use_pin_memory else False,
shuffle=False,
drop_last=False,
)
return test_dataloader

def _get_train_dataloader_fct(data_name):
def _train_dataloader_fct(args, tokenizer):
assert data_name in DATALOADER_FCT_DICT_, "{} not in DATALOADER_FCT_DICT".format(data_name)
assert data_name in DATA_CONFIG_DICT, "{} not in DATA_CONFIG_DICT".format(data_name)
dataloader_fct = DATALOADER_FCT_DICT_[data_name]
data_path = DATA_CONFIG_DICT[data_name]["train"]["data_path"]
features_path = DATA_CONFIG_DICT[data_name]["train"]["features_path"]
dataset = _get_dataset(args, tokenizer, dataloader_fct, data_path, features_path, subset="train")
train_dataloader, train_sampler = _train_sampler_dataloader(args, dataset)
return train_dataloader, len(dataset), train_sampler
return _train_dataloader_fct

def _get_test_dataloader_fct(data_name):
def _test_dataloader_fct(args, tokenizer, subset="test"):
assert data_name in DATALOADER_FCT_DICT_, "{} not in DATALOADER_FCT_DICT".format(data_name)
assert data_name in DATA_CONFIG_DICT, "{} not in DATA_CONFIG_DICT".format(data_name)
dataloader_fct = DATALOADER_FCT_DICT_[data_name]
data_path = DATA_CONFIG_DICT[data_name][subset]["data_path"]
features_path = DATA_CONFIG_DICT[data_name][subset]["features_path"]
testset = _get_dataset(args, tokenizer, dataloader_fct, data_path, features_path, subset=subset)
test_dataloader = _test_sampler_dataloader(args, testset)
return test_dataloader, len(testset)
return _test_dataloader_fct

def _get_train_multi_dataloader_fct(data_name):
def _train_dataloader_fct(args, tokenizer):
data_name_list = data_name.split(",")
dataset_list = []
for data_name_ in data_name_list:
if len(data_name_) == 0: continue
assert data_name_ in DATALOADER_FCT_DICT_, "{} not in DATALOADER_FCT_DICT".format(data_name_)
assert data_name_ in DATA_CONFIG_DICT, "{} not in DATA_CONFIG_DICT".format(data_name_)
dataloader_fct = DATALOADER_FCT_DICT_[data_name_]
data_path = DATA_CONFIG_DICT[data_name_]["train"]["data_path"]
features_path = DATA_CONFIG_DICT[data_name_]["train"]["features_path"]
dataset_ = _get_dataset(args, tokenizer, dataloader_fct, data_path, features_path, subset="train")
dataset_list.append(dataset_)

dataset = ConcatDataset(dataset_list)
train_dataloader, train_sampler = _train_sampler_dataloader(args, dataset)
return train_dataloader, len(dataset), train_sampler
return _train_dataloader_fct

dataloader_cc_train = _get_train_dataloader_fct("cc")
dataloader_cc_test = _get_test_dataloader_fct("cc")

dataloader_coco_train = _get_train_dataloader_fct("coco")
dataloader_coco_test = _get_test_dataloader_fct("coco")

class DataloaderDictClass(dict):
def __getitem__(self, item):
if item not in self and item.find(",") > -1:
train_loader_ = _get_train_multi_dataloader_fct(item)
v = {"train": train_loader_, "val": None, "test": None}
else:
v = super(DataloaderDictClass, self).__getitem__(item)
return v

DATALOADER_DICT = DataloaderDictClass()
DATALOADER_DICT["cc"] = {"train":dataloader_cc_train, "val":dataloader_cc_test, "test":None}
DATALOADER_DICT["coco"] = {"train":dataloader_coco_train, "val":dataloader_coco_test, "test":None}
57 changes: 57 additions & 0 deletions dataloaders/dataloader_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

import numpy as np
from torch.utils.data import Dataset
import torch.distributed as dist
from util import get_logger

class DatasetBase(Dataset):
def __init__(self, tokenizer, max_words=30, **kwargs):
self.tokenizer = tokenizer
self.max_words = max_words
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}

def _get_text(self, image_id, caption):
k = 1
choice_image_ids = [image_id]
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)

for i, image_id in enumerate(choice_image_ids):
words = self.tokenizer.tokenize(caption)

words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
total_length_with_CLS = self.max_words - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]

input_ids = self.tokenizer.convert_tokens_to_ids(words)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)
while len(input_ids) < self.max_words:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == self.max_words
assert len(input_mask) == self.max_words
assert len(segment_ids) == self.max_words

pairs_text[i] = np.array(input_ids)
pairs_mask[i] = np.array(input_mask)
pairs_segment[i] = np.array(segment_ids)

return pairs_text, pairs_mask, pairs_segment, choice_image_ids

def print_dist(self, info):
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
if rank == 0:
get_logger().info(info)
174 changes: 174 additions & 0 deletions dataloaders/dataloader_cc_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

import os
import io
import zlib
import numpy as np
import pickle
import json
from collections import defaultdict
import lmdb
import base64
from dataloaders.rawimage_util import RawImageExtractor
from dataloaders.rawimage_util import get_felzenszwalb_from_cache
from dataloaders.dataloader_base import DatasetBase

class GCC_DataLoader(DatasetBase):
"""GCC dataset loader."""
def __init__(
self,
subset,
data_path,
features_path,
tokenizer,
max_words=30,
max_frames=1,
image_resolution=224,
vit_version="ViT-B/32",
use_felzenszwalb=False,
):
super(GCC_DataLoader, self).__init__(tokenizer, max_words)
assert max_frames == 1, "GCC dataset is an image dataset."
self.data_path = data_path
self.features_path = features_path
self.max_words = max_words
self.max_frames = max_frames
self.tokenizer = tokenizer
self.image_resolution = image_resolution

self.subset = subset
assert self.subset in ["train", "val", "test"]

self.use_felzenszwalb = use_felzenszwalb and self.subset == "train"

csv_map = {"train": "cc3m_train_desc.pkl", "val": "cc3m_val_desc.pkl", "test": None}
assert csv_map[self.subset] is not None, "The caption file of {} is unavailable.".format(self.subset)

data_csv = os.path.join(data_path, csv_map[self.subset])
assert os.path.exists(data_csv), "Missed csv file, download from {}".\
format("https://ai.google.com/research/ConceptualCaptions/download")

features_map = {"train":"cc3m_train_lmdb_total", "val":"cc3m_val.pkl", "test":None}
assert features_map[self.subset] is not None, "The feature of {} is unavailable.".format(self.subset)

scale, sigma, min_size = 224, 0.9, 224
seg_path_ = "cc3m_train_lmdb_total_seg_scale{}_sigma{}_min_size{}.lmdb".format(scale, sigma * 10, min_size)
seg_map = {"train": seg_path_, "val": None, "test": None}
assert seg_map[self.subset] is not None, "The feature of {} is unavailable.".format(self.subset)

with open(data_csv, 'rb') as f:
captions_dict_ = pickle.load(f)
self.captions_dict = captions_dict_

self.seg_lmdb_path = None
self.seg_env = None
self.seg_txn = None
if self.use_felzenszwalb:
seg_lmdb_path = os.path.join(features_path, seg_map[self.subset])
self.seg_lmdb_path = seg_lmdb_path

if self.subset == "train":
lmdb_path = os.path.join(features_path, features_map[self.subset])
lmdb_keys_path = os.path.join(features_path, features_map[self.subset] + "_keys.pkl")
# env and txn is delay-loaded in ddp.
self.lmdb_path = lmdb_path
self.env = None
self.txn = None
with open(lmdb_keys_path, 'rb') as f:
lmdb_keys = pickle.load(f)
self.img_keys = lmdb_keys['key']
else:
features_path = os.path.join(self.features_path, features_map[self.subset])
with open(features_path, 'rb') as f:
img_data = pickle.load(f)
self.img_data = img_data
self.img_keys = list(self.img_data.keys())

self.print_dist("Total Pair: {}".format(len(self.img_keys)))

self.sample_len = len(self.img_keys)
self.rawImageExtractor = RawImageExtractor(is_train=True, size=self.image_resolution)

def __len__(self):
return self.sample_len

def _init_env(self):
self.env = lmdb.open(self.lmdb_path, map_size=96 * 1024 * 1024 * 1024, subdir=True,
readonly=True, readahead=False, meminit=False, max_spare_txns=1, lock=False)
self.txn = self.env.begin(write=False, buffers=True)

def _init_seg_env(self):
self.seg_env = lmdb.open(self.seg_lmdb_path, map_size=96 * 1024 * 1024 * 1024, subdir=True,
readonly=True, readahead=False, meminit=False, max_spare_txns=1, lock=False)
self.seg_txn = self.seg_env.begin(write=False, buffers=True)


def __exit__(self, exc_type, exc_val, exc_tb):
if self.txn is not None:
self.txn.__exit__(exc_type, exc_val, exc_tb)
if self.env is not None:
self.env.close()
if self.seg_txn is not None:
self.seg_txn.__exit__(exc_type, exc_val, exc_tb)
if self.seg_env is not None:
self.seg_env.close()

def _get_rawimage(self, image_id, aug_images=False):
# Pair x 3 x H x W, Pair is 3 as using two extra views of image
image = np.zeros((1 if aug_images is False else 3, 3, self.image_resolution, self.image_resolution), dtype=np.float)
coord = np.zeros((1 if aug_images is False else 3, 4), dtype=np.float)

if self.subset == "train":
image_bytes = self.txn.get(image_id.encode('ascii'))
else:
image_bytes = self.img_data[image_id]

get_image_bool = True
try:
raw_image_data = self.rawImageExtractor.get_image_data_from_bytes(image_bytes, paired_aug=aug_images)
for id_, (k_, image_data_) in enumerate(raw_image_data.items()):
image_data_, coord_ = image_data_
image[id_] = image_data_ # 3 x H x W
coord[id_] = coord_ # 4
except Exception as excep:
self.print_dist("Raw Image reading Error in CC3M!")
get_image_bool = False

return image, coord, get_image_bool

def __getitem__(self, idx):
if self.subset == "train" and self.env is None:
self._init_env()
if self.subset == "train" and self.seg_lmdb_path is not None \
and self.seg_env is None:
self._init_seg_env()

get_image_bool = False
retry_num = 0
while get_image_bool is False:
image_id = self.img_keys[idx]
caption = self.captions_dict[image_id]["caption"]

pairs_text, pairs_mask, pairs_segment, choice_image_ids = self._get_text(image_id, caption)
image, coord, get_image_bool = self._get_rawimage(image_id)

if get_image_bool is False:
idx = (idx + 1) % self.sample_len
retry_num += 1
if retry_num > 50:
raise ValueError("Retry Limited: {}".format(retry_num))

if self.use_felzenszwalb:
seg4image_ = np.array(json.loads(zlib.decompress(self.seg_txn.get(image_id.encode('ascii')))), dtype=np.long)
seg4image_ = seg4image_[2:].reshape(seg4image_[0], seg4image_[1])
image_seg = get_felzenszwalb_from_cache(seg4image_, coord, img_size=self.image_resolution, patch_size=16)

return_tuple = (pairs_text, pairs_mask, pairs_segment, image, coord)

if self.use_felzenszwalb:
return_tuple = return_tuple + (image_seg,)

return return_tuple
Loading

0 comments on commit d48d825

Please sign in to comment.