Skip to content

Commit

Permalink
memory usage in basenet.text
Browse files Browse the repository at this point in the history
  • Loading branch information
bkj committed Jun 4, 2018
1 parent c974d66 commit 1fa5d08
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
34 changes: 28 additions & 6 deletions basenet/basenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ def _to_device(x, device):
else:
return x.to(device)
else:
return x.cuda()
if device == 'cuda':
return x.cuda()
elif device == 'cpu':
return x.cpu()
else:
raise Exception

class Metrics:
@staticmethod
Expand Down Expand Up @@ -72,7 +77,16 @@ def __init__(self, loss_fn=F.cross_entropy, verbose=False):

def to(self, device=None):
self.device = device
super().to(device=device)
if TORCH_VERSION_4:
super().to(device=device)
else:
if device == 'cuda':
self.cuda()
elif device == 'cpu':
self.cpu()
else:
raise Exception

return self

# --
Expand Down Expand Up @@ -252,11 +266,19 @@ def predict(self, dataloaders, mode='val'):
if self.verbose:
gen = tqdm(gen, total=len(loader), desc='predict:%s' % mode)

if hasattr(self, 'reset'):
self.reset()

for _, (data, target) in gen:
with torch.no_grad():
data = _to_device(data, self.device)
all_output.append(self(data).cpu())
all_target.append(target)
if TORCH_VERSION_4:
with torch.no_grad():
output = self(_to_device(data, self.device)).cpu()
else:
data = Variable(data, volatile=True)
output = self(_to_device(data, self.device)).cpu()

all_output.append(output)
all_target.append(target)

return torch.cat(all_output), torch.cat(all_target)

Expand Down
2 changes: 1 addition & 1 deletion basenet/text/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def text_collate_fn(batch, pad_value=1):
X, y = zip(*batch)

max_len = max([len(xx) for xx in X])
X = [F.pad(xx, pad=(max_len - len(xx), 0), value=pad_value) for xx in X]
X = [F.pad(xx, pad=(max_len - len(xx), 0), value=pad_value).data for xx in X]

X = torch.stack(X, dim=-1)
y = torch.LongTensor(y)
Expand Down

0 comments on commit 1fa5d08

Please sign in to comment.