Skip to content

Commit

Permalink
add metrics.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dontLoveBugs committed Dec 21, 2018
1 parent 5228d11 commit 84a3708
Show file tree
Hide file tree
Showing 10 changed files with 620 additions and 52 deletions.
125 changes: 125 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
.idea/
/.idea
*.iml
*.ppt
*.pptx
*.caffemodel

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# JPG PNG
*.jpg
*.png
28 changes: 11 additions & 17 deletions DORN_nyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,48 +115,42 @@ def forward(self, x):
decode_label is the ordinal labels for each position of Image I
"""
N, C, H, W = x.size()
ord_num = C // 2

if torch.cuda.is_available():
decode_label = torch.zeros((N, 1, H, W), dtype=torch.float32).cuda()
ord_labels = torch.zeros((N, C // 2, H, W), dtype=torch.float32).cuda()
else:
decode_label = torch.zeros((N, 1, H, W), dtype=torch.float32)
ord_labels = torch.zeros((N, C // 2, H, W), dtype=torch.float32)
# print('#1 decode size:', decode_label.size())
ord_num = C // 2
for i in range(ord_num):
ord_i = x[:, 2 * i:2 * i + 2, :, :]
ord_i = nn.functional.softmax(ord_i, dim=1) # compute P(w, h) in paper
ord_i = ord_i[:, 1, :, :]
ord_labels[:, i, :, :] = ord_i
# print('ord_i >= 0.5 size:', (ord_i >= 0.5).size())
decode_label += (ord_i >= 0.5).view(N, 1, H, W).float() # sum(n(p_k >= 0.5))
# ord_num = C // 2
# for i in range(ord_num):
# ord_i = x[:, 2 * i:2 * i + 2, :, :]
# ord_i = nn.functional.softmax(ord_i, dim=1) # compute P(w, h) in paper
# ord_i = ord_i[:, 1, :, :]
# ord_labels[:, i, :, :] = ord_i
# # print('ord_i >= 0.5 size:', (ord_i >= 0.5).size())
# decode_label += (ord_i >= 0.5).view(N, 1, H, W).float() # sum(n(p_k >= 0.5))

"""
replace iter with matrix operation
fast speed methods
"""
A = x[:, ::2, :, :].clone()
B = x[:, 1::2, :, :].clone()
# print('A size:', A.size())
# print('B size:', B.size())

A = A.view(N, 1, ord_num * H * W)
B = B.view(N, 1, ord_num * H * W)

C = torch.cat((A, B), dim=1)
# C = torch.clamp(C, min = 1e-8, max = 1e8) # prevent nans

ord_c = nn.functional.softmax(C, dim=1)

# print('C size:', C.size())
# print('ord_c size:', ord_c.size())

ord_c1 = ord_c[:, 1, :].clone()
ord_c1 = ord_c1.view(-1, ord_num, H, W)
decode_c = torch.sum(ord_c1, dim=1).view(-1, 1, H, W)
# print('ord_c1 size:', ord_c1.size())
# print('decode_c size:', decode_c.size())

# print('decode_label size:', decode_label.size())
return decode_c, ord_c1


Expand Down
26 changes: 19 additions & 7 deletions NYUDepth_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,20 @@ def main():
model_dict = checkpoint['model'].state_dict()
model = DORN_nyu.DORN()
model.load_state_dict(model_dict)
del model_dict # 删除载入的模型
# 使用SGD进行优化
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

# in paper, aspp module's lr is 20 bigger than the other modules
aspp_params = list(map(id, model.aspp_module.parameters()))
base_params = filter(lambda p: id(p) not in aspp_params, model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
optimizer = torch.optim.SGD([
{'params': base_params},
{'params': model.aspp_module.parameters(), 'lr': args.lr * 20},
], lr = args.lr, momentum = args.momentum)

print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
del checkpoint # 删除载入的模型
del model_dict
else:
print("=> creating Model")
model = DORN_nyu.DORN()
Expand Down Expand Up @@ -167,13 +176,16 @@ def train(train_loader, model, criterion, optimizer, epoch, logger):

# compute pred
end = time.time()
pred_d, pred_ord = model(input) # @wx 注意输出
with torch.autograd.detect_anomaly():
pred_d, pred_ord = model(input) # @wx 注意输出

loss = criterion(pred_ord, target)
optimizer.zero_grad()
loss.backward() # compute gradient and do SGD step
optimizer.step()

loss = criterion(pred_ord, target)
optimizer.zero_grad()
loss.backward() # compute gradient and do SGD step
optimizer.step()
torch.cuda.synchronize()

gpu_time = time.time() - end

# measure accuracy and record loss
Expand Down
92 changes: 69 additions & 23 deletions criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,30 @@ def forward(self, pred, target):
return self.loss


class ScaleInvariantError(nn.Module):
"""
Scale invariant error defined in Eigen's paper!
"""

def __init__(self, lamada=0.5):
super(ScaleInvariantError, self).__init__()
self.lamada = lamada
return

def forward(self, y_true, y_pred):
first_log = torch.log(torch.clamp(y_pred, min, max))
second_log = torch.log(torch.clamp(y_true, min, max))
d = first_log - second_log
loss = torch.mean(d * d) - self.lamada * torch.mean(d) * torch.mean(d)
return loss


class ordLoss(nn.Module):
"""
Ordinal loss is defined as the average of pixelwise ordinal loss F(h, w, X, O)
over the entire image domain:
"""

def __init__(self):
super(ordLoss, self).__init__()
self.loss = 0.0
Expand All @@ -75,35 +94,62 @@ def forward(self, ord_labels, target):
:return: ordinal loss
"""
# assert pred.dim() == target.dim()
# invalid_mask = target < 0
# target[invalid_mask] = 0

