-
Notifications
You must be signed in to change notification settings - Fork 29
/
hsic.py
127 lines (98 loc) · 4.33 KB
/
hsic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""ReBias
Copyright (c) 2020-present NAVER Corp.
MIT license
Python Implementation of the finite sample estimator of Hilbert-Schmidt Independence Criterion (HSIC)
We provide both biased estimator and unbiased estimators (unbiased estimator is used in the paper)
"""
import torch
import torch.nn as nn
def to_numpy(x):
"""convert Pytorch tensor to numpy array
"""
return x.clone().detach().cpu().numpy()
class HSIC(nn.Module):
"""Base class for the finite sample estimator of Hilbert-Schmidt Independence Criterion (HSIC)
..math:: HSIC (X, Y) := || C_{x, y} ||^2_{HS}, where HSIC (X, Y) = 0 iif X and Y are independent.
Empirically, we use the finite sample estimator of HSIC (with m observations) by,
(1) biased estimator (HSIC_0)
Gretton, Arthur, et al. "Measuring statistical dependence with Hilbert-Schmidt norms." 2005.
:math: (m - 1)^2 tr KHLH.
where K_{ij} = kernel_x (x_i, x_j), L_{ij} = kernel_y (y_i, y_j), H = 1 - m^{-1} 1 1 (Hence, K, L, H are m by m matrices).
(2) unbiased estimator (HSIC_1)
Song, Le, et al. "Feature selection via dependence maximization." 2012.
:math: \frac{1}{m (m - 3)} \bigg[ tr (\tilde K \tilde L) + \frac{1^\top \tilde K 1 1^\top \tilde L 1}{(m-1)(m-2)} - \frac{2}{m-2} 1^\top \tilde K \tilde L 1 \bigg].
where \tilde K and \tilde L are related to K and L by the diagonal entries of \tilde K_{ij} and \tilde L_{ij} are set to zero.
Parameters
----------
sigma_x : float
the kernel size of the kernel function for X.
sigma_y : float
the kernel size of the kernel function for Y.
algorithm: str ('unbiased' / 'biased')
the algorithm for the finite sample estimator. 'unbiased' is used for our paper.
reduction: not used (for compatibility with other losses).
"""
def __init__(self, sigma_x, sigma_y=None, algorithm='unbiased',
reduction=None):
super(HSIC, self).__init__()
if sigma_y is None:
sigma_y = sigma_x
self.sigma_x = sigma_x
self.sigma_y = sigma_y
if algorithm == 'biased':
self.estimator = self.biased_estimator
elif algorithm == 'unbiased':
self.estimator = self.unbiased_estimator
else:
raise ValueError('invalid estimator: {}'.format(algorithm))
def _kernel_x(self, X):
raise NotImplementedError
def _kernel_y(self, Y):
raise NotImplementedError
def biased_estimator(self, input1, input2):
"""Biased estimator of Hilbert-Schmidt Independence Criterion
Gretton, Arthur, et al. "Measuring statistical dependence with Hilbert-Schmidt norms." 2005.
"""
K = self._kernel_x(input1)
L = self._kernel_y(input2)
KH = K - K.mean(0, keepdim=True)
LH = L - L.mean(0, keepdim=True)
N = len(input1)
return torch.trace(KH @ LH / (N - 1) ** 2)
def unbiased_estimator(self, input1, input2):
"""Unbiased estimator of Hilbert-Schmidt Independence Criterion
Song, Le, et al. "Feature selection via dependence maximization." 2012.
"""
kernel_XX = self._kernel_x(input1)
kernel_YY = self._kernel_y(input2)
tK = kernel_XX - torch.diag(kernel_XX)
tL = kernel_YY - torch.diag(kernel_YY)
N = len(input1)
hsic = (
torch.trace(tK @ tL)
+ (torch.sum(tK) * torch.sum(tL) / (N - 1) / (N - 2))
- (2 * torch.sum(tK, 0).dot(torch.sum(tL, 0)) / (N - 2))
)
return hsic / (N * (N - 3))
def forward(self, input1, input2, **kwargs):
return self.estimator(input1, input2)
class RbfHSIC(HSIC):
"""Radial Basis Function (RBF) kernel HSIC implementation.
"""
def _kernel(self, X, sigma):
X = X.view(len(X), -1)
XX = X @ X.t()
X_sqnorms = torch.diag(XX)
X_L2 = -2 * XX + X_sqnorms.unsqueeze(1) + X_sqnorms.unsqueeze(0)
gamma = 1 / (2 * sigma ** 2)
kernel_XX = torch.exp(-gamma * X_L2)
return kernel_XX
def _kernel_x(self, X):
return self._kernel(X, self.sigma_x)
def _kernel_y(self, Y):
return self._kernel(Y, self.sigma_y)
class MinusRbfHSIC(RbfHSIC):
"""``Minus'' RbfHSIC for the ``max'' optimization.
"""
def forward(self, input1, input2, **kwargs):
return -self.estimator(input1, input2)