Skip to content

Commit

Permalink
add fast model version
Browse files Browse the repository at this point in the history
  • Loading branch information
dontLoveBugs committed Nov 27, 2018
1 parent 597fe6c commit c57c55d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
34 changes: 31 additions & 3 deletions DORN_nyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# @Time : 2018/11/22 12:33
# @Author : Wang Xin
# @Email : wangxin_buaa@163.com
import os

import torch
import torch.nn as nn
Expand Down Expand Up @@ -108,11 +109,11 @@ def __init__(self):
super(OrdinalRegressionLayer, self).__init__()

def forward(self, x):
'''
"""
:param x: N X H X W X C, N is batch_size, C is channels of features
:return: ord_labels is ordinal outputs for each spatial locations , size is N x H X W X C (C = 2K, K is interval of SID)
decode_label is the ordinal labels for each position of Image I
'''
"""
N, C, H, W = x.size()
if torch.cuda.is_available():
decode_label = torch.zeros((N, 1, H, W), dtype=torch.float32).cuda()
Expand All @@ -130,8 +131,33 @@ def forward(self, x):
# print('ord_i >= 0.5 size:', (ord_i >= 0.5).size())
decode_label += (ord_i >= 0.5).view(N, 1, H, W).float() # sum(n(p_k >= 0.5))

"""
replace iter with matrix operation
fast speed methods
"""
A = x[:, ::2, :, :].clone()
B = x[:, 1::2, :, :].clone()
# print('A size:', A.size())
# print('B size:', B.size())

A = A.view(N, 1, ord_num * H * W)
B = B.view(N, 1, ord_num * H * W)

C = torch.cat((A, B), dim=1)

ord_c = nn.functional.softmax(C, dim=1)

# print('C size:', C.size())
# print('ord_c size:', ord_c.size())

ord_c1 = ord_c[:, 1, :].clone()
ord_c1 = ord_c1.view(-1, ord_num, H, W)
decode_c = torch.sum(ord_c1, dim=1).view(-1, 1, H, W)
# print('ord_c1 size:', ord_c1.size())
# print('decode_c size:', decode_c.size())

# print('decode_label size:', decode_label.size())
return decode_label, ord_labels
return decode_c, ord_c1


class ResNet(nn.Module):
Expand Down Expand Up @@ -228,6 +254,8 @@ def forward(self, x):
return depth_labels, ord_labels


# os.environ["CUDA_VISIBLE_DEVICES"] = "1" # 默认使用GPU 0

if __name__ == "__main__":
model = DORN()
model = model.cuda()
Expand Down
1 change: 1 addition & 0 deletions NYUDepth_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def main():
model_dict = checkpoint['model'].state_dict()
model = DORN_nyu.DORN()
model.load_state_dict(model_dict)
del model_dict # 删除载入的模型
# 使用SGD进行优化
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

Expand Down

0 comments on commit c57c55d

Please sign in to comment.