Skip to content

Commit

Permalink
Add support to use different numbers of samples in the queue (#795)
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar authored May 15, 2022
1 parent cc9ccb3 commit 8ee2ab2
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 30 deletions.
32 changes: 27 additions & 5 deletions tests/data/test_queue.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from torch.utils.data import DataLoader

import torch
import torchio as tio
from torchio.data import UniformSampler
from torchio import SubjectsDataset, Queue, DATA
from torchio.utils import create_dummy_dataset
from ..utils import TorchioTestCase

Expand All @@ -18,10 +20,10 @@ def setUp(self):
)

def run_queue(self, num_workers, **kwargs):
subjects_dataset = SubjectsDataset(self.subjects_list)
subjects_dataset = tio.SubjectsDataset(self.subjects_list)
patch_size = 10
sampler = UniformSampler(patch_size)
queue_dataset = Queue(
queue_dataset = tio.Queue(
subjects_dataset,
max_length=6,
samples_per_volume=2,
Expand All @@ -31,8 +33,8 @@ def run_queue(self, num_workers, **kwargs):
_ = str(queue_dataset)
batch_loader = DataLoader(queue_dataset, batch_size=4)
for batch in batch_loader:
_ = batch['one_modality'][DATA]
_ = batch['segmentation'][DATA]
_ = batch['one_modality'][tio.DATA]
_ = batch['segmentation'][tio.DATA]

def test_queue(self):
self.run_queue(num_workers=0)
Expand All @@ -42,3 +44,23 @@ def test_queue_multiprocessing(self):

def test_queue_no_start_background(self):
self.run_queue(num_workers=0, start_background=False)

def test_different_samples_per_volume(self):
image2 = tio.ScalarImage(tensor=2 * torch.ones(1, 1, 1, 1))
image10 = tio.ScalarImage(tensor=10 * torch.ones(1, 1, 1, 1))
subject2 = tio.Subject(im=image2, num_samples=2)
subject10 = tio.Subject(im=image10, num_samples=10)
dataset = tio.SubjectsDataset([subject2, subject10])
patch_size = 1
sampler = UniformSampler(patch_size)
queue_dataset = tio.Queue(
dataset,
max_length=12,
samples_per_volume=3, # should be ignored
sampler=sampler,
)
batch_loader = DataLoader(queue_dataset, batch_size=6)
batches = [batch['im'][tio.DATA] for batch in batch_loader]
all_numbers = torch.stack(batches).flatten().tolist()
assert all_numbers.count(10) == 10
assert all_numbers.count(2) == 2
3 changes: 3 additions & 0 deletions torchio/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@

# Floating point error
MIN_FLOAT_32 = torch.finfo(torch.float32).eps

# For the queue
NUM_SAMPLES = 'num_samples'
55 changes: 30 additions & 25 deletions torchio/data/queue.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import warnings
from itertools import islice
from typing import List, Iterator, Optional

import torch
import humanize
from tqdm import trange
from torch.utils.data import Dataset, DataLoader

from .. import NUM_SAMPLES
from .subject import Subject
from .sampler import PatchSampler
from .dataset import SubjectsDataset
Expand Down Expand Up @@ -59,7 +58,9 @@ class Queue(Dataset):
max_length: Maximum number of patches that can be stored in the queue.
Using a large number means that the queue needs to be filled less
often, but more CPU memory is needed to store the patches.
samples_per_volume: Number of patches to extract from each volume.
samples_per_volume: Default number of patches to extract from each
volume. If a subject contains an attribute :attr:`num_samples`, it
will be used instead of :attr:`samples_per_volume`.
A small number of patches ensures a large variability in the queue,
but training will be slower.
sampler: A subclass of :class:`~torchio.data.sampler.PatchSampler` used
Expand Down Expand Up @@ -206,34 +207,38 @@ def num_patches(self) -> int:

@property
def iterations_per_epoch(self) -> int:
return self.num_subjects * self.samples_per_volume
total_num_patches = sum(
self._get_subject_num_samples(subject)
for subject in self.subjects_dataset.dry_iter()
)
return total_num_patches

def _get_subject_num_samples(self, subject):
num_samples = getattr(
subject,
NUM_SAMPLES,
self.samples_per_volume,
)
return num_samples

def _fill(self) -> None:
assert self.sampler is not None
if self.max_length % self.samples_per_volume != 0:
message = (
f'Queue length ({self.max_length})'
' not divisible by the number of'
f' patches per volume ({self.samples_per_volume})'
)
warnings.warn(message, RuntimeWarning)

# If there are e.g. 4 subjects and 1 sample per volume and max_length
# is 6, we just need to load 4 subjects, not 6
max_num_subjects_for_queue = self.max_length // self.samples_per_volume
num_subjects_for_queue = min(
self.num_subjects, max_num_subjects_for_queue)

self._print(f'Filling queue from {num_subjects_for_queue} subjects...')
if self.verbose:
iterable = trange(num_subjects_for_queue, leave=False)
else:
iterable = range(num_subjects_for_queue)
for _ in iterable:

num_subjects = 0
while True:
subject = self._get_next_subject()
iterable = self.sampler(subject)
patches = list(islice(iterable, self.samples_per_volume))
num_samples = self._get_subject_num_samples(subject)
num_free_slots = self.max_length - len(self.patches_list)
num_samples = min(num_samples, num_free_slots)
patches = list(islice(iterable, num_samples))
self.patches_list.extend(patches)
num_subjects += 1
list_full = len(self.patches_list) >= self.max_length
all_subjects_sampled = num_subjects >= len(self.subjects_dataset)
if list_full or all_subjects_sampled:
break

if self.shuffle_patches:
self._shuffle_patches_list()

Expand Down

0 comments on commit 8ee2ab2

Please sign in to comment.