Skip to content

Commit

Permalink
Implement u-net models; tiktorch wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Oct 22, 2018
1 parent a999330 commit 0f579cb
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 25 deletions.
12 changes: 12 additions & 0 deletions minitorch/data/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import numpy as np


class Compose(object):
def __init__(self, *transforms):
assert all(callable(trafo) for trafo in transforms)
self.transforms = transforms

def __call__(self, data, target):
for trafo in self.transforms:
data, target = trafo(data, target)
return data, target
4 changes: 3 additions & 1 deletion minitorch/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .unet import UNetBase, Unet2d
from .unet import UNetBase
from .unet import UNet2d, UNet2dGN
from .unet import UNet3d, UNet3dGN, AnisotropicUNet
156 changes: 144 additions & 12 deletions minitorch/model/unet.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import torch
import torch.nn as nn
from ..util import crop_tensor


class UNetBase(nn.Module):
""" UNet Base class implementation
Deriving classes must implement
- _conv_block(in_channels, out_channels, level, part)
- _upsampler(in_channels, out_channels, level)
return conv block for a U-Net level
- _pooler(level)
return pooling operation used for downsampling in-between encoders
- _upsampler(in_channels, out_channels, level)
return upsampling operation used for upsampling in-between decoders
- _out_conv(in_channels, out_channels)
return output conv layer
Arguments:
in_channels: number of input channels
out_channels: number of output channels
initial_features: number of features after first convolution
gain: growth factor of features
depth: depth of the u-net
final_activation: activation applied to the network output
"""

Expand All @@ -34,8 +40,6 @@ def __init__(self, in_channels=1, out_channels=1,
# modules of the encoder path
n_features = [in_channels] + [initial_features * gain ** level
for level in range(self.depth)]
# print("encoder:")
# print(n_features)
self.encoder = nn.ModuleList([self._conv_block(n_features[level], n_features[level + 1],
level, part='encoder')
for level in range(self.depth)])
Expand All @@ -47,27 +51,39 @@ def __init__(self, in_channels=1, out_channels=1,
n_features = [initial_features * gain ** level
for level in range(self.depth + 1)]
n_features = n_features[::-1]
# print("decoder:")
# print(n_features)
self.decoder = nn.ModuleList([self._conv_block(n_features[level], n_features[level + 1],
level, part='decoder')
self.depth - level - 1, part='decoder')
for level in range(self.depth)])

# the pooling layers; we use 2x2 MaxPooling
# the pooling layers;
self.poolers = nn.ModuleList([self._pooler(level) for level in range(self.depth)])
# the upsampling layers
self.upsamplers = nn.ModuleList([self._upsampler(n_features[level],
n_features[level + 1], level)
n_features[level + 1],
self.depth - level - 1)
for level in range(self.depth)])
# output conv and activation
# the output conv is not followed by a non-linearity, because we apply
# activation afterwards
self.out_conv = self._out_conv(n_features[-1], out_channels)
self.activation = final_activation

# NOTE we duplicate this from `minitorch.utils.data` so that we can provide
# this file as a standalone header
@staticmethod
def _crop_tensor(input_, shape_to_crop):
input_shape = input_.shape
# get the difference between the shapes
shape_diff = tuple((ish - csh) // 2
for ish, csh in zip(input_shape, shape_to_crop))
# calculate the crop
crop = tuple(slice(sd, sh - sd)
for sd, sh in zip(shape_diff, input_shape))
return input_[crop]

# crop the `from_encoder` tensor and concatenate both
def _crop_and_concat(self, from_decoder, from_encoder):
cropped = crop_tensor(from_encoder, from_decoder.shape)
cropped = self._crop_tensor(from_encoder, from_decoder.shape)
return torch.cat((cropped, from_decoder), dim=1)

def forward(self, input):
Expand Down Expand Up @@ -96,7 +112,15 @@ def forward(self, input):
return x


class Unet2d(UNetBase):
#
# 2D U-Net implementations
#


class UNet2d(UNetBase):
""" 2d U-Net for segmentation as described in
https://arxiv.org/abs/1505.04597
"""
# Convolutional block for single layer of the decoder / encoder
# we apply to 2d convolutions with relu activation
def _conv_block(self, in_channels, out_channels, level, part):
Expand All @@ -105,7 +129,6 @@ def _conv_block(self, in_channels, out_channels, level, part):
nn.Conv2d(out_channels, out_channels, kernel_size=3),
nn.ReLU())


