Skip to content
This repository has been archived by the owner on Jul 21, 2021. It is now read-only.

Commit

Permalink
fix fine-tuning deformable conv
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway committed Apr 5, 2017
1 parent 571b20a commit e25e0af
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 26 deletions.
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# [WIP] PyTorch implementation of Deformable Convolution
* by Wei OUYANG @ Institut Pasteur
# PyTorch implementation of Deformable Convolution
* By Wei OUYANG @ Institut Pasteur
* The original implementation in Keras/TensorFlow: https://github.com/felixlaumon/deform-conv

### TODO List
- [x] implement offsets mapping in pytorch
- [x] all tests passed
- [ ] deformable convolution layers
- [ ] scaled mnist demo (no fine-tuning yet)
- [x] deformable convolution module
- [x] Fine-tuning the deformable convolution modules
- [x] scaled mnist demo
- [x] improve speed with cached grid array
- [ ] use MNIST dataset from pytorch (instead of Keras)
- [ ] support fine-tuning the deformable convolution layers
- [ ] support input image with different width and height
- [ ] benchmark with tensorflow implementation

## Understanding Deformable Convolution
> Dai, Jifeng, Haozhi Qi, Yuwen Xiong, Yi Li, Guodong Zhang, Han Hu, and Yichen
Expand Down
20 changes: 3 additions & 17 deletions scaled_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test(model, generator, batch_num, epoch):


torch.save(model, 'models/cnn.th')
# 1875/1875 [==============================] - 24s - loss: 0.0090 - acc: 0.9969 - val_loss: 0.0528 - val_acc: 0.9858

# ---
# Evaluate normal CNN
Expand All @@ -109,6 +108,7 @@ def test(model, generator, batch_num, epoch):
# ---
# Deformable CNN

print('Finetune deformable CNN (ConvOffset2D and BatchNorm)')
model = get_deform_cnn(trainable=False)
model = model.cuda()
transfer_weights(model_cnn, model)
Expand All @@ -119,7 +119,6 @@ def test(model, generator, batch_num, epoch):


torch.save(model, 'models/deform_cnn.th')
# 1875/1875 [==============================] - 24s - loss: 0.0090 - acc: 0.9969 - val_loss: 0.0528 - val_acc: 0.9858

# ---
# Evaluate deformable CNN
Expand All @@ -128,19 +127,6 @@ def test(model, generator, batch_num, epoch):
model = torch.load('models/deform_cnn.th')

test(model, test_gen, validation_steps, epoch)
# 99.11%
# xx%
test(model, test_scaled_gen, validation_steps, epoch)
# 63.27%

# TODO: support fine-tuning
# deform_conv_layers = [l for l in model.layers if isinstance(l, ConvOffset2D)]
#
# Xb, Yb = next(test_gen)
# for l in deform_conv_layers:
# print(l)
# _model = Model(inputs=inputs, outputs=l.output)
# offsets = _model.predict(Xb)
# offsets = offsets.reshape(offsets.shape[0], offsets.shape[1], offsets.shape[2], -1, 2)
# print(offsets.min())
# print(offsets.mean())
# print(offsets.max())
# xx%
21 changes: 17 additions & 4 deletions torch_deform_conv/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,32 @@ def forward(self, x):
x = F.softmax(x)
return x

def freeze(self, skip=ConvOffset2D):
def freeze(self, module_classes):
'''
freeze modules for finetuning
'''
for k, m in self._modules.items():
if skip is None or not isinstance(m, skip):
if any([isinstance(m, mc) for mc in module_classes]):
for param in m.parameters():
param.requires_grad = False

def unfreeze(self, module_classes):
'''
unfreeze modules
'''
for k, m in self._modules.items():
if any([isinstance(m, mc) for mc in module_classes]):
for param in m.parameters():
param.requires_grad = True

def parameters(self):
return filter(lambda p: p.requires_grad, super(DeformConvNet, self).parameters())

def get_cnn():
return ConvNet()

def get_deform_cnn(trainable=True):
def get_deform_cnn(trainable=True, freeze_filter=[nn.Conv2d, nn.Linear]):
model = DeformConvNet()
model.freeze(skip=ConvOffset2D)
if not trainable:
model.freeze(freeze_filter)
return model

0 comments on commit e25e0af

Please sign in to comment.