-
Notifications
You must be signed in to change notification settings - Fork 0
/
LiarLiar.py
79 lines (53 loc) · 3.64 KB
/
LiarLiar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
from Attention import MultiHeadAttention, PositionFeedforward
from Encoder import Encoder
class arePantsonFire(nn.Module):
def __init__(self, sentence_encoder: Encoder, explanation_encoder: Encoder, multihead_Attention: MultiHeadAttention,
position_Feedforward: PositionFeedforward, hidden_dim: int, max_length_sentence,
max_length_justification, input_dim, device='cuda:0'):
"""
If you wish to shift on cpu pass device as 'cpu'
"""
super(arePantsonFire, self).__init__()
self.device = device
self.sentence_pos_embedding = nn.Embedding(max_length_sentence, hidden_dim)
self.justification_pos_embedding = nn.Embedding(max_length_justification, hidden_dim)
self.sentence_encoder = sentence_encoder
self.explanation_encoder = explanation_encoder
self.attention = multihead_Attention
self.position_feedforward = position_Feedforward
self.upscale_conv, self.first_conv, self.flatten_conv = self.get_convolutions(input_dim=input_dim, hidden_dim=hidden_dim)
self.linear1, self.linear2, self.bilinear, self.classifier = self.get_linears_layers(max_length_sentence=max_length_sentence)
def forward(self, sentence, justification, credit_history): # 1 Marks
#print(sentence.shape)
sentence_pos = torch.arange(0, sentence.shape[2]).unsqueeze(0).repeat(sentence.shape[0],1).to(self.device).long()
justification_pos = torch.arange(0, justification.shape[2]).unsqueeze(0).repeat(justification.shape[0], 1).to(self.device).long()
sentence = self.upscale_conv(sentence)
sentence = sentence + self.sentence_pos_embedding(sentence_pos).permute(0, 2, 1)
justification = self.upscale_conv(justification)
justification = justification + self.justification_pos_embedding(justification_pos).permute(0, 2, 1)
encoded_sentence=self.sentence_encoder.forward(sentence)
encoded_justification=self.explanation_encoder.forward(justification)
attention_output=self.attention.forward(encoded_sentence,encoded_justification,encoded_justification)
attention_output=self.position_feedforward.forward(attention_output)
attention_output=self.first_conv(attention_output)
attention_output=nn.functional.relu(attention_output)
attention_output=self.flatten_conv(attention_output)
attention_output=torch.flatten(attention_output)
attention_output=self.linear1(attention_output)
attention_output=self.linear2(attention_output)
attention_output=self.bilinear(attention_output, torch.squeeze(credit_history))
attention_output=self.classifier(attention_output)
return attention_output
def get_convolutions(self, input_dim, hidden_dim): # 0.5 Marks
upscale_conv=nn.Conv1d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=1, stride=1)
first_conv=nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim//2, kernel_size=3, stride=1, padding=1)
flatten_conv=nn.Conv1d(in_channels=hidden_dim//2,out_channels=1, kernel_size=5, stride=1, padding=2)
return upscale_conv, first_conv, flatten_conv
def get_linears_layers(self, max_length_sentence): # 0.5 Marks
linear1=nn.Linear(in_features=max_length_sentence, out_features=max_length_sentence//4)
linear2=nn.Linear(in_features=max_length_sentence//4, out_features=6)
bilinear=nn.Bilinear(in1_features=6, in2_features=5,out_features=12, bias=True)
classifier=nn.Linear(in_features=12, out_features=6)
return linear1, linear2, bilinear, classifier