-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgrad_check.py
76 lines (53 loc) · 1.94 KB
/
grad_check.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch.nn as nn
import onmt
import torch
from onmt.modules.optimized.relative_self_attention_func import RelativeShiftFunction
class TestDecoder(nn.Module):
def __init__(self):
super().__init__()
self.function = RelativeShiftFunction.apply
def forward(self, input):
return self.function(input)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='reversible transformer')
parser.add_argument('-model_size', type=int, default=16,
help='Size of embedding / transformer hidden')
parser.add_argument('-gpu', default=0, type=int,
help="Seed for deterministic runs.")
parser.add_argument('-test_decoder', action='store_true',
help='Test decoder')
opt = parser.parse_args()
torch.cuda.set_device(opt.gpu)
onmt.constants.weight_norm = False
onmt.constants.checkpointing = False
onmt.constants.max_position_length = 4096
onmt.constants.double_precision = True
opt.layers = 2
opt.variational_dropout = False
opt.dropout = 0.0
opt.attn_dropout = 0.0
opt.n_heads = 1
opt.inner_size = 16
bsz = 3
seq_len = 5
len_q = 2
x = torch.arange(seq_len - 1, -1, -1).unsqueeze(0).unsqueeze(0)
x = x.expand(bsz, len_q, seq_len)
print(x)
input_states = torch.randn(*(bsz, len_q, seq_len)).double().cuda()
net = TestDecoder()
net = net.double().cuda()
print(net)
x = x.double().cuda()
print(net(x))
print("start gradchecking ...")
input_states.requires_grad = True
torch.autograd.gradcheck(net, (input_states))
print("gradchecking completed.")
# context.requires_grad = True
# input.requires
# print("start gradchecking for context...")
# input_states.requires_grad = True
# torch.autograd.gradcheck(net, (input_states, context))
# print("gradchecking completed.")