forked from constantinpape/minitorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement u-net models; tiktorch wrapper
- Loading branch information
1 parent
a999330
commit 0f579cb
Showing
6 changed files
with
241 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import numpy as np | ||
|
||
|
||
class Compose(object): | ||
def __init__(self, *transforms): | ||
assert all(callable(trafo) for trafo in transforms) | ||
self.transforms = transforms | ||
|
||
def __call__(self, data, target): | ||
for trafo in self.transforms: | ||
data, target = trafo(data, target) | ||
return data, target |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .unet import UNetBase, Unet2d | ||
from .unet import UNetBase | ||
from .unet import UNet2d, UNet2dGN | ||
from .unet import UNet3d, UNet3dGN, AnisotropicUNet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .util import train, validate, main | ||
from .data import crop_tensor, normalize, pad | ||
from .tensorboard import TensorBoard | ||
from .tiktorch import checkpoint_to_tiktorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,67 @@ | ||
# TODO | ||
def checkpoint_to_tiktorch(save_folder, input_shape, minimal_increment): | ||
""" Save checkpoint in tiktorch format | ||
import os | ||
import inspect | ||
from shutil import copyfile | ||
|
||
import yaml | ||
import torch | ||
|
||
|
||
def _to_dynamic_shape(minimal_increment): | ||
if len(minimal_increment) == 2: | ||
dynamic_shape = '(%i * (nH + 1), %i * (nW + 1))' % minimal_increment | ||
elif len(minimal_increment) == 3: | ||
dynamic_shape = '(%i * (nD + 1), %i * (nH + 1), %i * (nW + 1))' % minimal_increment | ||
else: | ||
raise ValueError("Invald length %i for minimal increment" % len(minimal_increment)) | ||
return dynamic_shape | ||
|
||
|
||
def checkpoint_to_tiktorch(model, model_kwargs, | ||
checkpoint_folder, output_folder, | ||
input_shape, minimal_increment, | ||
load_best=True): | ||
""" Save checkpoint in tiktorch format: | ||
TODO link | ||
Arguments: | ||
model: | ||
model_kwargs: | ||
checkpoint_folder: | ||
output_folder: | ||
input_shape: | ||
minimal_increment: | ||
load_best: | ||
""" | ||
pass | ||
os.makedirs(output_folder, exists_ok=True) | ||
|
||
# get the path to code and class name | ||
code_path = inspect.getfile(model) | ||
cls_name = model.__name__ | ||
|
||
# build the model, check the input and get output shape | ||
model = model(**model_kwargs) | ||
weight_path = os.path.join(checkpoint_folder, | ||
'best_weights.torch' if load_best else 'weights.torch') | ||
assert os.path.exists(weight_path), weight_path | ||
model.load_state_dict(torch.load(weight_path)) | ||
|
||
input_ = torch.zeros(*input_shape, dtype=torch.float()) | ||
out = model(input_) | ||
output_shape = tuple(out.shape) | ||
|
||
# build the config | ||
config = {'input_shape': input_shape, | ||
'output_shape': output_shape, | ||
'dynamic_input_shape': _to_dynamic_shape(minimal_increment), | ||
'model_class_name': cls_name, | ||
'model_init_kwargs': model_kwargs, | ||
'torch_version': torch.__version__} | ||
|
||
# serialize config | ||
config_file = os.path.join(checkpoint_folder, 'tiktorch_config.yaml') | ||
with open(config_file, 'w') as f: | ||
yaml.dump(config, f) | ||
|
||
# copy the state-dict and the code path | ||
copyfile(weight_path, os.path.join(output_folder, 'state.nn')) | ||
copyfile(code_path, os.path.join(output_folder, 'model.py')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters