Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes HalfCauchyPrior #2137

Merged
merged 2 commits into from
Sep 15, 2022
Merged

Fixes HalfCauchyPrior #2137

merged 2 commits into from
Sep 15, 2022

Conversation

feynmanliang
Copy link
Contributor

Using HalfCauchyPrior in https://docs.gpytorch.ai/en/stable/examples/01_Exact_GPs/GP_Regression_Fully_Bayesian.html raises

'HalfCauchyPrior' object has no attribute 'loc'
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-79-6f8875ee8806> in <module>
     15 nuts_kernel = NUTS(pyro_model)
     16 mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, disable_progbar=smoke_test)
---> 17 mcmc_run.run(train_X, train_Y)
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    380             # requires_grad", which happens with `jit_compile` under PyTorch 1.7
    381             args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
--> 382             for x, chain_id in self.sampler.run(*args, **kwargs):
    383                 if num_samples[chain_id] == 0:
    384                     num_samples[chain_id] += 1
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    162             logger = initialize_logger(logger, "", progress_bar)
    163             hook_w_logging = _add_logging_hook(logger, progress_bar, self.hook)
--> 164             for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging,
    165                                        i if self.num_chains > 1 else None,
    166                                        *args, **kwargs):
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
    106 
    107 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 108     kernel.setup(warmup_steps, *args, **kwargs)
    109     params = kernel.initial_params
    110     # yield structure (key, value.shape) of params
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
    301         self._warmup_steps = warmup_steps
    302         if self.model is not None:
--> 303             self._initialize_model_properties(args, kwargs)
    304         if self.initial_params:
    305             z = {k: v.detach() for k, v in self.initial_params.items()}
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
    243 
    244     def _initialize_model_properties(self, model_args, model_kwargs):
--> 245         init_params, potential_fn, transforms, trace = initialize_model(
    246             self.model,
    247             model_args,
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/pyro/infer/mcmc/util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains, init_strategy, initial_params)
    380         automatic_transform_enabled = False
    381     if max_plate_nesting is None:
--> 382         max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
    383     # Wrap model in `poutine.enum` to enumerate over discrete latent sites.
    384     # No-op if model does not have any discrete latents.
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/pyro/infer/mcmc/util.py in _guess_max_plate_nesting(model, args, kwargs)
    234     """
    235     with poutine.block():
--> 236         model_trace = poutine.trace(model).get_trace(*args, **kwargs)
    237     sites = [site for site in model_trace.nodes.values()
    238              if site["type"] == "sample"]
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    185         Calls this poutine and returns its trace instead of the function's return value.
    186         """
--> 187         self(*args, **kwargs)
    188         return self.msngr.get_trace()
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    163                                       args=args, kwargs=kwargs)
    164             try:
--> 165                 ret = self.fn(*args, **kwargs)
    166             except (ValueError, RuntimeError) as e:
    167                 exc_type, exc_value, traceback = sys.exc_info()
<ipython-input-79-6f8875ee8806> in pyro_model(x, y)
      8 def pyro_model(x, y):
      9     with gpytorch.settings.fast_computations(False, False, False):
---> 10         sampled_model = model.pyro_sample_from_prior()
     11         output = sampled_model.likelihood(sampled_model(x))
     12         pyro.sample("obs", output, obs=y)
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/gpytorch/module.py in pyro_sample_from_prior(self)
    381         """
    382         new_module = self.to_pyro_random_module()
--> 383         return _pyro_sample_from_prior(module=new_module, memo=None, prefix="")
    384 
    385     def local_load_samples(self, samples_dict, memo, prefix):
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/gpytorch/module.py in _pyro_sample_from_prior(module, memo, prefix)
    491     for mname, module_ in module.named_children():
    492         submodule_prefix = prefix + ("." if prefix else "") + mname
--> 493         _pyro_sample_from_prior(module=module_, memo=memo, prefix=submodule_prefix)
    494 
    495     return module
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/gpytorch/module.py in _pyro_sample_from_prior(module, memo, prefix)
    491     for mname, module_ in module.named_children():
    492         submodule_prefix = prefix + ("." if prefix else "") + mname
--> 493         _pyro_sample_from_prior(module=module_, memo=memo, prefix=submodule_prefix)
    494 
    495     return module
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/gpytorch/module.py in _pyro_sample_from_prior(module, memo, prefix)
    485                     )
    486                 memo.add(prior)
--> 487                 prior = prior.expand(closure(module).shape)
    488                 value = pyro.sample(prefix + ("." if prefix else "") + prior_name, prior)
    489                 setting_closure(module, value)
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/gpytorch/priors/torch_priors.py in expand(self, batch_shape)
     72 
     73     def expand(self, batch_shape):
---> 74         return HalfCauchy(self.loc.expand(batch_shape), self.scale.expand(batch_shape))
     75 
     76 
/mnt/xarfuse/uid-227577/ac1f4bfb-seed-nspid4026531836_cgpid3197131-ns-4026531840/torch/nn/modules/module.py in __getattr__(self, name)
   1263             if name in modules:
   1264                 return modules[name]
-> 1265         raise AttributeError("'{}' object has no attribute '{}'".format(
   1266             type(self).__name__, name))
   1267 
AttributeError: 'HalfCauchyPrior' object has no attribute 'loc'

After this PR, there is no longer any error.

@Balandat Balandat enabled auto-merge September 15, 2022 00:19
@Balandat Balandat merged commit f66e957 into cornellius-gp:master Sep 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants