-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathmodel.py
93 lines (83 loc) · 4.2 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn import functional as F
class EncoderCNN(nn.Module):
def __init__(self, embed_size):
"""Load the pretrained ResNet-50 and replace top fc layer."""
super(EncoderCNN, self).__init__()
resnet = models.resnet50(pretrained=True)
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
self.embed = nn.Linear(resnet.fc.in_features, embed_size)
self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
def forward(self, images):
"""Extract feature vectors from input images."""
with torch.no_grad():
features = self.resnet(images)
features = features.view(features.size(0), -1)
features = self.embed(features)
features = self.bn(features)
return features
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
"""Set the hyper-parameters and build the layers."""
super(DecoderRNN, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
def forward(self, features, captions):
"""Decode image feature vectors and generates captions."""
captions = captions[:,:-1]
embeddings = self.embed(captions)
inputs = torch.cat((features.unsqueeze(1), embeddings), 1)
hiddens, _ = self.lstm(inputs)
outputs = self.linear(hiddens)
return outputs
def sample(self, inputs, states=None, max_len=20):
"""Accept a pre-processed image tensor (inputs) and return predicted
sentence (list of tensor ids of length max_len). This is the greedy
search approach.
"""
sampled_ids = []
for i in range(max_len):
hiddens, states = self.lstm(inputs, states)
outputs = self.linear(hiddens.squeeze(1))
# Get the index (in the vocabulary) of the most likely integer that
# represents a word
predicted = outputs.argmax(1)
sampled_ids.append(predicted.item())
inputs = self.embed(predicted)
inputs = inputs.unsqueeze(1)
return sampled_ids
def sample_beam_search(self, inputs, states=None, max_len=20, beam_width=5):
"""Accept a pre-processed image tensor and return the top predicted
sentences. This is the beam search approach.
"""
# Top word idx sequences and their corresponding inputs and states
idx_sequences = [[[], 0.0, inputs, states]]
for _ in range(max_len):
# Store all the potential candidates at each step
all_candidates = []
# Predict the next word idx for each of the top sequences
for idx_seq in idx_sequences:
hiddens, states = self.lstm(idx_seq[2], idx_seq[3])
outputs = self.linear(hiddens.squeeze(1))
# Transform outputs to log probabilities to avoid floating-point
# underflow caused by multiplying very small probabilities
log_probs = F.log_softmax(outputs, -1)
top_log_probs, top_idx = log_probs.topk(beam_width, 1)
top_idx = top_idx.squeeze(0)
# create a new set of top sentences for next round
for i in range(beam_width):
next_idx_seq, log_prob = idx_seq[0][:], idx_seq[1]
next_idx_seq.append(top_idx[i].item())
log_prob += top_log_probs[0][i].item()
# Indexing 1-dimensional top_idx gives 0-dimensional tensors.
# We have to expand dimensions before embedding them
inputs = self.embed(top_idx[i].unsqueeze(0)).unsqueeze(0)
all_candidates.append([next_idx_seq, log_prob, inputs, states])
# Keep only the top sequences according to their total log probability
ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
idx_sequences = ordered[:beam_width]
return [idx_seq[0] for idx_seq in idx_sequences]