From 8ee2ab2f0caf66f4a1338b115fab1d014c754f6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Sun, 15 May 2022 23:05:31 +0100 Subject: [PATCH] Add support to use different numbers of samples in the queue (#795) --- tests/data/test_queue.py | 32 +++++++++++++++++++---- torchio/constants.py | 3 +++ torchio/data/queue.py | 55 ++++++++++++++++++++++------------------ 3 files changed, 60 insertions(+), 30 deletions(-) diff --git a/tests/data/test_queue.py b/tests/data/test_queue.py index 28952588a..0384ea4a0 100644 --- a/tests/data/test_queue.py +++ b/tests/data/test_queue.py @@ -1,6 +1,8 @@ from torch.utils.data import DataLoader + +import torch +import torchio as tio from torchio.data import UniformSampler -from torchio import SubjectsDataset, Queue, DATA from torchio.utils import create_dummy_dataset from ..utils import TorchioTestCase @@ -18,10 +20,10 @@ def setUp(self): ) def run_queue(self, num_workers, **kwargs): - subjects_dataset = SubjectsDataset(self.subjects_list) + subjects_dataset = tio.SubjectsDataset(self.subjects_list) patch_size = 10 sampler = UniformSampler(patch_size) - queue_dataset = Queue( + queue_dataset = tio.Queue( subjects_dataset, max_length=6, samples_per_volume=2, @@ -31,8 +33,8 @@ def run_queue(self, num_workers, **kwargs): _ = str(queue_dataset) batch_loader = DataLoader(queue_dataset, batch_size=4) for batch in batch_loader: - _ = batch['one_modality'][DATA] - _ = batch['segmentation'][DATA] + _ = batch['one_modality'][tio.DATA] + _ = batch['segmentation'][tio.DATA] def test_queue(self): self.run_queue(num_workers=0) @@ -42,3 +44,23 @@ def test_queue_multiprocessing(self): def test_queue_no_start_background(self): self.run_queue(num_workers=0, start_background=False) + + def test_different_samples_per_volume(self): + image2 = tio.ScalarImage(tensor=2 * torch.ones(1, 1, 1, 1)) + image10 = tio.ScalarImage(tensor=10 * torch.ones(1, 1, 1, 1)) + subject2 = tio.Subject(im=image2, num_samples=2) + subject10 = tio.Subject(im=image10, num_samples=10) + dataset = tio.SubjectsDataset([subject2, subject10]) + patch_size = 1 + sampler = UniformSampler(patch_size) + queue_dataset = tio.Queue( + dataset, + max_length=12, + samples_per_volume=3, # should be ignored + sampler=sampler, + ) + batch_loader = DataLoader(queue_dataset, batch_size=6) + batches = [batch['im'][tio.DATA] for batch in batch_loader] + all_numbers = torch.stack(batches).flatten().tolist() + assert all_numbers.count(10) == 10 + assert all_numbers.count(2) == 2 diff --git a/torchio/constants.py b/torchio/constants.py index 206657b24..faf1b6d6b 100644 --- a/torchio/constants.py +++ b/torchio/constants.py @@ -30,3 +30,6 @@ # Floating point error MIN_FLOAT_32 = torch.finfo(torch.float32).eps + +# For the queue +NUM_SAMPLES = 'num_samples' diff --git a/torchio/data/queue.py b/torchio/data/queue.py index 518e99590..00c931a2d 100644 --- a/torchio/data/queue.py +++ b/torchio/data/queue.py @@ -1,12 +1,11 @@ -import warnings from itertools import islice from typing import List, Iterator, Optional import torch import humanize -from tqdm import trange from torch.utils.data import Dataset, DataLoader +from .. import NUM_SAMPLES from .subject import Subject from .sampler import PatchSampler from .dataset import SubjectsDataset @@ -59,7 +58,9 @@ class Queue(Dataset): max_length: Maximum number of patches that can be stored in the queue. Using a large number means that the queue needs to be filled less often, but more CPU memory is needed to store the patches. - samples_per_volume: Number of patches to extract from each volume. + samples_per_volume: Default number of patches to extract from each + volume. If a subject contains an attribute :attr:`num_samples`, it + will be used instead of :attr:`samples_per_volume`. A small number of patches ensures a large variability in the queue, but training will be slower. sampler: A subclass of :class:`~torchio.data.sampler.PatchSampler` used @@ -206,34 +207,38 @@ def num_patches(self) -> int: @property def iterations_per_epoch(self) -> int: - return self.num_subjects * self.samples_per_volume + total_num_patches = sum( + self._get_subject_num_samples(subject) + for subject in self.subjects_dataset.dry_iter() + ) + return total_num_patches + + def _get_subject_num_samples(self, subject): + num_samples = getattr( + subject, + NUM_SAMPLES, + self.samples_per_volume, + ) + return num_samples def _fill(self) -> None: assert self.sampler is not None - if self.max_length % self.samples_per_volume != 0: - message = ( - f'Queue length ({self.max_length})' - ' not divisible by the number of' - f' patches per volume ({self.samples_per_volume})' - ) - warnings.warn(message, RuntimeWarning) - - # If there are e.g. 4 subjects and 1 sample per volume and max_length - # is 6, we just need to load 4 subjects, not 6 - max_num_subjects_for_queue = self.max_length // self.samples_per_volume - num_subjects_for_queue = min( - self.num_subjects, max_num_subjects_for_queue) - - self._print(f'Filling queue from {num_subjects_for_queue} subjects...') - if self.verbose: - iterable = trange(num_subjects_for_queue, leave=False) - else: - iterable = range(num_subjects_for_queue) - for _ in iterable: + + num_subjects = 0 + while True: subject = self._get_next_subject() iterable = self.sampler(subject) - patches = list(islice(iterable, self.samples_per_volume)) + num_samples = self._get_subject_num_samples(subject) + num_free_slots = self.max_length - len(self.patches_list) + num_samples = min(num_samples, num_free_slots) + patches = list(islice(iterable, num_samples)) self.patches_list.extend(patches) + num_subjects += 1 + list_full = len(self.patches_list) >= self.max_length + all_subjects_sampled = num_subjects >= len(self.subjects_dataset) + if list_full or all_subjects_sampled: + break + if self.shuffle_patches: self._shuffle_patches_list()