Skip to content

Commit

Permalink
supports training with airstore datasets (facebookresearch#296)
Browse files Browse the repository at this point in the history
Summary:
I created a new type of dataset "airstore" on top of new AIRStore ioPath tabular interface. To make the check point resuming work, I need replicate the logic that set_epoch and set_start_iter on sampler, do the same thing for dataset objs.
I have tested it in faircluster
* doc for setup the integration:  https://docs.google.com/document/d/1RNSjJcFTGl4Or-9SRLYbB-ffuOMidZy62SwBA9L0Qm0/edit?usp=sharing
* example output (imagenet with config quick_simclr_2node): /checkpoint/wpc/vissl/2021-04-14-17-03-36/checkpoints/log.txt

Pull Request resolved: facebookresearch#296

Reviewed By: prigoyal

Differential Revision: D27802354

Pulled By: wpc

fbshipit-source-id: 607e8623f8c9d3c48f5f09fe759bd41f6910b0f1
  • Loading branch information
wpc authored and facebook-github-bot committed Apr 19, 2021
1 parent 9a3e0ac commit 6f29d5f
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 13 deletions.
4 changes: 4 additions & 0 deletions configs/config/dataset_catalog.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
{
"airstore_imagenet": {
"train": ["airstore://flashblade_imagenet_train", "<unused>"],
"val": ["airstore://flashblade_imagenet_val", "<unused>"]
},
"imagenet1k_folder": {
"train": ["<img_path>", "<lbl_path>"],
"val": ["<img_path>", "<lbl_path>"]
Expand Down
2 changes: 1 addition & 1 deletion dev/launch_slurm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ CFG=( "$@" )
# create a temporary experiment folder to run the SLURM job in isolation
RUN_ID=$(date +'%Y-%m-%d-%H-%M-%S')
EXP_ROOT_DIR="/checkpoint/$USER/vissl/$RUN_ID"
CHECKPOINT_DIR="$EXP_ROOT_DIR/checkpoints/"
CHECKPOINT_DIR=${CHECKPOINT_DIR:-"$EXP_ROOT_DIR/checkpoints/"}

echo "EXP_ROOT_DIR: $EXP_ROOT_DIR"
echo "CHECKPOINT_DIR: $CHECKPOINT_DIR"
Expand Down
3 changes: 3 additions & 0 deletions vissl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from classy_vision.dataset import DataloaderAsyncGPUWrapper
from torch.utils.data import DataLoader
from vissl.data.airstore_dataset import AirstoreDataset
from vissl.data.collators import get_collator
from vissl.data.data_helper import (
DeterministicDistributedSampler,
Expand All @@ -26,13 +27,15 @@


__all__ = [
"AirstoreDataset",
"GenericSSLDataset",
"get_data_files",
"register_datasets",
"VisslDatasetCatalog",
]

DATASET_SOURCE_MAP = {
"airstore": AirstoreDataset,
"disk_filelist": DiskImageDataset,
"disk_folder": DiskImageDataset,
"torchvision_dataset": TorchvisionDataset,
Expand Down
163 changes: 163 additions & 0 deletions vissl/data/airstore_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import io
import logging
from typing import (
Any,
Iterable,
Tuple,
)

import torch
from classy_vision.generic.distributed_util import get_rank, get_world_size
from iopath.common.file_io import PathManager
from PIL import Image, ImageFile
from vissl.config import AttrDict
from vissl.data.data_helper import QueueDataset, get_mean_image


def create_path_manager() -> PathManager:
# TODO: move this inline import out after AIRStore OSS public released
from airstore.client.airstore_tabular import AIRStorePathHandler

pathmanager = PathManager()
pathmanager.register_handler(AIRStorePathHandler())
pathmanager.set_strict_kwargs_checking(False)
return pathmanager


class AirstoreDataset(QueueDataset):
def __init__(
self, cfg: AttrDict, data_source: str, path: str, split: str, dataset_name: str
):
super(AirstoreDataset, self).__init__(
queue_size=cfg["DATA"][split]["BATCHSIZE_PER_REPLICA"]
)
self.pathmanager = create_path_manager()
self.cfg = cfg
self.batch_size = cfg["DATA"][split]["BATCHSIZE_PER_REPLICA"]
self.airstore_uri = path
self.split = split
self.epoch = 0
self.start_iter = 0
self.enable_queue_dataset = cfg["DATA"][self.split]["ENABLE_QUEUE_DATASET"]
self.global_rank = get_rank()
self.global_world_size = get_world_size()
self._iterator = None

def set_epoch(self, epoch: int):
# set by trainer when train on new epoch or restore from a checkpoint
logging.info(f"set epoch to {epoch} in airstore dataset")
self.epoch = epoch

def set_start_iter(self, start_iter: int):
# set by trainer when train on restoring from a checkpoint
logging.info(f"set start_iter to {start_iter} in airstore dataset")
if start_iter < 0:
raise Exception(f"{start_iter} is not a valid iteration value")
self.start_iter = start_iter

def _calculate_skip_samples(self, num_workers: int, worker_id: int) -> int:
# this function is used for calcuate how many samples we should skip per
# each worker when resuming from a checkpoint
per_replica_skip = self.start_iter * self.batch_size
per_worker_skip = per_replica_skip // num_workers
# since dataloader fetching from each worker by roundrobin so we can
# calculate exactly which worker has one extra to skip when per_replica_skip
# can't be divided by num_workers cleanly
if worker_id < per_replica_skip % num_workers:
per_worker_skip += 1
return per_worker_skip

def _open_iterator(self) -> Iterable[Any]:
# data iterator from airstore for current data split.
# data are sharded by global total number of workers after shuffling

split_cfg = self.cfg["DATA"][self.split]

# extract numbers of dataloading workers and current worker id (range from
# 0 to num_workers-1) from torch.utils. If we can't get worker_info we
# assume the current process is the only dataloading worker.
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
num_workers = 1
worker_id = 0
else:
num_workers = worker_info.num_workers
worker_id = worker_info.id

# split the dataset for each worker
airstore_world_size = self.global_world_size * num_workers
# each worker take it's split by it's parent process rank and worker id
airstore_rank = self.global_rank * num_workers + worker_id

return self.pathmanager.opent(
self.airstore_uri,
"r",
skip_samples=self._calculate_skip_samples(num_workers, worker_id),
enable_shuffle=getattr(split_cfg, "AIRSTORE_ENABLE_SHUFFLE", True),
shuffle_window=getattr(split_cfg, "AIRSTORE_SHUFFLE_WINDOW", 128),
seed=self.epoch,
world_size=airstore_world_size,
rank=airstore_rank,
limit=getattr(split_cfg, "DATA_LIMIT", -1),
offset=getattr(split_cfg, "DATA_OFFSET", 0),
num_of_threads=getattr(split_cfg, "AIRSTORE_NUM_OF_THREADS", 2),
prefetch=getattr(split_cfg, "AIRSTORE_PREFETCH", 1),
max_holding_bundles=getattr(split_cfg, "AIRSTORE_MAX_HOLDING_BUNDLES", 5),
bundle_download_timeout_ms=getattr(
split_cfg, "AIRSTORE_BUNDLE_DOWNLOAD_TIMEOUT_MS", 30000
),
max_retries=getattr(split_cfg, "AIRSTORE_MAX_RETRIES", 5),
dataset_catalog_path=getattr(
split_cfg, "AIRSTORE_DS_CATALOG_PATH", None
), # temporary need during airstore development
env=getattr(
split_cfg, "AIRSTORE_ENV", "OSS"
), # env need set to "fb" if run in FB, otherwise set to "OSS"
)

def num_samples(self) -> int:
return self._open_iterator().total_size

def __len__(self) -> int:
return self.num_samples()

def __getitem__(self, index) -> Tuple[Image.Image, bool]:
if self._iterator is None:
self._iterator = self._open_iterator()

if not self.queue_init and self.enable_queue_dataset:
self._init_queues()

try:
# TODO (wpc, prigoyal): we should check images are good when we are
# uploading them to airstore.
ImageFile.LOAD_TRUNCATED_IMAGES = True

image_bytes = next(self._iterator)["image"]
img = Image.open(io.BytesIO(image_bytes))

if img.mode != "RGB":
img = img.convert("RGB")

if self.enable_queue_dataset:
self.on_sucess(img)
is_success = True
except Exception as e:
# TODO: airstore should have no failed images
# because they are filtered at prepare time.
# Then, this should be removed.
logging.warning(e)
is_success = False
# if we have queue dataset class enabled, we try to use it to get
# the seen valid images
if self.enable_queue_dataset:
img, is_success = self.on_failure()
if img is None:
raise RuntimeError(
"Encountered invalid image and couldn't load from QueueDataset"
)
else:
img = get_mean_image(self.cfg["DATA"][self.split].DEFAULT_GRAY_IMG_SIZE)
return img, is_success
12 changes: 12 additions & 0 deletions vissl/hooks/state_update_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
count_params,
)
from classy_vision.hooks.classy_hook import ClassyHook
from vissl.data import AirstoreDataset, GenericSSLDataset


class SSLModelComplexityHook(ClassyHook):
Expand Down Expand Up @@ -95,6 +96,17 @@ def on_phase_start(self, task: "tasks.ClassyTask") -> None:
if hasattr(task.dataloaders[phase_type].sampler, "set_epoch"):
# task.phase_idx is current running phase id
task.dataloaders[phase_type].sampler.set_epoch(task.phase_idx)

# call set_epoch and for AirstoreDataset since it handles shuffle
# behavior internally
if hasattr(task.dataloaders[phase_type], "dataset"):
dataset = task.dataloaders[phase_type].dataset
if isinstance(dataset, GenericSSLDataset):
for data_obj in dataset.data_objs:
if isinstance(data_obj, AirstoreDataset):
# task.phase_idx is current running phase id
data_obj.set_epoch(task.phase_idx)

logging.info(f"Starting phase {task.phase_idx} [{phase_type}]")


Expand Down
45 changes: 33 additions & 12 deletions vissl/trainer/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from fvcore.common.file_io import PathManager
from torch.cuda.amp import GradScaler as TorchGradScaler
from vissl.config import AttrDict
from vissl.data import build_dataset, get_loader, print_sampler_config
from vissl.data import (
build_dataset,
get_loader,
print_sampler_config,
AirstoreDataset,
GenericSSLDataset,
)
from vissl.models import build_model, convert_sync_bn
from vissl.optimizers import get_optimizer_param_groups
from vissl.utils.activation_checkpointing import manual_gradient_reduction
Expand Down Expand Up @@ -456,6 +462,17 @@ def _build_model(self):

return model

def _compute_start_iter_from_checkpoint(self, phase_type) -> int:
# used for calculating the start iteration (count from current epoch) when resuming
# from checkpoint
if self.checkpoint is None or self.checkpoint["iteration"] <= 0:
return 0

num_iters_in_epochs = len(self.dataloaders[phase_type])
num_epochs = self.checkpoint["train_phase_idx"] + 1
num_train_iters_done = num_epochs * num_iters_in_epochs
return self.checkpoint["iteration"] - num_train_iters_done

def recreate_data_iterator(self, phase_type, epoch, compute_start_iter):
"""
Recreate data iterator (including multiprocessing workers) and destroy the
Expand All @@ -467,26 +484,30 @@ def recreate_data_iterator(self, phase_type, epoch, compute_start_iter):
epoch and start_iteration so that the data is deterministically shuffled,
so we call them here.
"""
start_iter = 0
if compute_start_iter:
start_iter = self._compute_start_iter_from_checkpoint(phase_type)

if hasattr(self.dataloaders[phase_type], "sampler"):
sampler = self.dataloaders[phase_type].sampler
# (Re-)Shuffle data: set epoch of distributed (or fairstore) sampler.
if hasattr(sampler, "set_epoch"):
sampler.set_epoch(epoch)
# Resume from the iteration if valid
if hasattr(sampler, "set_start_iter"):
if (
compute_start_iter
and self.checkpoint is not None
and self.checkpoint["iteration"] > 0
):
num_iters_in_epochs = len(self.dataloaders[phase_type])
num_epochs = self.checkpoint["train_phase_idx"] + 1
num_train_iters_done = num_epochs * num_iters_in_epochs
start_iter = self.checkpoint["iteration"] - num_train_iters_done
else:
start_iter = 0
sampler.set_start_iter(start_iter)
print_sampler_config(sampler)

# call set_epoch and set_start_iter for AirstoreDataset since it handles
# shuffle and sample skipping behavior internally
if hasattr(self.dataloaders[phase_type], "dataset"):
dataset = self.dataloaders[phase_type].dataset
if isinstance(dataset, GenericSSLDataset):
for data_obj in dataset.data_objs:
if isinstance(data_obj, AirstoreDataset):
data_obj.set_epoch(epoch)
data_obj.set_start_iter(start_iter)

# delete the old data iterator
del self.data_iterator
gc.collect()
Expand Down

0 comments on commit 6f29d5f

Please sign in to comment.