# upsampling via transposed 2d convolutions
def _upsampler(self, in_channels, out_channels, level):
return nn.ConvTranspose2d(in_channels, out_channels,
Expand All @@ -117,3 +140,112 @@ def _pooler(self, level):

def _out_conv(self, in_channels, out_channels):
return nn.Conv2d(in_channels, out_channels, 1)


class UNet2dGN(UNet2d):
""" 2d U-Net with GroupNorm
"""
# Convolutional block for single layer of the decoder / encoder
# we apply to 2d convolutions with relu activation
def _conv_block(self, in_channels, out_channels, level, part):
num_groups1 = min(in_channels, 32)
num_groups2 = min(out_channels, 32)
return nn.Sequential(nn.GroupNorm(num_groups1, in_channels),
nn.Conv2d(in_channels, out_channels, kernel_size=3),
nn.ReLU(),
nn.GroupNorm(num_groups2, out_channels),
nn.Conv2d(out_channels, out_channels, kernel_size=3),
nn.ReLU())


def unet_2d(pretrained=None, **kwargs):
net = UNet2dGN(**kwargs)
if pretrained is not None:
assert pretrained in ('isbi',)
# TODO implement download
return net


#
# 3D U-Net implementations
#


class UNet3d(UNetBase):
""" 3d U-Net for segmentation as described in
https://arxiv.org/abs/1606.06650
"""
# Convolutional block for single layer of the decoder / encoder
# we apply to 2d convolutions with relu activation
def _conv_block(self, in_channels, out_channels, level, part):
return nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=3),
nn.ReLU(),
nn.Conv3d(out_channels, out_channels, kernel_size=3),
nn.ReLU())

# upsampling via transposed 3d convolutions
def _upsampler(self, in_channels, out_channels, level):
return nn.ConvTranspose3d(in_channels, out_channels,
kernel_size=2, stride=2)

# pooling via maxpool3d
def _pooler(self, level):
return nn.MaxPool3d(2)

def _out_conv(self, in_channels, out_channels):
return nn.Conv3d(in_channels, out_channels, 1)


class UNet3dGN(UNet3d):
""" 3d U-Net with GroupNorm
"""
# Convolutional block for single layer of the decoder / encoder
# we apply to 2d convolutions with relu activation
def _conv_block(self, in_channels, out_channels, level, part):
num_groups1 = min(in_channels, 32)
num_groups2 = min(out_channels, 32)
return nn.Sequential(nn.GroupNorm(num_groups1, in_channels),
nn.Conv3d(in_channels, out_channels, kernel_size=3),
nn.ReLU(),
nn.GroupNorm(num_groups2, out_channels),
nn.Conv3d(out_channels, out_channels, kernel_size=3),
nn.ReLU())


class AnisotropicUNet(UNet3dGN):
""" 3D GroupNorm U-Net with anisotropic scaling
Arguments:
scale_factors: list of scale factors
in_channels: number of input channels
out_channels: number of output channels
initial_features: number of features after first convolution
gain: growth factor of features
final_activation: activation applied to the network output
"""
@staticmethod
def _validate_scale_factors(scale_factors):
assert isinstance(scale_factors, (list, tuple))
for sf in scale_factors:
assert isinstance(sf, (int, tuple))
if not isinstance(sf, int):
assert len(sf) == 3
assert all(isinstance(sff, int) for sff in sf)

def __init__(self, scale_factors, in_channels=1,
out_channels=1, initial_features=64,
gain=2, final_activation=None):
self._validate_scale_factors(scale_factors)
self.scale_factors = scale_factors
super().__init__(in_channels=in_channels, out_channels=out_channels,
initial_features=initial_features, gain=gain,
detph=len(self.scale_factors), final_activation=final_activation)

def _upsampler(self, in_channels, out_channels, level):
scale_factor = self.scale_factors[level]
return nn.ConvTranspose3d(in_channels, out_channels,
kernel_size=scale_factor,
stride=scale_factor)

