Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CherryPick] Fixes for distribution validation checks #53763

Merged
merged 4 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions test/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,20 @@ def is_all_nan(tensor):
{'scale': torch.tensor([0., 1.], requires_grad=True)},
{'scale': torch.tensor([1., -1.], requires_grad=True)},
]),
Example(LKJCholesky, [
{
'dim': -2,
'concentration': 0.1
},
{
'dim': 1,
'concentration': 2.,
},
{
'dim': 2,
'concentration': 0.,
},
]),
Example(Laplace, [
{
'loc': torch.tensor([1., 1.], requires_grad=True),
Expand Down Expand Up @@ -1376,7 +1390,7 @@ def test_relaxed_one_hot_categorical_1d(self):
self.assertFalse(RelaxedOneHotCategorical(probs=p, temperature=temp).sample().requires_grad)
self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample((2, 2)).size(), (2, 2, 3))
self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample((1,)).size(), (1, 3))
self._gradcheck_log_prob(RelaxedOneHotCategorical, (temp, p))
self._gradcheck_log_prob(lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p))

def test_relaxed_one_hot_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
Expand All @@ -1390,8 +1404,8 @@ def test_relaxed_one_hot_categorical_2d(self):
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3))
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample((6,)).size(), (6, 2, 3))
self._gradcheck_log_prob(RelaxedOneHotCategorical, (temp, p))
self._gradcheck_log_prob(RelaxedOneHotCategorical, (temp_2, p))
self._gradcheck_log_prob(lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p))
self._gradcheck_log_prob(lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp_2, p))

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_argmax_relaxed_categorical(self):
Expand Down Expand Up @@ -1627,10 +1641,11 @@ def test_lognormal_sample(self):
'LogNormal(loc={}, scale={})'.format(mean, std))

def test_logisticnormal(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
mean = torch.randn(5, 5).requires_grad_()
std = torch.randn(5, 5).abs().requires_grad_()
mean_1d = torch.randn(1).requires_grad_()
std_1d = torch.randn(1).requires_grad_()
std_1d = torch.randn(1).abs().requires_grad_()
mean_delta = torch.tensor([1.0, 0.0])
std_delta = torch.tensor([1e-5, 1e-5])
self.assertEqual(LogisticNormal(mean, std).sample().size(), (5, 6))
Expand All @@ -1648,9 +1663,11 @@ def test_logisticnormal(self):
1. / (1. + 1. + math.exp(1))]),
atol=1e-4, rtol=0)

self._gradcheck_log_prob(LogisticNormal, (mean, std))
self._gradcheck_log_prob(LogisticNormal, (mean, 1.0))
self._gradcheck_log_prob(LogisticNormal, (0.0, std))
# TODO: gradcheck seems to mutate the sample values so that the simplex
# constraint fails by a very small margin.
self._gradcheck_log_prob(lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, std))
self._gradcheck_log_prob(lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, 1.0))
self._gradcheck_log_prob(lambda m, s: LogisticNormal(m, s, validate_args=False), (0.0, std))

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_logisticnormal_logprob(self):
Expand Down Expand Up @@ -2578,7 +2595,7 @@ def tril_cholesky_to_tril_corr(x):

for dim in range(2, 5):
log_probs = []
lkj = LKJCholesky(dim, concentration=1.)
lkj = LKJCholesky(dim, concentration=1., validate_args=True)
for i in range(2):
sample = lkj.sample()
sample_tril = tril_matrix_to_vec(sample, diag=-1)
Expand All @@ -2591,6 +2608,8 @@ def tril_cholesky_to_tril_corr(x):
# for dim=2, pdf = 0.5 (jacobian adjustment factor is 0.)
self.assertTrue(all([x == torch.tensor(0.5).log() for x in log_probs]))
self.assertEqual(log_probs[0], log_probs[1])
invalid_sample = torch.cat([sample, sample.new_ones(1, dim)], dim=0)
self.assertRaises(ValueError, lambda: lkj.log_prob(invalid_sample))

def test_independent_shape(self):
for Dist, params in EXAMPLES:
Expand Down Expand Up @@ -4498,6 +4517,35 @@ def test_valid(self):
for param in params:
Dist(validate_args=True, **param)

