Skip to content

Commit

Permalink
Fix gradient issue in tfpk.Matern when input points are equal.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 485744124
  • Loading branch information
emilyfertig authored and tensorflower-gardener committed Nov 3, 2022
1 parent be8de53 commit 40f9ac3
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 26 deletions.
19 changes: 1 addition & 18 deletions tensorflow_probability/python/math/psd_kernels/internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,24 +232,7 @@ def pairwise_square_distance_matrix(x1, x2, feature_ndims):
ps.reduce_prod(ps.shape(x2)[-feature_ndims:])]], axis=0))
pairwise_sq = row_norm_x1 + row_norm_x2 - 2 * tf.linalg.matmul(
reshaped_x1, reshaped_x2, transpose_b=True)
pairwise_sq = tf.clip_by_value(pairwise_sq, 0., np.inf)

# If we statically know that `x1` and `x2` have the same number of examples,
# then we check if they are equal so that we can ensure that the diagonal
# distances are zero in this case.
num_examples1 = tf.compat.dimension_value(x1.shape[-feature_ndims - 1])
num_examples2 = tf.compat.dimension_value(x2.shape[-feature_ndims - 1])
if num_examples1 is not None and num_examples2 is not None:
if num_examples1 == num_examples2:
all_equal = tf.reduce_all(
tf.equal(x1, x2), axis=range(-1, -feature_ndims - 2, -1))
eye = tf.eye(num_examples1, dtype=pairwise_sq.dtype)
pairwise_sq = tf.where(
all_equal[..., tf.newaxis, tf.newaxis] & (eye == 1.),
tf.zeros([], dtype=pairwise_sq.dtype),
pairwise_sq)

return pairwise_sq
return tf.clip_by_value(pairwise_sq, 0., np.inf)


def pairwise_square_distance_tensor(
Expand Down
36 changes: 28 additions & 8 deletions tensorflow_probability/python/math/psd_kernels/matern.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ def _parameter_properties(cls, dtype):

def _apply_with_distance(
self, x1, x2, pairwise_square_distance, example_ndims=0):
# Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.
norm = util.sqrt_with_finite_grads(pairwise_square_distance)
norm = tf.math.sqrt(pairwise_square_distance)
inverse_length_scale = self._inverse_length_scale_parameter()
if inverse_length_scale is not None:
inverse_length_scale = util.pad_shape_with_ones(
Expand Down Expand Up @@ -319,8 +318,15 @@ def __init__(self,

def _apply_with_distance(
self, x1, x2, pairwise_square_distance, example_ndims=0):
# Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.
norm = util.sqrt_with_finite_grads(pairwise_square_distance)
# Where pairwise square distance is 0, gradients with respect to each of the
# inputs should be 0 as well. Set square distances to be statically 0 to
# ensure that gradients are 0 (and not infinity/NaN when the square root is
# taken).
pairwise_sq = tf.where(
tf.equal(pairwise_square_distance, 0.),
tf.zeros([], dtype=pairwise_square_distance.dtype),
pairwise_square_distance)
norm = tf.math.sqrt(pairwise_sq)
inverse_length_scale = self._inverse_length_scale_parameter()
if inverse_length_scale is not None:
inverse_length_scale = util.pad_shape_with_ones(
Expand Down Expand Up @@ -393,8 +399,15 @@ def __init__(self,

def _apply_with_distance(
self, x1, x2, pairwise_square_distance, example_ndims=0):
# Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.
norm = util.sqrt_with_finite_grads(pairwise_square_distance)
# Where pairwise square distance is 0, gradients with respect to each of the
# inputs should be 0 as well. Set square distances to be statically 0 to
# ensure that gradients are 0 (and not infinity/NaN when the square root is
# taken).
pairwise_sq = tf.where(
tf.equal(pairwise_square_distance, 0.),
tf.zeros([], dtype=pairwise_square_distance.dtype),
pairwise_square_distance)
norm = tf.math.sqrt(pairwise_sq)
np_dtype = dtype_util.as_numpy_dtype(norm.dtype)
inverse_length_scale = self._inverse_length_scale_parameter()
if inverse_length_scale is not None:
Expand Down Expand Up @@ -468,8 +481,15 @@ def __init__(self,

def _apply_with_distance(
self, x1, x2, pairwise_square_distance, example_ndims=0):
# Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.
norm = util.sqrt_with_finite_grads(pairwise_square_distance)
# Where pairwise square distance is 0, gradients with respect to each of the
# inputs should be 0 as well. Set square distances to be statically 0 to
# ensure that gradients are 0 (and not infinity/NaN when the square root is
# taken).
pairwise_sq = tf.where(
tf.equal(pairwise_square_distance, 0.),
tf.zeros([], dtype=pairwise_square_distance.dtype),
pairwise_square_distance)
norm = tf.math.sqrt(pairwise_sq)
inverse_length_scale = self._inverse_length_scale_parameter()
if inverse_length_scale is not None:
inverse_length_scale = util.pad_shape_with_ones(
Expand Down

0 comments on commit 40f9ac3

Please sign in to comment.