Skip to content

Commit

Permalink
Pass allow_packing=True when converting input points to nested tens…
Browse files Browse the repository at this point in the history
…ors in GP, GPRM, and SchurComplement.

PiperOrigin-RevId: 485888424
  • Loading branch information
emilyfertig authored and tensorflower-gardener committed Nov 3, 2022
1 parent 200200a commit 0b50a7e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def __init__(self,
if index_points is not None:
index_points = nest_util.convert_to_nested_tensor(
index_points, dtype=input_dtype, name='index_points',
convert_ref=False)
convert_ref=False, allow_packing=True)
jitter = tensor_util.convert_nonref_to_tensor(
jitter, dtype=dtype, name='jitter')
observation_noise_variance = tensor_util.convert_nonref_to_tensor(
Expand Down Expand Up @@ -589,7 +589,7 @@ def _get_index_points(self, index_points=None):
'computed.')
return nest_util.convert_to_nested_tensor(
index_points if index_points is not None else self._index_points,
dtype_hint=self.kernel.dtype)
dtype_hint=self.kernel.dtype, allow_packing=True)

@distribution_util.AppendDocstring(kwargs_dict={
'index_points':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,11 @@ def __init__(self,
if index_points is not None:
index_points = nest_util.convert_to_nested_tensor(
index_points, dtype=input_dtype, convert_ref=False,
name='index_points')
name='index_points', allow_packing=True)
if observation_index_points is not None:
observation_index_points = nest_util.convert_to_nested_tensor(
observation_index_points, dtype=input_dtype, convert_ref=False,
name='observation_index_points')
name='observation_index_points', allow_packing=True)
observations = tensor_util.convert_nonref_to_tensor(
observations, dtype=dtype,
name='observations')
Expand Down Expand Up @@ -583,7 +583,8 @@ def conditional_mean_fn(x):
"""Conditional mean."""
observations = tf.convert_to_tensor(self._observations)
observation_index_points = nest_util.convert_to_nested_tensor(
self._observation_index_points, dtype_hint=self.kernel.dtype)
self._observation_index_points, dtype_hint=self.kernel.dtype,
allow_packing=True)
k_x_obs_linop = tf.linalg.LinearOperatorFullMatrix(
kernel.matrix(x, observation_index_points))
chol_linop = tf.linalg.LinearOperatorLowerTriangular(
Expand Down Expand Up @@ -774,7 +775,7 @@ def precompute_regression_model(
jitter = tf.convert_to_tensor(jitter, dtype=dtype)

observation_index_points = nest_util.convert_to_nested_tensor(
observation_index_points, dtype=input_dtype)
observation_index_points, dtype=input_dtype, allow_packing=True)
observation_noise_variance = tf.convert_to_tensor(
observation_noise_variance, dtype=dtype)
observations = tf.convert_to_tensor(observations, dtype=dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ def __init__(self,
self._diag_shift = tensor_util.convert_nonref_to_tensor(
diag_shift, dtype=float_dtype, name='diag_shift')
self._fixed_inputs = nest_util.convert_to_nested_tensor(
fixed_inputs, dtype=dtype, name='fixed_inputs', convert_ref=False)
fixed_inputs, dtype=dtype, name='fixed_inputs', convert_ref=False,
allow_packing=True)
if ((fixed_inputs_mask is not None) and
(fixed_inputs_is_missing is not None)):
raise ValueError('Expected at most one of `fixed_inputs_mask` or '
Expand Down Expand Up @@ -378,7 +379,8 @@ def with_precomputed_divisor(
fixed_inputs,
diag_shift], tf.float32)
float_dtype = dtype
fixed_inputs = nest_util.convert_to_nested_tensor(fixed_inputs, dtype)
fixed_inputs = nest_util.convert_to_nested_tensor(
fixed_inputs, dtype=dtype, allow_packing=True)
if ((fixed_inputs_mask is not None) and
(fixed_inputs_is_missing is not None)):
raise ValueError('Expected at most one of `fixed_inputs_mask` or '
Expand Down Expand Up @@ -452,7 +454,7 @@ def _apply(self, x1, x2, example_ndims):
return k12

fixed_inputs = nest_util.convert_to_nested_tensor(
self._fixed_inputs, dtype_hint=self.dtype)
self._fixed_inputs, dtype_hint=self.dtype, allow_packing=True)
fixed_inputs_is_missing = self._get_fixed_inputs_is_missing()
if fixed_inputs_is_missing is not None:
fixed_inputs_is_missing = util.pad_shape_with_ones(
Expand Down Expand Up @@ -503,7 +505,7 @@ def _matrix(self, x1, x2):
return k12

fixed_inputs = nest_util.convert_to_nested_tensor(
self._fixed_inputs, dtype_hint=self.dtype)
self._fixed_inputs, dtype_hint=self.dtype, allow_packing=True)
fixed_inputs_is_missing = self._get_fixed_inputs_is_missing()
if fixed_inputs_is_missing is not None:
fixed_inputs_is_missing = fixed_inputs_is_missing[..., tf.newaxis, :]
Expand Down Expand Up @@ -574,7 +576,7 @@ def _parameter_properties(cls, dtype):
def _divisor_matrix(self, fixed_inputs=None, fixed_inputs_is_missing=None):
fixed_inputs = nest_util.convert_to_nested_tensor(
self._fixed_inputs if fixed_inputs is None else fixed_inputs,
dtype_hint=self.dtype)
dtype_hint=self.dtype, allow_packing=True)
if fixed_inputs_is_missing is None:
fixed_inputs_is_missing = self._get_fixed_inputs_is_missing()
# NOTE: Replacing masked-out rows/columns of the divisor matrix with
Expand Down

0 comments on commit 0b50a7e

Please sign in to comment.