def test_invalid_log_probs_arg(self):
# Check that validation errors are indeed disabled,
# but they might raise another error
for Dist, params in EXAMPLES:
if Dist == TransformedDistribution:
# TransformedDistribution has a distribution instance
# as the argument, so we cannot do much about that
continue
for param in params:
d_nonval = Dist(validate_args=False, **param)
d_val = Dist(validate_args=True, **param)
for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]):
# samples with incorrect shape must throw ValueError only
try:
log_prob = d_val.log_prob(v)
except ValueError:
pass
# get sample of correct shape
val = torch.full(d_val.batch_shape + d_val.event_shape, v)
# check samples with incorrect support
try:
log_prob = d_val.log_prob(val)
except ValueError as e:
if e.args and 'must be within the support' in e.args[0]:
try:
log_prob = d_nonval.log_prob(val)
except RuntimeError:
pass

@unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN")
def test_invalid(self):
for Dist, params in BAD_EXAMPLES:
Expand Down
1 change: 1 addition & 0 deletions torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
'HalfNormal',
'Independent',
'Kumaraswamy',
'LKJCholesky',
'Laplace',
'LogNormal',
'LogisticNormal',
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, concentration1, concentration0, validate_args=None):
else:
concentration1, concentration0 = broadcast_all(concentration1, concentration0)
concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
self._dirichlet = Dirichlet(concentration1_concentration0)
self._dirichlet = Dirichlet(concentration1_concentration0, validate_args=validate_args)
super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
Expand Down
3 changes: 2 additions & 1 deletion torch/distributions/kumaraswamy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self, concentration1, concentration0, validate_args=None):
self.concentration1, self.concentration0 = broadcast_all(concentration1, concentration0)
finfo = torch.finfo(self.concentration0.dtype)
base_dist = Uniform(torch.full_like(self.concentration0, 0),
torch.full_like(self.concentration0, 1))
torch.full_like(self.concentration0, 1),
validate_args=validate_args)
transforms = [PowerTransform(exponent=self.concentration0.reciprocal()),
AffineTransform(loc=1., scale=-1.),
PowerTransform(exponent=self.concentration1.reciprocal())]
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/lkj_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def log_prob(self, value):
# So the probability of a Cholesky factor is propotional to
# prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i)
# with order_i = 2 * concentration - 2 + D - i
if self._validate_args:
self._validate_sample(value)
diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:]
order = torch.arange(2, self.dim + 1)
order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/log_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class LogNormal(TransformedDistribution):
has_rsample = True

def __init__(self, loc, scale, validate_args=None):
base_dist = Normal(loc, scale)
base_dist = Normal(loc, scale, validate_args=validate_args)
super(LogNormal, self).__init__(base_dist, ExpTransform(), validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/logistic_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class LogisticNormal(TransformedDistribution):
has_rsample = True

def __init__(self, loc, scale, validate_args=None):
base_dist = Normal(loc, scale)
base_dist = Normal(loc, scale, validate_args=validate_args)
if not base_dist.batch_shape:
base_dist = base_dist.expand([1])
super(LogisticNormal, self).__init__(base_dist,
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/mixture_same_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def cdf(self, x):
return torch.sum(cdf_x * mix_prob, dim=-1)

def log_prob(self, x):
if self._validate_args:
self._validate_sample(x)
x = self._pad(x)
log_prob_x = self.component_distribution.log_prob(x) # [S, B, k]
log_mix_prob = torch.log_softmax(self.mixture_distribution.logits,
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Pareto(TransformedDistribution):

def __init__(self, scale, alpha, validate_args=None):
self.scale, self.alpha = broadcast_all(scale, alpha)
base_dist = Exponential(self.alpha)
base_dist = Exponential(self.alpha, validate_args=validate_args)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
super(Pareto, self).__init__(base_dist, transforms, validate_args=validate_args)

Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/relaxed_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class RelaxedOneHotCategorical(TransformedDistribution):
has_rsample = True

def __init__(self, temperature, probs=None, logits=None, validate_args=None):
base_dist = ExpRelaxedCategorical(temperature, probs, logits)
base_dist = ExpRelaxedCategorical(temperature, probs, logits, validate_args=validate_args)
super(RelaxedOneHotCategorical, self).__init__(base_dist,
ExpTransform(),
validate_args=validate_args)
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/transformed_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def log_prob(self, value):
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/weibull.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Weibull(TransformedDistribution):
def __init__(self, scale, concentration, validate_args=None):
self.scale, self.concentration = broadcast_all(scale, concentration)
self.concentration_reciprocal = self.concentration.reciprocal()
base_dist = Exponential(torch.ones_like(self.scale))
base_dist = Exponential(torch.ones_like(self.scale), validate_args=validate_args)
transforms = [PowerTransform(exponent=self.concentration_reciprocal),
AffineTransform(loc=0, scale=self.scale)]
super(Weibull, self).__init__(base_dist,
Expand Down