Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove default independent sampler jitter but ensure positive variance #888

Merged
merged 12 commits into from
Jan 21, 2025

Conversation

uri-granta
Copy link
Collaborator

@uri-granta uri-granta commented Jan 3, 2025

Related issue(s)/PRs:

Summary

As discussed elsewhere, jitter isn't necessary for independent reparametrization sampling, beyond wanting to ensure that the variance is non-zero.

Fully backwards compatible: yes

PR checklist

  • The quality checks are all passing
  • The bug case / new feature is covered by tests
  • Any new features are well-documented (in docstrings or notebooks)

@uri-granta uri-granta changed the title Cap reparam sampling jitter Remove default reparam jitter but ensure positive variance Jan 3, 2025
@uri-granta uri-granta marked this pull request as ready for review January 5, 2025 21:49
@@ -133,7 +134,7 @@ def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorTy
tf.debugging.assert_greater_equal(jitter, 0.0)

mean, var = self._model.predict(at[..., None, :, :]) # [..., 1, 1, L], [..., 1, 1, L]
var = var + jitter
var = ensure_positive(var + jitter)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(note that we could alternatively ignore the jitter argument here, even if it's explicitly provided, if we think that would be better)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This version might be a bit difficult to read and debug, as we are potentially applying a correction twice (we apply the jitter with the sum, then with ensure_positive we potentially add an offset).

But I'm not sure if there exists a better alternative

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One solution to both this comment and the one at the end would be to change the default value to -1, and comment that this magic value doesn't add jitter but ensures that the variance is positive. And then if the user specifies an explicit non-negative jitter we can use that unmodified?

(Engineering-wise it would be nicer to make jitter an Optional[float] but that would necessitate changing the interface and modifying the other samplers too.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would explicitly ignore the jitter here and add to docstrings that it is ignored - perhaps lets also do it properly and change it to be optional

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here I think there should be no reason for the user to want a different jitter, right @vpicheny ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main case for the jitter here is when the sampling is used with an acquisition function, possibly using sqrt(var) or log(var) or cdf(mean, var), that would fail if it is numerically zero but negative.

Otherwise we would probably just want to avoid any offset that would get in the way, e.g. say the output is not rescaled and has very very small values so adding 1e-6 would change everything.

We could leave this logic to the acquisition function, or just ensure here that we are "just positive".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vpicheny so is your suggestion to ignore the jitter in IndependentReparametrizationSampler but still call ensure_positive?

@hstojic similarly, are you proposing to call ensure_positive in deep_ensemble_trajectory rather than adding DEFAULTS.JITTER?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hstojic similarly, are you proposing to call ensure_positive in deep_ensemble_trajectory rather than adding DEFAULTS.JITTER?

lets raise a separate PR for keras and gpflux

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @vpicheny is not sure, either

  1. no change and leave it to whatever is using sample, or
  2. ignore it and make it barely positive when we find 0

I would go with 2, but perhaps then make sure it can be overriden by jitter argument?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean by "ignore it but make sure it can be overridden by jitter argument". What does overriding mean? (If we allow users to specify a jitter value then isn't that option 1?)

@uri-granta uri-granta changed the title Remove default reparam jitter but ensure positive variance Remove default independent sampler jitter but ensure positive variance Jan 6, 2025
@@ -285,6 +285,20 @@ def test_independent_reparametrization_sampler_reset_sampler(qmc: bool, qmc_skip
npt.assert_array_less(1e-9, tf.abs(samples2 - samples1))


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("dtype", [tf.float32, tf.float64])
def test_independent_reparametrization_sampler_sample_ensures_positive_variance(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what this test is doing... does setting the kernel amplitude to 0 makes the model variance equal to zero? should we check then that the model prediction variance is zero, but the sampler applies the right fix?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right. I've now added an assert that the model variance is zero.

def ensure_positive(x: TensorType) -> TensorType:
"""Esure that all the elements in `x` are strictly positive (using a dtype-dependent
capping threshold."""
return tf.math.maximum(x, 1e-6 if x.dtype == tf.float32 else 1e-16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naive question, is 1e-6 the lowest we can have with single precision?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at all. This was just based on scaling up the suggested value of 1e-16 for float64. Both numbers can go significantly smaller if we want: float32 can go down to aound 1e-38 and float64 to 2e-308. Do you have any intuition for how small we should make these?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may be fine with smallest number for each precision that makes it positive, though it may depend on the usage downstream - at the moment we are just taking sqrt and doing some multiplication, that will take it to equal 0 but in this use case it should be fine I think? eps contribution would be removed in these cases, but not sure if that's relevant

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I would probably vote for a very small value on both cases. 1e-6 is way too high.

And maybe we do not need to differentiate between single and double precision? Both could be e.g. 1e-32 or something

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do want to avoid "jumps" so I would go with something close to the end of range - imagine having an adjacent point that is 1e-300 but then you swap 0 with 1e-32, that would create a jump, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensure_positive will turn both 0 and 1e-300 to 1e-32, so there won't be a jump (but there won't be any gradient either)

Copy link
Collaborator

@vpicheny vpicheny left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good but something that bothers me is that there is no way of bypassing the corrections given by ensure_positive.
If someone really wants to use the "true" variance, which could be exactly zero, or just control manually the amount of correction, there is no way of doing this.
But maybe it's OK as it is?

Copy link
Collaborator

@hstojic hstojic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comments

@@ -133,7 +134,7 @@ def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorTy
tf.debugging.assert_greater_equal(jitter, 0.0)

mean, var = self._model.predict(at[..., None, :, :]) # [..., 1, 1, L], [..., 1, 1, L]
var = var + jitter
var = ensure_positive(var + jitter)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would explicitly ignore the jitter here and add to docstrings that it is ignored - perhaps lets also do it properly and change it to be optional

@@ -133,7 +134,7 @@ def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorTy
tf.debugging.assert_greater_equal(jitter, 0.0)

mean, var = self._model.predict(at[..., None, :, :]) # [..., 1, 1, L], [..., 1, 1, L]
var = var + jitter
var = ensure_positive(var + jitter)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here I think there should be no reason for the user to want a different jitter, right @vpicheny ?

def ensure_positive(x: TensorType) -> TensorType:
"""Esure that all the elements in `x` are strictly positive (using a dtype-dependent
capping threshold."""
return tf.math.maximum(x, 1e-6 if x.dtype == tf.float32 else 1e-16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may be fine with smallest number for each precision that makes it positive, though it may depend on the usage downstream - at the moment we are just taking sqrt and doing some multiplication, that will take it to equal 0 but in this use case it should be fine I think? eps contribution would be removed in these cases, but not sure if that's relevant

@@ -133,7 +134,7 @@ def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorTy
tf.debugging.assert_greater_equal(jitter, 0.0)

mean, var = self._model.predict(at[..., None, :, :]) # [..., 1, 1, L], [..., 1, 1, L]
var = var + jitter
var = ensure_positive(var + jitter)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also please check GPflux and keras samplers?
in Keras (https://github.com/secondmind-labs/trieste/blob/25d2a038fc1a74485337afac4fa45f29a4c4a311/trieste/models/keras/sampler.py#L171C59-L171C67), we have the same use-case and we should use the new ensure_positive function there as well

Copy link
Collaborator

@hstojic hstojic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@uri-granta uri-granta merged commit 2365c50 into develop Jan 21, 2025
12 checks passed
@uri-granta uri-granta deleted the uri/cap_sampling_jitter branch January 21, 2025 11:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants