Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
- Fix progress bar updating during bag generation
- Add `show_progress` keyword argument for the feature extractor `tfrecord_inference()` method, for optionally displaying a progress bar.
- Add new `run_batched_inference()` method to pytorch feature extractors, used for generating features from a large tensor dataset of images.
- When generating features (using a feature extractor) from a loaded pytorch tensor of images, add a new `batch_size` argument that runs batch inference, decreasing memory requirements:

```
ctranspath = sf.build_feature_extractor('ctranspath')
images = torch.load(...)
features = ctranspath(images, batch_size=32)
```
  • Loading branch information
jamesdolezal committed Oct 11, 2024
1 parent fec8fc5 commit a2d0edb
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 13 deletions.
2 changes: 1 addition & 1 deletion slideflow/io/torch/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def random_jpeg_compression(
"""
q = (torch.rand(1)[0] * q_min) + (q_max - q_min)
img = torchvision.io.encode_jpeg(img, quality=q)
return torchvision.io.decode_image(img)
return torchvision.io.decode_jpeg(img)


def compose_color_distortion(s=1.0):
Expand Down
52 changes: 41 additions & 11 deletions slideflow/model/extractors/_factory_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect
import slideflow as sf
from tqdm import tqdm
from typing import Tuple, Generator, Optional, TYPE_CHECKING
from slideflow import errors

Expand Down Expand Up @@ -129,8 +130,8 @@ class TorchFeatureExtractor(BaseFeatureExtractor):
"""Feature extractor for PyTorch models."""

def __init__(
self,
channels_last: bool = False,
self,
channels_last: bool = False,
mixed_precision: bool = True,
**transform_kwargs
) -> None:
Expand All @@ -153,6 +154,7 @@ def __call__(self, obj, **kwargs):
"""Generate features for a batch of images or a WSI."""
import torch
from slideflow.model.torch import autocast
from slideflow.io.torch import as_cwh

if isinstance(obj, sf.WSI):
# Returns masked array of features
Expand All @@ -166,7 +168,7 @@ def __call__(self, obj, **kwargs):
features.append(batch_features)
locations.append(batch_locations)
return torch.cat(features), torch.cat(locations)
elif kwargs:
elif kwargs and not (len(kwargs) == 1 and 'batch_size' in kwargs):
raise ValueError(
f"{self.__class__.__name__} does not accept keyword arguments "
"when extracting features from a batch of images."
Expand All @@ -180,18 +182,45 @@ def __call__(self, obj, **kwargs):
raise RuntimeError("Expected input to be a uint8 tensor, got: {}".format(
obj.dtype
))
obj = obj.to(self.device)
obj = self.transform(obj)

# Determine batch size
batch_size = kwargs.get('batch_size', None)

with autocast(self.device.type, mixed_precision=self.mixed_precision):
with torch.inference_mode(self.inference_mode):
if self.channels_last:
obj = obj.to(memory_format=torch.channels_last)
return self._process_output(self.model(obj))
if batch_size:
return self.run_batched_inference(obj, batch_size)
else:
obj = self.transform(as_cwh(obj))
obj = obj.to(self.device)
return self._process_output(self.model(obj))

def run_batched_inference(self, obj, batch_size):
"""Run inference on a batch of images."""
import torch
from torch.utils.data import DataLoader
from slideflow.io.torch import as_cwh

# Create a DataLoader
dataset = torch.utils.data.TensorDataset(obj)
dl = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False)

# Run inference
features = []
for batch in dl:
if isinstance(batch, list):
batch = batch[0]
batch = self.transform(as_cwh(batch))
batch = batch.to(self.device)
features.append(self._process_output(self.model(batch)))
return torch.cat(features)

def _verify_transform_args(self, kwargs):
sig = inspect.signature(self.get_transforms)
valid_kwargs = [
p.name for p in sig.parameters.values()
p.name for p in sig.parameters.values()
if (p.kind == p.KEYWORD_ONLY
and p.name != 'img_size')
]
Expand Down Expand Up @@ -257,13 +286,14 @@ def build_transform(self, **kwargs):
from torchvision import transforms
kwargs.update(self.transform_kwargs)
return transforms.Compose(self.get_transforms(**kwargs))


def tfrecord_inference(
self,
tfrecord_path: str,
batch_size: int = 32,
num_workers: int = 2
*,
num_workers: int = 2,
show_progress: bool = False
) -> Generator[Tuple["torch.Tensor", "torch.Tensor"], None, None]:
"""Generate features from a TFRecord file."""
import torch
Expand All @@ -275,7 +305,7 @@ def tfrecord_inference(
batch_size=batch_size,
num_workers=num_workers
)
for batch in tfr_dl:
for batch in tqdm(tfr_dl, desc="Generating...", disable=not show_progress):
features = self(sf.io.torch.whc_to_cwh(batch['image_raw']))
locations = torch.stack([batch['loc_x'], batch['loc_y']], dim=1)
yield features, locations
Expand Down Expand Up @@ -310,7 +340,7 @@ def __init__(self, model_name: str, tile_px: int, device='cuda', **kwargs):
self.model_kw = {k:v for k,v in kwargs.items() if k in _model_kwarg_names}

# Build the imagenet-pretrained model
device = torch_utils.get_device(device)
device = torch_utils.get_device(device)
_hp = ModelParams(tile_px=tile_px, model=model_name, include_top=False, hidden_layers=0)
model = _hp.build_model(num_classes=1, pretrain='imagenet').to(device)
self.ftrs = Features.from_model(model, tile_px=tile_px, **self.model_kw)
Expand Down
2 changes: 1 addition & 1 deletion slideflow/model/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,7 +1748,7 @@ def batch_worker():
model_output = self._calculate_feature_batch(batch_img)
q.put((model_output, batch_slides, (batch_loc_x, batch_loc_y)))
if pb:
pb.advance(task, self.batch_size)
pb.advance(task, batch_img.shape[0])
q.put((None, None, None))
batch_proc_thread.join()
if hasattr(dataset, 'close'):
Expand Down

0 comments on commit a2d0edb

Please sign in to comment.