Skip to content

Commit

Permalink
[CherryPick] Fixes for distribution validation checks (#53763)
Browse files Browse the repository at this point in the history
* Add sample validation for LKJCholesky.log_prob

* Fix distributions which don't properly honor validate_args=False

A number of derived distributions use base distributions in their
implementation.

We add what we hope is a comprehensive test whether all distributions
actually honor skipping validation of arguments in log_prob and then
fix the bugs we found. These bugs are particularly cumbersome in
PyTorch 1.8 and master when validate_args is turned on by default
In addition one might argue that validate_args is not performing
as well as it should when the default is not to validate but the
validation is turned on in instantiation.

Arguably, there is another set of bugs or at least inconsistencies
when validation of inputs does not prevent invalid indices in
sample validation (when with validation an IndexError is raised
in the test). We would encourage the implementors to be more
ambitious when validation is turned on and amend sample validation
to throw a ValueError for consistency.

* additional fixes to distributions

* address failing tests

Co-authored-by: neerajprad <neerajprad@devvm903.atn0.facebook.com>
Co-authored-by: Thomas Viehmann <tv.code@beamnet.de>
  • Loading branch information
3 people authored Mar 15, 2021
1 parent 4596a8e commit e991cda
Show file tree
Hide file tree
Showing 12 changed files with 71 additions and 15 deletions.
64 changes: 56 additions & 8 deletions test/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,20 @@ def is_all_nan(tensor):
{'scale': torch.tensor([0., 1.], requires_grad=True)},
{'scale': torch.tensor([1., -1.], requires_grad=True)},
]),
Example(LKJCholesky, [
{
'dim': -2,
'concentration': 0.1
},
{
'dim': 1,
'concentration': 2.,
},
{
'dim': 2,
'concentration': 0.,
},
]),
Example(Laplace, [
{
'loc': torch.tensor([1., 1.], requires_grad=True),
Expand Down Expand Up @@ -1376,7 +1390,7 @@ def test_relaxed_one_hot_categorical_1d(self):
self.assertFalse(RelaxedOneHotCategorical(probs=p, temperature=temp).sample().requires_grad)
self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample((2, 2)).size(), (2, 2, 3))
self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample((1,)).size(), (1, 3))
self._gradcheck_log_prob(RelaxedOneHotCategorical, (temp, p))
self._gradcheck_log_prob(lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p))

def test_relaxed_one_hot_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
Expand All @@ -1390,8 +1404,8 @@ def test_relaxed_one_hot_categorical_2d(self):
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3))
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample((6,)).size(), (6, 2, 3))
self._gradcheck_log_prob(RelaxedOneHotCategorical, (temp, p))
self._gradcheck_log_prob(RelaxedOneHotCategorical, (temp_2, p))
self._gradcheck_log_prob(lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p))
self._gradcheck_log_prob(lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp_2, p))

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_argmax_relaxed_categorical(self):
Expand Down Expand Up @@ -1627,10 +1641,11 @@ def test_lognormal_sample(self):
'LogNormal(loc={}, scale={})'.format(mean, std))

def test_logisticnormal(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
mean = torch.randn(5, 5).requires_grad_()
std = torch.randn(5, 5).abs().requires_grad_()
mean_1d = torch.randn(1).requires_grad_()
std_1d = torch.randn(1).requires_grad_()
std_1d = torch.randn(1).abs().requires_grad_()
mean_delta = torch.tensor([1.0, 0.0])
std_delta = torch.tensor([1e-5, 1e-5])
self.assertEqual(LogisticNormal(mean, std).sample().size(), (5, 6))
Expand All @@ -1648,9 +1663,11 @@ def test_logisticnormal(self):
1. / (1. + 1. + math.exp(1))]),
atol=1e-4, rtol=0)

self._gradcheck_log_prob(LogisticNormal, (mean, std))
self._gradcheck_log_prob(LogisticNormal, (mean, 1.0))
self._gradcheck_log_prob(LogisticNormal, (0.0, std))
# TODO: gradcheck seems to mutate the sample values so that the simplex
# constraint fails by a very small margin.
self._gradcheck_log_prob(lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, std))
self._gradcheck_log_prob(lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, 1.0))
self._gradcheck_log_prob(lambda m, s: LogisticNormal(m, s, validate_args=False), (0.0, std))

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_logisticnormal_logprob(self):
Expand Down Expand Up @@ -2578,7 +2595,7 @@ def tril_cholesky_to_tril_corr(x):

for dim in range(2, 5):
log_probs = []
lkj = LKJCholesky(dim, concentration=1.)
lkj = LKJCholesky(dim, concentration=1., validate_args=True)
for i in range(2):
sample = lkj.sample()
sample_tril = tril_matrix_to_vec(sample, diag=-1)
Expand All @@ -2591,6 +2608,8 @@ def tril_cholesky_to_tril_corr(x):
# for dim=2, pdf = 0.5 (jacobian adjustment factor is 0.)
self.assertTrue(all([x == torch.tensor(0.5).log() for x in log_probs]))
self.assertEqual(log_probs[0], log_probs[1])
invalid_sample = torch.cat([sample, sample.new_ones(1, dim)], dim=0)
self.assertRaises(ValueError, lambda: lkj.log_prob(invalid_sample))