N, C, H, W = ord_labels.size()
ord_num = C
# print('ord_num = ', ord_num)

self.loss = 0.0

for k in range(ord_num):
'''
p^k_(w, h) = e^y(w, h, 2k+1) / [e^(w, h, 2k) + e^(w, h, 2k+1)]
'''
p_k = ord_labels[:, k, :, :]
p_k = p_k.view(N, 1, H, W)

mask_0 = (target <= k).detach()
mask_1 = (target > k).detach()
# print('p_k size:', p_k.size())
# print('mask 0 size:', mask_0.size())
# print('mask 1 size:', mask_1.size())
'''
对每个像素而言,
如果k小于l(w, h), log(p_k)
如果k大于l(w, h), log(1-p_k)
'''

one = torch.ones(p_k[mask_1].size())
if torch.cuda.is_available():
one = one.cuda()
self.loss += torch.sum(torch.log(p_k[mask_0])) + torch.sum(torch.log(one - p_k[mask_1]))
# for k in range(ord_num):
# '''
# p^k_(w, h) = e^y(w, h, 2k+1) / [e^(w, h, 2k) + e^(w, h, 2k+1)]
# '''
# p_k = ord_labels[:, k, :, :]
# p_k = p_k.view(N, 1, H, W)
#
# '''
# 对每个像素而言,
# 如果k小于l(w, h), log(p_k)
# 如果k大于l(w, h), log(1-p_k)
# 希望分类正确的p_k越大越好
# '''
# mask_0 = (target >= k).detach() # 分类正确
# mask_1 = (target < k).detach() # 分类错误
#
# one = torch.ones(p_k[mask_1].size())
# if torch.cuda.is_available():
# one = one.cuda()
# self.loss += torch.sum(torch.log(torch.clamp(p_k[mask_0], min = 1e-7, max = 1e7))) \
# + torch.sum(torch.log(torch.clamp(one - p_k[mask_1], min = 1e-7, max = 1e7)))

# faster version
if torch.cuda.is_available():
K = torch.zeros((N, C, H, W), dtype=torch.float32).cuda()
for i in range(ord_num):
K[:, i, :, :] = K[:, i, :, :] + i * torch.ones((N, H, W), dtype=torch.float32).cuda()
else:
K = torch.zeros((N, C, H, W), dtype=torch.float32)
for i in range(ord_num):
K[:, i, :, :] = K[:, i, :, :] + i * torch.ones((N, H, W), dtype=torch.float32)

mask_0 = (K <= target).detach()
mask_1 = (K > target).detach()

one = torch.ones(ord_labels[mask_1].size())
if torch.cuda.is_available():
one = one.cuda()

self.loss += torch.sum(torch.log(torch.clamp(ord_labels[mask_0], min=1e-7, max=1e7))) \
+ torch.sum(torch.log(torch.clamp(one - ord_labels[mask_1], min=1e-7, max=1e7)))

# del K
# del one
# del mask_0
# del mask_1

N = N * H * W
self.loss /= N
self.loss /= (-N) # negative
return self.loss
4 changes: 4 additions & 0 deletions dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Time : 2018/10/21 20:43
# @Author : Wang Xin
# @Email : wangxin_buaa@163.com
Loading

0 comments on commit 84a3708

Please sign in to comment.