Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CherryPick] Fixes for distribution validation checks #53763

Merged
merged 4 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix distributions which don't properly honor validate_args=False
A number of derived distributions use base distributions in their
implementation.

We add what we hope is a comprehensive test whether all distributions
actually honor skipping validation of arguments in log_prob and then
fix the bugs we found. These bugs are particularly cumbersome in
PyTorch 1.8 and master when validate_args is turned on by default
In addition one might argue that validate_args is not performing
as well as it should when the default is not to validate but the
validation is turned on in instantiation.

Arguably, there is another set of bugs or at least inconsistencies
when validation of inputs does not prevent invalid indices in
sample validation (when with validation an IndexError is raised
in the test). We would encourage the implementors to be more
ambitious when validation is turned on and amend sample validation
to throw a ValueError for consistency.
  • Loading branch information
t-vi authored and neerajprad committed Mar 10, 2021
commit 4051b739f80205258c5130049b88c22b3c5e497d
23 changes: 23 additions & 0 deletions test/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4514,6 +4514,29 @@ def test_valid(self):
for param in params:
Dist(validate_args=True, **param)

def test_invalid_log_probs_arg(self):
# Check that validation errors are indeed disabled,
# but they might raise another error
for Dist, params in EXAMPLES:
if Dist == TransformedDistribution:
# TransformedDistribution has a distribution instance
# as the argument, so we cannot do much about that
continue
for param in params:
d_nonval = Dist(validate_args=False, **param)
d_val = Dist(validate_args=True, **param)
for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]):
try:
log_prob = d_val.log_prob(v)
except IndexError:
pass
except ValueError as e:
if e.args and 'must be within the support' in e.args[0]:
try:
log_prob = d_nonval.log_prob(v)
except (IndexError, RuntimeError):
pass

@unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN")
def test_invalid(self):
for Dist, params in BAD_EXAMPLES:
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, concentration1, concentration0, validate_args=None):
else:
concentration1, concentration0 = broadcast_all(concentration1, concentration0)
concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
self._dirichlet = Dirichlet(concentration1_concentration0)
self._dirichlet = Dirichlet(concentration1_concentration0, validate_args=validate_args)
super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
Expand Down
3 changes: 2 additions & 1 deletion torch/distributions/kumaraswamy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self, concentration1, concentration0, validate_args=None):
self.concentration1, self.concentration0 = broadcast_all(concentration1, concentration0)
finfo = torch.finfo(self.concentration0.dtype)
base_dist = Uniform(torch.full_like(self.concentration0, 0),
torch.full_like(self.concentration0, 1))
torch.full_like(self.concentration0, 1),
validate_args=validate_args)
transforms = [PowerTransform(exponent=self.concentration0.reciprocal()),
AffineTransform(loc=1., scale=-1.),
PowerTransform(exponent=self.concentration1.reciprocal())]
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/log_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class LogNormal(TransformedDistribution):
has_rsample = True

def __init__(self, loc, scale, validate_args=None):
base_dist = Normal(loc, scale)
base_dist = Normal(loc, scale, validate_args=validate_args)
super(LogNormal, self).__init__(base_dist, ExpTransform(), validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Pareto(TransformedDistribution):

def __init__(self, scale, alpha, validate_args=None):
self.scale, self.alpha = broadcast_all(scale, alpha)
base_dist = Exponential(self.alpha)
base_dist = Exponential(self.alpha, validate_args=validate_args)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
super(Pareto, self).__init__(base_dist, transforms, validate_args=validate_args)

Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/weibull.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Weibull(TransformedDistribution):
def __init__(self, scale, concentration, validate_args=None):
self.scale, self.concentration = broadcast_all(scale, concentration)
self.concentration_reciprocal = self.concentration.reciprocal()
base_dist = Exponential(torch.ones_like(self.scale))
base_dist = Exponential(torch.ones_like(self.scale), validate_args=validate_args)
transforms = [PowerTransform(exponent=self.concentration_reciprocal),
AffineTransform(loc=0, scale=self.scale)]
super(Weibull, self).__init__(base_dist,
Expand Down