def test_independent_shape(self):
for Dist, params in EXAMPLES:
Expand Down Expand Up @@ -4498,6 +4517,35 @@ def test_valid(self):
for param in params:
Dist(validate_args=True, **param)

def test_invalid_log_probs_arg(self):
# Check that validation errors are indeed disabled,
# but they might raise another error
for Dist, params in EXAMPLES:
if Dist == TransformedDistribution:
# TransformedDistribution has a distribution instance
# as the argument, so we cannot do much about that
continue
for param in params:
d_nonval = Dist(validate_args=False, **param)
d_val = Dist(validate_args=True, **param)
for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]):
# samples with incorrect shape must throw ValueError only
try:
log_prob = d_val.log_prob(v)
except ValueError:
pass
# get sample of correct shape
val = torch.full(d_val.batch_shape + d_val.event_shape, v)
# check samples with incorrect support
try:
log_prob = d_val.log_prob(val)
except ValueError as e:
if e.args and 'must be within the support' in e.args[0]:
try:
log_prob = d_nonval.log_prob(val)
except RuntimeError:
pass

@unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN")
def test_invalid(self):
for Dist, params in BAD_EXAMPLES:
Expand Down
1 change: 1 addition & 0 deletions torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
'HalfNormal',
'Independent',
'Kumaraswamy',
'LKJCholesky',
'Laplace',
'LogNormal',
'LogisticNormal',
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, concentration1, concentration0, validate_args=None):
else:
concentration1, concentration0 = broadcast_all(concentration1, concentration0)
concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
self._dirichlet = Dirichlet(concentration1_concentration0)
self._dirichlet = Dirichlet(concentration1_concentration0, validate_args=validate_args)
super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
Expand Down
3 changes: 2 additions & 1 deletion torch/distributions/kumaraswamy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self, concentration1, concentration0, validate_args=None):
self.concentration1, self.concentration0 = broadcast_all(concentration1, concentration0)
finfo = torch.finfo(self.concentration0.dtype)
base_dist = Uniform(torch.full_like(self.concentration0, 0),
torch.full_like(self.concentration0, 1))
torch.full_like(self.concentration0, 1),
validate_args=validate_args)
transforms = [PowerTransform(exponent=self.concentration0.reciprocal()),
AffineTransform(loc=1., scale=-1.),
PowerTransform(exponent=self.concentration1.reciprocal())]
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/lkj_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def log_prob(self, value):
# So the probability of a Cholesky factor is propotional to
# prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i)
# with order_i = 2 * concentration - 2 + D - i
if self._validate_args:
self._validate_sample(value)
diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:]
order = torch.arange(2, self.dim + 1)
order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/log_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class LogNormal(TransformedDistribution):
has_rsample = True

def __init__(self, loc, scale, validate_args=None):
base_dist = Normal(loc, scale)
base_dist = Normal(loc, scale, validate_args=validate_args)
super(LogNormal, self).__init__(base_dist, ExpTransform(), validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/logistic_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class LogisticNormal(TransformedDistribution):
has_rsample = True

def __init__(self, loc, scale, validate_args=None):
base_dist = Normal(loc, scale)
base_dist = Normal(loc, scale, validate_args=validate_args)
if not base_dist.batch_shape:
base_dist = base_dist.expand([1])
super(LogisticNormal, self).__init__(base_dist,
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/mixture_same_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def cdf(self, x):
return torch.sum(cdf_x * mix_prob, dim=-1)

def log_prob(self, x):
if self._validate_args:
self._validate_sample(x)
x = self._pad(x)
log_prob_x = self.component_distribution.log_prob(x) # [S, B, k]
log_mix_prob = torch.log_softmax(self.mixture_distribution.logits,
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Pareto(TransformedDistribution):

def __init__(self, scale, alpha, validate_args=None):
self.scale, self.alpha = broadcast_all(scale, alpha)
base_dist = Exponential(self.alpha)
base_dist = Exponential(self.alpha, validate_args=validate_args)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
super(Pareto, self).__init__(base_dist, transforms, validate_args=validate_args)

Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/relaxed_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class RelaxedOneHotCategorical(TransformedDistribution):
has_rsample = True

def __init__(self, temperature, probs=None, logits=None, validate_args=None):
base_dist = ExpRelaxedCategorical(temperature, probs, logits)
base_dist = ExpRelaxedCategorical(temperature, probs, logits, validate_args=validate_args)
super(RelaxedOneHotCategorical, self).__init__(base_dist,
ExpTransform(),
validate_args=validate_args)
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/transformed_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def log_prob(self, value):
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/weibull.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Weibull(TransformedDistribution):
def __init__(self, scale, concentration, validate_args=None):
self.scale, self.concentration = broadcast_all(scale, concentration)
self.concentration_reciprocal = self.concentration.reciprocal()
base_dist = Exponential(torch.ones_like(self.scale))
base_dist = Exponential(torch.ones_like(self.scale), validate_args=validate_args)
transforms = [PowerTransform(exponent=self.concentration_reciprocal),
AffineTransform(loc=0, scale=self.scale)]
super(Weibull, self).__init__(base_dist,
Expand Down

0 comments on commit e991cda

Please sign in to comment.