-
Notifications
You must be signed in to change notification settings - Fork 80
/
AttentionalClassify.py
43 lines (35 loc) · 1.46 KB
/
AttentionalClassify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Albert Berenguel
## Computer Vision Center (CVC). Universitat Autonoma de Barcelona
## Email: aberenguel@cvc.uab.es
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch.nn as nn
import unittest
class AttentionalClassify(nn.Module):
def __init__(self):
super(AttentionalClassify, self).__init__()
def forward(self, similarities, support_set_y):
"""
Produces pdfs over the support set classes for the target set image.
:param similarities: A tensor with cosine similarities of size [sequence_length, batch_size]
:param support_set_y: A tensor with the one hot vectors of the targets for each support set image
[sequence_length, batch_size, num_classes]
:return: Softmax pdf
"""
softmax = nn.Softmax()
softmax_similarities = softmax(similarities)
preds = softmax_similarities.unsqueeze(1).bmm(support_set_y).squeeze()
return preds
class AttentionalClassifyTest(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_forward(self):
pass
if __name__ == '__main__':
unittest.main()