Automatic Guide Generation¶
AutoDiagonalNormal¶
-
class
AutoDiagonalNormal
(model, prefix='auto', init_strategy=<function init_to_median>)[source]¶ Bases:
numpyro.contrib.autoguide.AutoContinuous
This implementation of
AutoContinuous
uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoDiagonalNormal(rng, model, ...) svi = SVI(model, guide, ...)
-
median
(params)[source]¶ Returns the posterior median value of each latent variable.
Parameters: params (dict) – A dict containing parameter values. Returns: A dict mapping sample site name to median tensor. Return type: dict
-
quantiles
(params, quantiles)[source]¶ Returns posterior quantiles each latent variable. Example:
print(guide.quantiles(opt_state, [0.05, 0.5, 0.95]))
Parameters: - opt_state – Current state of the optimizer.
- quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.
Returns: A dict mapping sample site name to a list of quantile values.
Return type:
-
AutoIAFNormal¶
-
class
AutoIAFNormal
(model, prefix='auto', init_strategy=<function init_to_median>, num_flows=3, **arn_kwargs)[source]¶ Bases:
numpyro.contrib.autoguide.AutoContinuous
This implementation of
AutoContinuous
uses a Diagonal Normal distribution transformed via aInverseAutoregressiveTransform
to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoIAFNormal(rng, model, get_params, hidden_dims=[20], skip_connections=True, ...) svi_init, svi_update, _ = svi(model, guide, ...)
Parameters: - rng (jax.random.PRNGKey) – random key to be used as the source of randomness to initialize the guide.
- model (callable) – a generative model.
- prefix (str) – a prefix that will be prefixed to all param internal sites.
- init_strategy (callable) – A per-site initialization function.
- num_flows (int) – the number of flows to be used, defaults to 3.
- **arn_kwargs –
keywords for constructing autoregressive neural networks, which includes:
- hidden_dims (
list[int]
) - the dimensionality of the hidden units per layer. Defaults to[latent_size, latent_size]
. - skip_connections (
bool
) - whether to add skip connections from the input to the output of each flow. Defaults to False. - nonlinearity (
callable
) - the nonlinearity to use in the feedforward network. Defaults tojax.experimental.stax.Relu()
.
- hidden_dims (