Contributed Code¶
Nested Sampling¶
-
class
NestedSampler
(model, *, num_live_points=1000, max_samples=100000, sampler_name='slice', depth=5, num_slices=5, termination_frac=0.01)[source]¶ Bases:
object
(EXPERIMENTAL) A wrapper for jaxns , a nested sampling package based on JAX.
See reference [1] for details on the meaning of each parameter. Please consider citing this reference if you use the nested sampler in your research.
Note
To enumerate over a discrete latent variable, you can add the keyword infer={“enumerate”: “parallel”} to the corresponding sample statement.
Note
To improve the performance, please consider enabling x64 mode at the beginning of your NumPyro program
numpyro.enable_x64()
.References
- JAXNS: a high-performance nested sampling package based on JAX, Joshua G. Albert (https://arxiv.org/abs/2012.15286)
Parameters: - model (callable) – a call with NumPyro primitives
- num_live_points (int) – the number of live points. As a rule-of-thumb, we should allocate around 50 live points per possible mode.
- max_samples (int) – the maximum number of iterations and samples
- sampler_name (str) – either “slice” (default value) or “multi_ellipsoid”
- depth (int) – an integer which determines the maximum number of ellipsoids to construct via hierarchical splitting (typical range: 3 - 9, default to 5)
- num_slices (int) – the number of slice sampling proposals at each sampling step (typical range: 1 - 5, default to 5)
- termination_frac (float) – termination condition (typical range: 0.001 - 0.01) (default to 0.01).
Example
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.nested_sampling import NestedSampler >>> true_coefs = jnp.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(0), (2000, 3)) >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(1)) >>> >>> def model(data, labels): ... coefs = numpyro.sample('coefs', dist.Normal(0, 1).expand([3])) ... intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) ... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), ... obs=labels) >>> >>> ns = NestedSampler(model) >>> ns.run(random.PRNGKey(2), data, labels) >>> samples = ns.get_samples(random.PRNGKey(3), num_samples=1000) >>> assert jnp.mean(jnp.abs(samples['intercept'])) < 0.05 >>> print(jnp.mean(samples['coefs'], axis=0)) [0.93661342 1.95034876 2.86123884]
-
run
(rng_key, *args, **kwargs)[source]¶ Run the nested samplers and collect weighted samples.
Parameters: - rng_key (random.PRNGKey) – Random number generator key to be used for the sampling.
- args – The arguments needed by the model.
- kwargs – The keyword arguments needed by the model.