forked from vinija/AI-Coding
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ContrastiveLoss.py
43 lines (31 loc) · 1.4 KB
/
ContrastiveLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
import torch.nn.functional as F
# Contrastive Loss
class ContrastiveLoss(nn.Module):
"""
Contrastive loss function.
This loss function is used in the context of a siamese network.
It helps the network learn to distinguish between pairs of similar and dissimilar items.
Args:
margin (float): Margin for contrastive loss. It defines the baseline for separating positive and negative pairs.
"""
def __init__(self, margin=1.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
"""
Forward pass for the contrastive loss calculation.
Args:
output1 (torch.Tensor): Output from one of the twin networks.
output2 (torch.Tensor): Output from the other twin network.
label (torch.Tensor): Labels indicating if the pair is similar (1) or dissimilar (0).
Returns:
torch.Tensor: Computed contrastive loss.
"""
# Calculate the Euclidean distance between the two outputs
euclidean_distance = F.pairwise_distance(output1, output2)
# Calculate contrastive loss
loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
return loss_contrastive