diff --git a/basenet/basenet.py b/basenet/basenet.py index ee917ee..58a3b79 100644 --- a/basenet/basenet.py +++ b/basenet/basenet.py @@ -37,7 +37,12 @@ def _to_device(x, device): else: return x.to(device) else: - return x.cuda() + if device == 'cuda': + return x.cuda() + elif device == 'cpu': + return x.cpu() + else: + raise Exception class Metrics: @staticmethod @@ -72,7 +77,16 @@ def __init__(self, loss_fn=F.cross_entropy, verbose=False): def to(self, device=None): self.device = device - super().to(device=device) + if TORCH_VERSION_4: + super().to(device=device) + else: + if device == 'cuda': + self.cuda() + elif device == 'cpu': + self.cpu() + else: + raise Exception + return self # -- @@ -252,11 +266,19 @@ def predict(self, dataloaders, mode='val'): if self.verbose: gen = tqdm(gen, total=len(loader), desc='predict:%s' % mode) + if hasattr(self, 'reset'): + self.reset() + for _, (data, target) in gen: - with torch.no_grad(): - data = _to_device(data, self.device) - all_output.append(self(data).cpu()) - all_target.append(target) + if TORCH_VERSION_4: + with torch.no_grad(): + output = self(_to_device(data, self.device)).cpu() + else: + data = Variable(data, volatile=True) + output = self(_to_device(data, self.device)).cpu() + + all_output.append(output) + all_target.append(target) return torch.cat(all_output), torch.cat(all_target) diff --git a/basenet/text/data.py b/basenet/text/data.py index bd1c51d..a8b02d6 100644 --- a/basenet/text/data.py +++ b/basenet/text/data.py @@ -27,7 +27,7 @@ def text_collate_fn(batch, pad_value=1): X, y = zip(*batch) max_len = max([len(xx) for xx in X]) - X = [F.pad(xx, pad=(max_len - len(xx), 0), value=pad_value) for xx in X] + X = [F.pad(xx, pad=(max_len - len(xx), 0), value=pad_value).data for xx in X] X = torch.stack(X, dim=-1) y = torch.LongTensor(y)