Skip to content

Commit

Permalink
Fix the old flatten method which use the size(0) to caculate the batc…
Browse files Browse the repository at this point in the history
…h size, the old method will intruduce Gather opertion in the onnx output, which will faild parsed by tensorRT 5.0 (#1134)
  • Loading branch information
apache2046 authored and fmassa committed Jul 19, 2019
1 parent bbd363c commit 2cae950
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 9 deletions.
3 changes: 2 additions & 1 deletion torchvision/models/alexnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url

Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(self, num_classes=1000):
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x

Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
def forward(self, x):
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
out = self.classifier(out)
return out

Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def forward(self, x):

x = self.avgpool(x)
# N x 1024 x 1 x 1
x = x.view(x.size(0), -1)
x = torch.flatten(x, 1)
# N x 1024
x = self.dropout(x)
x = self.fc(x)
Expand Down Expand Up @@ -208,7 +208,7 @@ def forward(self, x):
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
x = self.conv(x)
# N x 128 x 4 x 4
x = x.view(x.size(0), -1)
x = torch.flatten(x, 1)
# N x 2048
x = F.relu(self.fc1(x), inplace=True)
# N x 2048
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def forward(self, x):
# N x 2048 x 1 x 1
x = F.dropout(x, training=self.training)
# N x 2048 x 1 x 1
x = x.view(x.size(0), -1)
x = torch.flatten(x, 1)
# N x 2048
x = self.fc(x)
# N x 1000 (num_classes)
Expand Down Expand Up @@ -334,7 +334,7 @@ def forward(self, x):
# Adaptive average pooling
x = F.adaptive_avg_pool2d(x, (1, 1))
# N x 768 x 1 x 1
x = x.view(x.size(0), -1)
x = torch.flatten(x, 1)
# N x 768
x = self.fc(x)
# N x 1000
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url

Expand Down Expand Up @@ -203,7 +204,7 @@ def forward(self, x):
x = self.layer4(x)

x = self.avgpool(x)
x = x.reshape(x.size(0), -1)
x = torch.flatten(x, 1)
x = self.fc(x)

return x
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/squeezenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self, version='1_0', num_classes=1000):
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x.view(x.size(0), -1)
return torch.flatten(x, 1)


def _squeezenet(version, pretrained, progress, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/vgg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url

Expand Down Expand Up @@ -41,7 +42,7 @@ def __init__(self, features, num_classes=1000, init_weights=True):
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x

Expand Down

0 comments on commit 2cae950

Please sign in to comment.