def _pooler(self, level):
return nn.MaxPool3d(self.scale_factors[level])
1 change: 1 addition & 0 deletions minitorch/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .util import train, validate, main
from .data import crop_tensor, normalize, pad
from .tensorboard import TensorBoard
from .tiktorch import checkpoint_to_tiktorch
70 changes: 66 additions & 4 deletions minitorch/util/tiktorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,67 @@
# TODO
def checkpoint_to_tiktorch(save_folder, input_shape, minimal_increment):
""" Save checkpoint in tiktorch format
import os
import inspect
from shutil import copyfile

import yaml
import torch


def _to_dynamic_shape(minimal_increment):
if len(minimal_increment) == 2:
dynamic_shape = '(%i * (nH + 1), %i * (nW + 1))' % minimal_increment
elif len(minimal_increment) == 3:
dynamic_shape = '(%i * (nD + 1), %i * (nH + 1), %i * (nW + 1))' % minimal_increment
else:
raise ValueError("Invald length %i for minimal increment" % len(minimal_increment))
return dynamic_shape


def checkpoint_to_tiktorch(model, model_kwargs,
checkpoint_folder, output_folder,
input_shape, minimal_increment,
load_best=True):
""" Save checkpoint in tiktorch format:
TODO link
Arguments:
model:
model_kwargs:
checkpoint_folder:
output_folder:
input_shape:
minimal_increment:
load_best:
"""
pass
os.makedirs(output_folder, exists_ok=True)

# get the path to code and class name
code_path = inspect.getfile(model)
cls_name = model.__name__

# build the model, check the input and get output shape
model = model(**model_kwargs)
weight_path = os.path.join(checkpoint_folder,
'best_weights.torch' if load_best else 'weights.torch')
assert os.path.exists(weight_path), weight_path
model.load_state_dict(torch.load(weight_path))

input_ = torch.zeros(*input_shape, dtype=torch.float())
out = model(input_)
output_shape = tuple(out.shape)

# build the config
config = {'input_shape': input_shape,
'output_shape': output_shape,
'dynamic_input_shape': _to_dynamic_shape(minimal_increment),
'model_class_name': cls_name,
'model_init_kwargs': model_kwargs,
'torch_version': torch.__version__}

# serialize config
config_file = os.path.join(checkpoint_folder, 'tiktorch_config.yaml')
with open(config_file, 'w') as f:
yaml.dump(config, f)

# copy the state-dict and the code path
copyfile(weight_path, os.path.join(output_folder, 'state.nn'))
copyfile(code_path, os.path.join(output_folder, 'model.py'))
23 changes: 15 additions & 8 deletions pretrain/isbi/pretrain_isbi.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

from minitorch.util import main, TensorBoard
from minitorch.model import Unet2d
from minitorch.util import main, TensorBoard, checkpoint_to_tiktorch
from minitorch.model import UNet2dGN
from minitorch.data import Isbi2012
from minitorch.criteria import SorensenDice, WeightedLoss


def pretrain_isbi(device, data_set, num_workers=0):
def pretrain_isbi(net, device, data_set, num_workers=0):
# TODO transforms
train_set = Isbi2012(train=True, root=data_set)
val_set = Isbi2012(train=False, root=data_set)
Expand All @@ -18,9 +19,6 @@ def pretrain_isbi(device, data_set, num_workers=0):
val_loader = DataLoader(val_set, batch_size=1,
num_workers=num_workers)

# TODO unet options and GN unet
net = Unet2d(1, 1, initial_features=64)

loss = WeightedLoss(SorensenDice(use_as_loss=True), crop_target=True)
metric = WeightedLoss(SorensenDice(), crop_target=True)

Expand All @@ -36,5 +34,14 @@ def pretrain_isbi(device, data_set, num_workers=0):

if __name__ == '__main__':
data_set = '/home/cpape/Work/data/isbi2012/isbi2012_train_volume.h5'
device = 'cpu'
pretrain_isbi(device, data_set)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_kwargs = {'in_channels': 1, 'out_channels': 1,
'initial_features': 64}
net = UNet2dGN(**model_kwargs)
pretrain_isbi(net, device, data_set)

# TODO check that minimal increment is true
checkpoint_to_tiktorch(UNet2dGN, model_kwargs,
'./checkpoints', './ISBI2012_UNet_pretrained',
(1, 572, 572), (32, 32))

0 comments on commit 0f579cb

Please sign in to comment.