Distributions¶
Base Distribution¶
Distribution¶
-
class
Distribution
(batch_shape=(), event_shape=(), validate_args=None)[source]¶ Bases:
object
Base class for probability distributions in NumPyro. The design largely follows from
torch.distributions
.Parameters: - batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.
- event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.
- validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
As an example:
>>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> d = dist.Dirichlet(jnp.ones((2, 3, 4))) >>> d.batch_shape (2, 3) >>> d.event_shape (4,)
-
arg_constraints
= {}¶
-
support
= None¶
-
has_enumerate_support
= False¶
-
reparametrized_params
= []¶
-
batch_shape
¶ Returns the shape over which the distribution parameters are batched.
Returns: batch shape of the distribution. Return type: tuple
-
event_shape
¶ Returns the shape of a single sample from the distribution without batching.
Returns: event shape of the distribution. Return type: tuple
-
has_rsample
¶
-
shape
(sample_shape=())[source]¶ The tensor shape of samples from this distribution.
Samples are of shape:
d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
Parameters: sample_shape (tuple) – the size of the iid batch to be drawn from the distribution. Returns: shape of samples. Return type: tuple
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
sample_with_intermediates
(key, sample_shape=())[source]¶ Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:-self.event_shape] Return type: numpy.ndarray
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
to_event
(reinterpreted_batch_ndims=None)[source]¶ Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.
Parameters: reinterpreted_batch_ndims – Number of rightmost batch dims to interpret as event dims. Returns: An instance of Independent distribution. Return type: numpyro.distributions.distribution.Independent
-
enumerate_support
(expand=True)[source]¶ Returns an array with shape len(support) x batch_shape containing all values in the support.
-
expand
(batch_shape)[source]¶ Returns a new
ExpandedDistribution
instance with batch dimensions expanded to batch_shape.Parameters: batch_shape (tuple) – batch shape to expand to. Returns: an instance of ExpandedDistribution. Return type: ExpandedDistribution
-
expand_by
(sample_shape)[source]¶ Expands a distribution by adding
sample_shape
to the left side of itsbatch_shape
. To expand internal dims ofself.batch_shape
from 1 to something larger, useexpand()
instead.Parameters: sample_shape (tuple) – The size of the iid batch to be drawn from the distribution. Returns: An expanded version of this distribution. Return type: ExpandedDistribution
-
mask
(mask)[source]¶ Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions
Distribution.batch_shape
.Parameters: mask (bool or jnp.ndarray) – A boolean or boolean valued array (True includes a site, False excludes a site). Returns: A masked copy of this distribution. Return type: MaskedDistribution
Example:
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.distributions import constraints >>> from numpyro.infer import SVI, Trace_ELBO >>> def model(data, m): ... f = numpyro.sample("latent_fairness", dist.Beta(1, 1)) ... with numpyro.plate("N", data.shape[0]): ... # only take into account the values selected by the mask ... masked_dist = dist.Bernoulli(f).mask(m) ... numpyro.sample("obs", masked_dist, obs=data) >>> def guide(data, m): ... alpha_q = numpyro.param("alpha_q", 5., constraint=constraints.positive) ... beta_q = numpyro.param("beta_q", 5., constraint=constraints.positive) ... numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) >>> data = jnp.concatenate([jnp.ones(5), jnp.zeros(5)]) >>> # select values equal to one >>> masked_array = jnp.where(data == 1, True, False) >>> optimizer = numpyro.optim.Adam(step_size=0.05) >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) >>> svi_result = svi.run(random.PRNGKey(0), 300, data, masked_array) >>> params = svi_result.params >>> # inferred_mean is closer to 1 >>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
-
classmethod
infer_shapes
(*args, **kwargs)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters: - *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
- **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
-
cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.
-
icdf
(q)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.
-
is_discrete
¶
ExpandedDistribution¶
-
class
ExpandedDistribution
(base_dist, batch_shape=())[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {}¶
-
has_enumerate_support
¶ bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
-
has_rsample
¶
-
support
¶
-
sample_with_intermediates
(key, sample_shape=())[source]¶ Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:-self.event_shape] Return type: numpy.ndarray
-
enumerate_support
(expand=True)[source]¶ Returns an array with shape len(support) x batch_shape containing all values in the support.
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
FoldedDistribution¶
-
class
FoldedDistribution
(base_dist, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
Equivalent to
TransformedDistribution(base_dist, AbsTransform())
, but additionally supportslog_prob()
.Parameters: base_dist (Distribution) – A univariate distribution to reflect. -
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
log_prob
(*args, **kwargs)¶
-
ImproperUniform¶
-
class
ImproperUniform
(support, batch_shape, event_shape, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
A helper distribution with zero
log_prob()
over the support domain.Note
sample method is not implemented for this distribution. In autoguide and mcmc, initial parameters for improper sites are derived from init_to_uniform or init_to_value strategies.
Usage:
>>> from numpyro import sample >>> from numpyro.distributions import ImproperUniform, Normal, constraints >>> >>> def model(): ... # ordered vector with length 10 ... x = sample('x', ImproperUniform(constraints.ordered_vector, (), event_shape=(10,))) ... ... # real matrix with shape (3, 4) ... y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4))) ... ... # a shape-(6, 8) batch of length-5 vectors greater than 3 ... z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8), event_shape=(5,)))
If you want to set improper prior over all values greater than a, where a is another random variable, you might use
>>> def model(): ... a = sample('a', Normal(0, 1)) ... x = sample('x', ImproperUniform(constraints.greater_than(a), (), event_shape=()))
or if you want to reparameterize it
>>> from numpyro.distributions import TransformedDistribution, transforms >>> from numpyro.handlers import reparam >>> from numpyro.infer.reparam import TransformReparam >>> >>> def model(): ... a = sample('a', Normal(0, 1)) ... with reparam(config={'x': TransformReparam()}): ... x = sample('x', ... TransformedDistribution(ImproperUniform(constraints.positive, (), ()), ... transforms.AffineTransform(a, 1)))
Parameters: - support (Constraint) – the support of this distribution.
- batch_shape (tuple) – batch shape of this distribution. It is usually safe to set batch_shape=().
- event_shape (tuple) – event shape of this distribution.
-
arg_constraints
= {}¶
-
support
= <numpyro.distributions.constraints._Dependent object>¶
-
log_prob
(*args, **kwargs)¶
Independent¶
-
class
Independent
(base_dist, reinterpreted_batch_ndims, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Reinterprets batch dimensions of a distribution as event dims by shifting the batch-event dim boundary further to the left.
From a practical standpoint, this is useful when changing the result of
log_prob()
. For example, a univariate Normal distribution can be interpreted as a multivariate Normal with diagonal covariance:>>> import numpyro.distributions as dist >>> normal = dist.Normal(jnp.zeros(3), jnp.ones(3)) >>> [normal.batch_shape, normal.event_shape] [(3,), ()] >>> diag_normal = dist.Independent(normal, 1) >>> [diag_normal.batch_shape, diag_normal.event_shape] [(), (3,)]
Parameters: - base_distribution (numpyro.distribution.Distribution) – a distribution instance.
- reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims.
-
arg_constraints
= {}¶
-
support
¶
-
has_enumerate_support
¶ bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
-
reparameterized_params
¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
has_rsample
¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:-self.event_shape] Return type: numpy.ndarray
-
expand
(batch_shape)[source]¶ Returns a new
ExpandedDistribution
instance with batch dimensions expanded to batch_shape.Parameters: batch_shape (tuple) – batch shape to expand to. Returns: an instance of ExpandedDistribution. Return type: ExpandedDistribution
MaskedDistribution¶
-
class
MaskedDistribution
(base_dist, mask)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Masks a distribution by a boolean array that is broadcastable to the distribution’s
Distribution.batch_shape
. In the special casemask is False
, computation oflog_prob()
, is skipped, and constant zero values are returned instead.Parameters: mask (jnp.ndarray or bool) – A boolean or boolean-valued array. -
arg_constraints
= {}¶
-
has_enumerate_support
¶ bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
-
has_rsample
¶
-
support
¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:-self.event_shape] Return type: numpy.ndarray
-
enumerate_support
(expand=True)[source]¶ Returns an array with shape len(support) x batch_shape containing all values in the support.
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
TransformedDistribution¶
-
class
TransformedDistribution
(base_distribution, transforms, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see
LogNormal
andHalfNormal
.Parameters: - base_distribution – the base distribution over which to apply transforms.
- transforms – a single transform or a list of transforms.
- validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
-
arg_constraints
= {}¶
-
has_rsample
¶
-
support
¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
sample_with_intermediates
(key, sample_shape=())[source]¶ Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
Delta¶
-
class
Delta
(v=0.0, log_density=0.0, event_dim=0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'log_density': <numpyro.distributions.constraints._Real object>, 'v': <numpyro.distributions.constraints._Dependent object>}¶
-
reparametrized_params
= ['v', 'log_density']¶
-
support
¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
Unit¶
-
class
Unit
(log_factor, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Trivial nonnormalized distribution representing the unit type.
The unit type has a single value with no data, i.e.
value.size == 0
.This is used for
numpyro.factor()
statements.-
arg_constraints
= {'log_factor': <numpyro.distributions.constraints._Real object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:-self.event_shape] Return type: numpy.ndarray
-
Continuous Distributions¶
Beta¶
-
class
Beta
(concentration1, concentration0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>}¶
-
reparametrized_params
= ['concentration1', 'concentration0']¶
-
support
= <numpyro.distributions.constraints._Interval object>¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
BetaProportion¶
-
class
BetaProportion
(mean, concentration, validate_args=None)[source]¶ Bases:
numpyro.distributions.continuous.Beta
The BetaProportion distribution is a reparameterization of the conventional Beta distribution in terms of a the variate mean and a precision parameter.
- Reference:
- Beta regression for modelling rates and proportion, Ferrari Silvia, and
- Francisco Cribari-Neto. Journal of Applied Statistics 31.7 (2004): 799-815.
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'mean': <numpyro.distributions.constraints._Interval object>}¶
-
reparametrized_params
= ['mean', 'concentration']¶
-
support
= <numpyro.distributions.constraints._Interval object>¶
Cauchy¶
-
class
Cauchy
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
reparametrized_params
= ['loc', 'scale']¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
Chi2¶
-
class
Chi2
(df, validate_args=None)[source]¶ Bases:
numpyro.distributions.continuous.Gamma
-
arg_constraints
= {'df': <numpyro.distributions.constraints._GreaterThan object>}¶
-
reparametrized_params
= ['df']¶
-
Dirichlet¶
-
class
Dirichlet
(concentration, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._IndependentConstraint object>}¶
-
reparametrized_params
= ['concentration']¶
-
support
= <numpyro.distributions.constraints._Simplex object>¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
static
infer_shapes
(concentration)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters: - *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
- **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
-
Exponential¶
-
class
Exponential
(rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
reparametrized_params
= ['rate']¶
-
arg_constraints
= {'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
Gamma¶
-
class
Gamma
(concentration, rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
reparametrized_params
= ['concentration', 'rate']¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
Gumbel¶
-
class
Gumbel
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
reparametrized_params
= ['loc', 'scale']¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
GaussianRandomWalk¶
-
class
GaussianRandomWalk
(scale=1.0, num_steps=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
reparametrized_params
= ['scale']¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
HalfCauchy¶
-
class
HalfCauchy
(scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
reparametrized_params
= ['scale']¶
-
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
arg_constraints
= {'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.
-
icdf
(q)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
HalfNormal¶
-
class
HalfNormal
(scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
reparametrized_params
= ['scale']¶
-
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
arg_constraints
= {'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.
-
icdf
(q)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
InverseGamma¶
-
class
InverseGamma
(concentration, rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
Note
We keep the same notation rate as in Pyro but it plays the role of scale parameter of InverseGamma in literatures (e.g. wikipedia: https://en.wikipedia.org/wiki/Inverse-gamma_distribution)
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
-
reparametrized_params
= ['concentration', 'rate']¶
-
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
Laplace¶
-
class
Laplace
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
reparametrized_params
= ['loc', 'scale']¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
LKJ¶
-
class
LKJ
(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
LKJ distribution for correlation matrices. The distribution is controlled by
concentration
parameter \(\eta\) to make the probability of the correlation matrix \(M\) propotional to \(\det(M)^{\eta - 1}\). Because of that, whenconcentration == 1
, we have a uniform distribution over correlation matrices.When
concentration > 1
, the distribution favors samples with large large determinent. This is useful when we know a priori that the underlying variables are not correlated.When
concentration < 1
, the distribution favors samples with small determinent. This is useful when we know a priori that some underlying variables are correlated.Sample code for using LKJ in the context of multivariate normal sample:
def model(y): # y has dimension N x d d = y.shape[1] N = y.shape[0] # Vector of variances for each of the d variables theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d))) concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices corr_mat = numpyro.sample("corr_mat", dist.LKJ(d, concentration)) sigma = jnp.sqrt(theta) # we can also use a faster formula `cov_mat = jnp.outer(theta, theta) * corr_mat` cov_mat = jnp.matmul(jnp.matmul(jnp.diag(sigma), corr_mat), jnp.diag(sigma)) # Vector of expectations mu = jnp.zeros(d) with numpyro.plate("observations", N): obs = numpyro.sample("obs", dist.MultivariateNormal(mu, covariance_matrix=cov_mat), obs=y) return obs
Parameters: - dimension (int) – dimension of the matrices
- concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
- sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.
References
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>}¶
-
reparametrized_params
= ['concentration']¶
-
support
= <numpyro.distributions.constraints._CorrMatrix object>¶
-
mean
¶ Mean of the distribution.
LKJCholesky¶
-
class
LKJCholesky
(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by
concentration
parameter \(\eta\) to make the probability of the correlation matrix \(M\) generated from a Cholesky factor propotional to \(\det(M)^{\eta - 1}\). Because of that, whenconcentration == 1
, we have a uniform distribution over Cholesky factors of correlation matrices.When
concentration > 1
, the distribution favors samples with large diagonal entries (hence large determinent). This is useful when we know a priori that the underlying variables are not correlated.When
concentration < 1
, the distribution favors samples with small diagonal entries (hence small determinent). This is useful when we know a priori that some underlying variables are correlated.Sample code for using LKJCholesky in the context of multivariate normal sample:
def model(y): # y has dimension N x d d = y.shape[1] N = y.shape[0] # Vector of variances for each of the d variables theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d))) # Lower cholesky factor of a correlation matrix concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration)) # Lower cholesky factor of the covariance matrix sigma = jnp.sqrt(theta) # we can also use a faster formula `L_Omega = sigma[..., None] * L_omega` L_Omega = jnp.matmul(jnp.diag(sigma), L_omega) # Vector of expectations mu = jnp.zeros(d) with numpyro.plate("observations", N): obs = numpyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y) return obs
Parameters: - dimension (int) – dimension of the matrices
- concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
- sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.
References
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>}¶
-
reparametrized_params
= ['concentration']¶
-
support
= <numpyro.distributions.constraints._CorrCholesky object>¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
LogNormal¶
-
class
LogNormal
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
reparametrized_params
= ['loc', 'scale']¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
Logistic¶
-
class
Logistic
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
reparametrized_params
= ['loc', 'scale']¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
MultivariateNormal¶
-
class
MultivariateNormal
(loc=0.0, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'covariance_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'loc': <numpyro.distributions.constraints._IndependentConstraint object>, 'precision_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'scale_tril': <numpyro.distributions.constraints._LowerCholesky object>}¶
-
support
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
reparametrized_params
= ['loc', 'covariance_matrix', 'precision_matrix', 'scale_tril']¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
static
infer_shapes
(loc=(), covariance_matrix=None, precision_matrix=None, scale_tril=None)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters: - *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
- **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
-
LowRankMultivariateNormal¶
-
class
LowRankMultivariateNormal
(loc, cov_factor, cov_diag, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'cov_diag': <numpyro.distributions.constraints._IndependentConstraint object>, 'cov_factor': <numpyro.distributions.constraints._IndependentConstraint object>, 'loc': <numpyro.distributions.constraints._IndependentConstraint object>}¶
-
support
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
reparametrized_params
= ['loc', 'cov_factor', 'cov_diag']¶
-
mean
¶ Mean of the distribution.
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
static
infer_shapes
(loc, cov_factor, cov_diag)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters: - *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
- **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
-
Normal¶
-
class
Normal
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
reparametrized_params
= ['loc', 'scale']¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.
-
icdf
(q)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
Pareto¶
-
class
Pareto
(scale, alpha, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
-
arg_constraints
= {'alpha': <numpyro.distributions.constraints._GreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
reparametrized_params
= ['scale', 'alpha']¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
support
¶
-
cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.
-
SoftLaplace¶
-
class
SoftLaplace
(loc, scale, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Smooth distribution with Laplace-like tail behavior.
This distribution corresponds to the log-convex density:
z = (value - loc) / scale log_prob = log(2 / pi) - log(scale) - logaddexp(z, -z)
Like the Laplace density, this density has the heaviest possible tails (asymptotically) while still being log-convex. Unlike the Laplace distribution, this distribution is infinitely differentiable everywhere, and is thus suitable for HMC and Laplace approximation.
Parameters: - loc – Location parameter.
- scale – Scale parameter.
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
reparametrized_params
= ['loc', 'scale']¶
-
log_prob
(*args, **kwargs)¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.
-
icdf
(value)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
StudentT¶
-
class
StudentT
(df, loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'df': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
reparametrized_params
= ['df', 'loc', 'scale']¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
Uniform¶
-
class
Uniform
(low=0.0, high=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'high': <numpyro.distributions.constraints._Dependent object>, 'low': <numpyro.distributions.constraints._Dependent object>}¶
-
reparametrized_params
= ['low', 'high']¶
-
support
¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.
-
icdf
(value)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
static
infer_shapes
(low=(), high=())[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters: - *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
- **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
-
Weibull¶
-
class
Weibull
(scale, concentration, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
reparametrized_params
= ['scale', 'concentration']¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
Discrete Distributions¶
BernoulliLogits¶
-
class
BernoulliLogits
(logits=None, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'logits': <numpyro.distributions.constraints._Real object>}¶
-
support
= <numpyro.distributions.constraints._Boolean object>¶
-
has_enumerate_support
= True¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
BernoulliProbs¶
-
class
BernoulliProbs
(probs, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'probs': <numpyro.distributions.constraints._Interval object>}¶
-
support
= <numpyro.distributions.constraints._Boolean object>¶
-
has_enumerate_support
= True¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
BetaBinomial¶
-
class
BetaBinomial
(concentration1, concentration0, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a beta-binomial pair. The probability of success (
probs
for theBinomial
distribution) is unknown and randomly drawn from aBeta
distribution prior to a certain number of Bernoulli trials given bytotal_count
.Parameters: - concentration1 (numpy.ndarray) – 1st concentration parameter (alpha) for the Beta distribution.
- concentration0 (numpy.ndarray) – 2nd concentration parameter (beta) for the Beta distribution.
- total_count (numpy.ndarray) – number of Bernoulli trials.
-
arg_constraints
= {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
-
has_enumerate_support
= True¶
-
enumerate_support
(expand=True)¶ Returns an array with shape len(support) x batch_shape containing all values in the support.
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
support
¶
BinomialLogits¶
-
class
BinomialLogits
(logits, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
-
has_enumerate_support
= True¶
-
enumerate_support
(expand=True)¶ Returns an array with shape len(support) x batch_shape containing all values in the support.
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
support
¶
-
BinomialProbs¶
-
class
BinomialProbs
(probs, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'probs': <numpyro.distributions.constraints._Interval object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
-
has_enumerate_support
= True¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
support
¶
-
CategoricalLogits¶
-
class
CategoricalLogits
(logits, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'logits': <numpyro.distributions.constraints._IndependentConstraint object>}¶
-
has_enumerate_support
= True¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
support
¶
-
CategoricalProbs¶
-
class
CategoricalProbs
(probs, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'probs': <numpyro.distributions.constraints._Simplex object>}¶
-
has_enumerate_support
= True¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
support
¶
-
DirichletMultinomial¶
-
class
DirichletMultinomial
(concentration, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a dirichlet-multinomial pair. The probability of classes (
probs
for theMultinomial
distribution) is unknown and randomly drawn from aDirichlet
distribution prior to a certain number of Categorical trials given bytotal_count
.Parameters: - concentration (numpy.ndarray) – concentration parameter (alpha) for the Dirichlet distribution.
- total_count (numpy.ndarray) – number of Categorical trials.
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._IndependentConstraint object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
support
¶
-
static
infer_shapes
(concentration, total_count=())[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters: - *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
- **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
GammaPoisson¶
-
class
GammaPoisson
(concentration, rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a gamma-poisson pair, also referred to as a gamma-poisson mixture. The
rate
parameter for thePoisson
distribution is unknown and randomly drawn from aGamma
distribution.Parameters: - concentration (numpy.ndarray) – shape parameter (alpha) of the Gamma distribution.
- rate (numpy.ndarray) – rate parameter (beta) for the Gamma distribution.
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
GeometricLogits¶
-
class
GeometricLogits
(logits, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'logits': <numpyro.distributions.constraints._Real object>}¶
-
support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
GeometricProbs¶
-
class
GeometricProbs
(probs, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'probs': <numpyro.distributions.constraints._Interval object>}¶
-
support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
MultinomialLogits¶
-
class
MultinomialLogits
(logits, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'logits': <numpyro.distributions.constraints._IndependentConstraint object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
support
¶
-
static
infer_shapes
(logits, total_count)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters: - *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
- **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
-
MultinomialProbs¶
-
class
MultinomialProbs
(probs, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'probs': <numpyro.distributions.constraints._Simplex object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
support
¶
-
static
infer_shapes
(probs, total_count)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters: - *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
- **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
-
OrderedLogistic¶
-
class
OrderedLogistic
(predictor, cutpoints, validate_args=None)[source]¶ Bases:
numpyro.distributions.discrete.CategoricalProbs
A categorical distribution with ordered outcomes.
References:
- Stan Functions Reference, v2.20 section 12.6, Stan Development Team
Parameters: - predictor (numpy.ndarray) – prediction in real domain; typically this is output of a linear model.
- cutpoints (numpy.ndarray) – positions in real domain to separate categories.
-
arg_constraints
= {'cutpoints': <numpyro.distributions.constraints._OrderedVector object>, 'predictor': <numpyro.distributions.constraints._Real object>}¶
-
static
infer_shapes
(predictor, cutpoints)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters: - *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
- **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
NegativeBinomial¶
NegativeBinomialLogits¶
-
class
NegativeBinomialLogits
(total_count, logits, validate_args=None)[source]¶ Bases:
numpyro.distributions.conjugate.GammaPoisson
-
arg_constraints
= {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
-
log_prob
(*args, **kwargs)¶
-
NegativeBinomialProbs¶
-
class
NegativeBinomialProbs
(total_count, probs, validate_args=None)[source]¶ Bases:
numpyro.distributions.conjugate.GammaPoisson
-
arg_constraints
= {'probs': <numpyro.distributions.constraints._Interval object>, 'total_count': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
-
NegativeBinomial2¶
-
class
NegativeBinomial2
(mean, concentration, validate_args=None)[source]¶ Bases:
numpyro.distributions.conjugate.GammaPoisson
Another parameterization of GammaPoisson with rate is replaced by mean.
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'mean': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
-
Poisson¶
-
class
Poisson
(rate, *, is_sparse=False, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
PRNGIdentity¶
-
class
PRNGIdentity
[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Distribution over
PRNGKey()
. This can be used to draw a batch ofPRNGKey()
using theseed
handler. Only sample method is supported.-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
ZeroInflatedDistribution¶
-
ZeroInflatedDistribution
(base_dist, *, gate=None, gate_logits=None, validate_args=None)[source]¶ Generic Zero Inflated distribution.
Parameters: - base_dist (Distribution) – the base distribution.
- gate (numpy.ndarray) – probability of extra zeros given via a Bernoulli distribution.
- gate_logits (numpy.ndarray) – logits of extra zeros given via a Bernoulli distribution.
ZeroInflatedPoisson¶
-
class
ZeroInflatedPoisson
(gate, rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.discrete.ZeroInflatedProbs
A Zero Inflated Poisson distribution.
Parameters: - gate (numpy.ndarray) – probability of extra zeros.
- rate (numpy.ndarray) – rate of Poisson distribution.
-
arg_constraints
= {'gate': <numpyro.distributions.constraints._Interval object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
Directional Distributions¶
ProjectedNormal¶
-
class
ProjectedNormal
(concentration, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Projected isotropic normal distribution of arbitrary dimension.
This distribution over directional data is qualitatively similar to the von Mises and von Mises-Fisher distributions, but permits tractable variational inference via reparametrized gradients.
To use this distribution with autoguides and HMC, use
handlers.reparam
with aProjectedNormalReparam
reparametrizer in the model, e.g.:@handlers.reparam(config={"direction": ProjectedNormalReparam()}) def model(): direction = numpyro.sample("direction", ProjectedNormal(zeros(3))) ...
Note
This implements
log_prob()
only for dimensions {2,3}.- [1] D. Hernandez-Stumpfhauser, F.J. Breidt, M.J. van der Woerd (2017)
- “The General Projected Normal Distribution of Arbitrary Dimension: Modeling and Bayesian Inference” https://projecteuclid.org/euclid.ba/1453211962
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._IndependentConstraint object>}¶
-
reparametrized_params
= ['concentration']¶
-
support
= <numpyro.distributions.constraints._Sphere object>¶
-
mean
¶ Note this is the mean in the sense of a centroid in the submanifold that minimizes expected squared geodesic distance.
-
mode
¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:-self.event_shape] Return type: numpy.ndarray
-
static
infer_shapes
(concentration)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters: - *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
- **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
VonMises¶
-
class
VonMises
(loc, concentration, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>}¶
-
reparametrized_params
= ['loc']¶
-
support
= <numpyro.distributions.constraints._Interval object>¶
-
sample
(key, sample_shape=())[source]¶ Generate sample from von Mises distribution
Parameters: - key – random number generator key
- sample_shape – shape of samples
Returns: samples from von Mises
-
log_prob
(*args, **kwargs)¶
-
mean
¶ Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]
-
variance
¶ Computes circular variance of distribution
-
Truncated Distributions¶
LeftTruncatedDistribution¶
-
class
LeftTruncatedDistribution
(base_dist, low=0.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'low': <numpyro.distributions.constraints._Real object>}¶
-
reparametrized_params
= ['low']¶
-
supported_types
= (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)¶
-
support
¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
RightTruncatedDistribution¶
-
class
RightTruncatedDistribution
(base_dist, high=0.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'high': <numpyro.distributions.constraints._Real object>}¶
-
reparametrized_params
= ['high']¶
-
supported_types
= (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)¶
-
support
¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
TruncatedCauchy¶
-
class
TruncatedCauchy
(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.truncated.LeftTruncatedDistribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
reparametrized_params
= ['low', 'loc', 'scale']¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
TruncatedDistribution¶
-
TruncatedDistribution
(base_dist, low=None, high=None, validate_args=None)[source]¶ A function to generate a truncated distribution.
Parameters: - base_dist – The base distribution to be truncated. This should be a univariate distribution. Currently, only the following distributions are supported: Cauchy, Laplace, Logistic, Normal, and StudentT.
- low – the value which is used to truncate the base distribution from below. Setting this parameter to None to not truncate from below.
- high – the value which is used to truncate the base distribution from above. Setting this parameter to None to not truncate from above.
TruncatedNormal¶
-
class
TruncatedNormal
(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.truncated.LeftTruncatedDistribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
reparametrized_params
= ['low', 'loc', 'scale']¶
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
TruncatedPolyaGamma¶
-
class
TruncatedPolyaGamma
(batch_shape=(), validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
truncation_point
= 2.5¶
-
num_log_prob_terms
= 7¶
-
num_gamma_variates
= 8¶
-
arg_constraints
= {}¶
-
support
= <numpyro.distributions.constraints._Interval object>¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
TwoSidedTruncatedDistribution¶
-
class
TwoSidedTruncatedDistribution
(base_dist, low=0.0, high=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'high': <numpyro.distributions.constraints._Dependent object>, 'low': <numpyro.distributions.constraints._Dependent object>}¶
-
reparametrized_params
= ['low', 'high']¶
-
supported_types
= (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)¶
-
support
¶
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(*args, **kwargs)¶
-
TensorFlow Distributions¶
Thin wrappers around TensorFlow Probability (TFP) distributions. For details on the TFP distribution interface, see its Distribution docs.
BijectorConstraint¶
BijectorTransform¶
TFPDistribution¶
-
class
TFPDistribution
(batch_shape=(), event_shape=(), validate_args=None)[source]¶ A thin wrapper for TensorFlow Probability (TFP) distributions. The constructor has the same signature as the corresponding TFP distribution.
This class can be used to convert a TFP distribution to a NumPyro-compatible one as follows:
d = TFPDistribution[tfd.Normal](0, 1)
Autoregressive¶
-
class
Autoregressive
(distribution_fn, sample0=None, num_steps=None, validate_args=False, allow_nan_stats=True, name='Autoregressive')¶ Wraps tensorflow_probability.substrates.jax.distributions.autoregressive.Autoregressive with
TFPDistribution
.
BatchBroadcast¶
-
class
BatchBroadcast
(distribution, with_shape=None, *, to_shape=None, validate_args=False, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.batch_broadcast.BatchBroadcast with
TFPDistribution
.
BatchConcat¶
-
class
BatchConcat
(distributions, axis, validate_args=False, allow_nan_stats=True, name='BatchConcat')¶ Wraps tensorflow_probability.substrates.jax.distributions.batch_concat.BatchConcat with
TFPDistribution
.
BatchReshape¶
-
class
BatchReshape
(distribution, batch_shape, validate_args=False, allow_nan_stats=True, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.batch_reshape.BatchReshape with
TFPDistribution
.
Bates¶
-
class
Bates
(total_count, low=0.0, high=1.0, validate_args=False, allow_nan_stats=True, name='Bates')¶ Wraps tensorflow_probability.substrates.jax.distributions.bates.Bates with
TFPDistribution
.
Bernoulli¶
-
class
Bernoulli
(logits=None, probs=None, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='Bernoulli')¶ Wraps tensorflow_probability.substrates.jax.distributions.bernoulli.Bernoulli with
TFPDistribution
.
Beta¶
-
class
Beta
(concentration1, concentration0, validate_args=False, allow_nan_stats=True, force_probs_to_zero_outside_support=False, name='Beta')¶ Wraps tensorflow_probability.substrates.jax.distributions.beta.Beta with
TFPDistribution
.
BetaBinomial¶
-
class
BetaBinomial
(total_count, concentration1, concentration0, validate_args=False, allow_nan_stats=True, name='BetaBinomial')¶ Wraps tensorflow_probability.substrates.jax.distributions.beta_binomial.BetaBinomial with
TFPDistribution
.
BetaQuotient¶
-
class
BetaQuotient
(concentration1_numerator, concentration0_numerator, concentration1_denominator, concentration0_denominator, validate_args=False, allow_nan_stats=True, name='BetaQuotient')¶ Wraps tensorflow_probability.substrates.jax.distributions.beta_quotient.BetaQuotient with
TFPDistribution
.
Binomial¶
-
class
Binomial
(total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.binomial.Binomial with
TFPDistribution
.
Blockwise¶
-
class
Blockwise
(distributions, dtype_override=None, validate_args=False, allow_nan_stats=False, name='Blockwise')¶ Wraps tensorflow_probability.substrates.jax.distributions.blockwise.Blockwise with
TFPDistribution
.
Categorical¶
-
class
Categorical
(logits=None, probs=None, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='Categorical')¶ Wraps tensorflow_probability.substrates.jax.distributions.categorical.Categorical with
TFPDistribution
.
Cauchy¶
-
class
Cauchy
(loc, scale, validate_args=False, allow_nan_stats=True, name='Cauchy')¶ Wraps tensorflow_probability.substrates.jax.distributions.cauchy.Cauchy with
TFPDistribution
.
Chi¶
-
class
Chi
(df, validate_args=False, allow_nan_stats=True, name='Chi')¶ Wraps tensorflow_probability.substrates.jax.distributions.chi.Chi with
TFPDistribution
.
Chi2¶
-
class
Chi2
(df, validate_args=False, allow_nan_stats=True, name='Chi2')¶ Wraps tensorflow_probability.substrates.jax.distributions.chi2.Chi2 with
TFPDistribution
.
CholeskyLKJ¶
-
class
CholeskyLKJ
(dimension, concentration, validate_args=False, allow_nan_stats=True, name='CholeskyLKJ')¶ Wraps tensorflow_probability.substrates.jax.distributions.cholesky_lkj.CholeskyLKJ with
TFPDistribution
.
ContinuousBernoulli¶
-
class
ContinuousBernoulli
(logits=None, probs=None, lims=(0.499, 0.501), dtype=<class 'jax._src.numpy.lax_numpy.float32'>, validate_args=False, allow_nan_stats=True, name='ContinuousBernoulli')¶ Wraps tensorflow_probability.substrates.jax.distributions.continuous_bernoulli.ContinuousBernoulli with
TFPDistribution
.
DeterminantalPointProcess¶
-
class
DeterminantalPointProcess
(eigenvalues, eigenvectors, validate_args=False, allow_nan_stats=False, name='DeterminantalPointProcess')¶ Wraps tensorflow_probability.substrates.jax.distributions.dpp.DeterminantalPointProcess with
TFPDistribution
.
Deterministic¶
-
class
Deterministic
(loc, atol=None, rtol=None, validate_args=False, allow_nan_stats=True, name='Deterministic')¶ Wraps tensorflow_probability.substrates.jax.distributions.deterministic.Deterministic with
TFPDistribution
.
Dirichlet¶
-
class
Dirichlet
(concentration, validate_args=False, allow_nan_stats=True, force_probs_to_zero_outside_support=False, name='Dirichlet')¶ Wraps tensorflow_probability.substrates.jax.distributions.dirichlet.Dirichlet with
TFPDistribution
.
DirichletMultinomial¶
-
class
DirichletMultinomial
(total_count, concentration, validate_args=False, allow_nan_stats=True, name='DirichletMultinomial')¶ Wraps tensorflow_probability.substrates.jax.distributions.dirichlet_multinomial.DirichletMultinomial with
TFPDistribution
.
DoublesidedMaxwell¶
-
class
DoublesidedMaxwell
(loc, scale, validate_args=False, allow_nan_stats=True, name='doublesided_maxwell')¶ Wraps tensorflow_probability.substrates.jax.distributions.doublesided_maxwell.DoublesidedMaxwell with
TFPDistribution
.
Empirical¶
-
class
Empirical
(samples, event_ndims=0, validate_args=False, allow_nan_stats=True, name='Empirical')¶ Wraps tensorflow_probability.substrates.jax.distributions.empirical.Empirical with
TFPDistribution
.
ExpGamma¶
-
class
ExpGamma
(concentration, rate=None, log_rate=None, validate_args=False, allow_nan_stats=True, name='ExpGamma')¶ Wraps tensorflow_probability.substrates.jax.distributions.exp_gamma.ExpGamma with
TFPDistribution
.
ExpInverseGamma¶
-
class
ExpInverseGamma
(concentration, scale=None, log_scale=None, validate_args=False, allow_nan_stats=True, name='ExpInverseGamma')¶ Wraps tensorflow_probability.substrates.jax.distributions.exp_gamma.ExpInverseGamma with
TFPDistribution
.
ExpRelaxedOneHotCategorical¶
-
class
ExpRelaxedOneHotCategorical
(temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='ExpRelaxedOneHotCategorical')¶ Wraps tensorflow_probability.substrates.jax.distributions.relaxed_onehot_categorical.ExpRelaxedOneHotCategorical with
TFPDistribution
.
Exponential¶
-
class
Exponential
(rate, force_probs_to_zero_outside_support=False, validate_args=False, allow_nan_stats=True, name='Exponential')¶ Wraps tensorflow_probability.substrates.jax.distributions.exponential.Exponential with
TFPDistribution
.
ExponentiallyModifiedGaussian¶
-
class
ExponentiallyModifiedGaussian
(loc, scale, rate, validate_args=False, allow_nan_stats=True, name='ExponentiallyModifiedGaussian')¶ Wraps tensorflow_probability.substrates.jax.distributions.exponentially_modified_gaussian.ExponentiallyModifiedGaussian with
TFPDistribution
.
FiniteDiscrete¶
-
class
FiniteDiscrete
(outcomes, logits=None, probs=None, rtol=None, atol=None, validate_args=False, allow_nan_stats=True, name='FiniteDiscrete')¶ Wraps tensorflow_probability.substrates.jax.distributions.finite_discrete.FiniteDiscrete with
TFPDistribution
.
Gamma¶
-
class
Gamma
(concentration, rate=None, log_rate=None, validate_args=False, allow_nan_stats=True, force_probs_to_zero_outside_support=False, name='Gamma')¶ Wraps tensorflow_probability.substrates.jax.distributions.gamma.Gamma with
TFPDistribution
.
GammaGamma¶
-
class
GammaGamma
(concentration, mixing_concentration, mixing_rate, validate_args=False, allow_nan_stats=True, name='GammaGamma')¶ Wraps tensorflow_probability.substrates.jax.distributions.gamma_gamma.GammaGamma with
TFPDistribution
.
GaussianProcess¶
-
class
GaussianProcess
(kernel, index_points=None, mean_fn=None, observation_noise_variance=0.0, marginal_fn=None, jitter=1e-06, validate_args=False, allow_nan_stats=False, name='GaussianProcess')¶ Wraps tensorflow_probability.substrates.jax.distributions.gaussian_process.GaussianProcess with
TFPDistribution
.
GaussianProcessRegressionModel¶
-
class
GaussianProcessRegressionModel
(kernel, index_points=None, observation_index_points=None, observations=None, observation_noise_variance=0.0, predictive_noise_variance=None, mean_fn=None, jitter=1e-06, validate_args=False, allow_nan_stats=False, name='GaussianProcessRegressionModel')¶ Wraps tensorflow_probability.substrates.jax.distributions.gaussian_process_regression_model.GaussianProcessRegressionModel with
TFPDistribution
.
GeneralizedExtremeValue¶
-
class
GeneralizedExtremeValue
(loc, scale, concentration, validate_args=False, allow_nan_stats=True, name='GeneralizedExtremeValue')¶ Wraps tensorflow_probability.substrates.jax.distributions.gev.GeneralizedExtremeValue with
TFPDistribution
.
GeneralizedNormal¶
-
class
GeneralizedNormal
(loc, scale, power, validate_args=False, allow_nan_stats=True, name='GeneralizedNormal')¶ Wraps tensorflow_probability.substrates.jax.distributions.generalized_normal.GeneralizedNormal with
TFPDistribution
.
GeneralizedPareto¶
-
class
GeneralizedPareto
(loc, scale, concentration, validate_args=False, allow_nan_stats=True, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.generalized_pareto.GeneralizedPareto with
TFPDistribution
.
Geometric¶
-
class
Geometric
(logits=None, probs=None, force_probs_to_zero_outside_support=False, validate_args=False, allow_nan_stats=True, name='Geometric')¶ Wraps tensorflow_probability.substrates.jax.distributions.geometric.Geometric with
TFPDistribution
.
Gumbel¶
-
class
Gumbel
(loc, scale, validate_args=False, allow_nan_stats=True, name='Gumbel')¶ Wraps tensorflow_probability.substrates.jax.distributions.gumbel.Gumbel with
TFPDistribution
.
HalfCauchy¶
-
class
HalfCauchy
(loc, scale, validate_args=False, allow_nan_stats=True, name='HalfCauchy')¶ Wraps tensorflow_probability.substrates.jax.distributions.half_cauchy.HalfCauchy with
TFPDistribution
.
HalfNormal¶
-
class
HalfNormal
(scale, validate_args=False, allow_nan_stats=True, name='HalfNormal')¶ Wraps tensorflow_probability.substrates.jax.distributions.half_normal.HalfNormal with
TFPDistribution
.
HalfStudentT¶
-
class
HalfStudentT
(df, loc, scale, validate_args=False, allow_nan_stats=True, name='HalfStudentT')¶ Wraps tensorflow_probability.substrates.jax.distributions.half_student_t.HalfStudentT with
TFPDistribution
.
Horseshoe¶
-
class
Horseshoe
(scale, validate_args=False, allow_nan_stats=True, name='Horseshoe')¶ Wraps tensorflow_probability.substrates.jax.distributions.horseshoe.Horseshoe with
TFPDistribution
.
Independent¶
-
class
Independent
(distribution, reinterpreted_batch_ndims=None, validate_args=False, experimental_use_kahan_sum=False, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.independent.Independent with
TFPDistribution
.
InverseGamma¶
-
class
InverseGamma
(concentration, scale=None, validate_args=False, allow_nan_stats=True, name='InverseGamma')¶ Wraps tensorflow_probability.substrates.jax.distributions.inverse_gamma.InverseGamma with
TFPDistribution
.
InverseGaussian¶
-
class
InverseGaussian
(loc, concentration, validate_args=False, allow_nan_stats=True, name='InverseGaussian')¶ Wraps tensorflow_probability.substrates.jax.distributions.inverse_gaussian.InverseGaussian with
TFPDistribution
.
JohnsonSU¶
-
class
JohnsonSU
(skewness, tailweight, loc, scale, validate_args=False, allow_nan_stats=True, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.johnson_su.JohnsonSU with
TFPDistribution
.
JointDistribution¶
-
class
JointDistribution
(dtype, reparameterization_type, validate_args, allow_nan_stats, parameters=None, graph_parents=None, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution.JointDistribution with
TFPDistribution
.
JointDistributionCoroutine¶
-
class
JointDistributionCoroutine
(model, sample_dtype=None, validate_args=False, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution_coroutine.JointDistributionCoroutine with
TFPDistribution
.
JointDistributionCoroutineAutoBatched¶
-
class
JointDistributionCoroutineAutoBatched
(model, sample_dtype=None, batch_ndims=0, use_vectorized_map=True, validate_args=False, experimental_use_kahan_sum=False, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched with
TFPDistribution
.
JointDistributionNamed¶
-
class
JointDistributionNamed
(model, validate_args=False, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution_named.JointDistributionNamed with
TFPDistribution
.
JointDistributionNamedAutoBatched¶
-
class
JointDistributionNamedAutoBatched
(model, batch_ndims=0, use_vectorized_map=True, validate_args=False, experimental_use_kahan_sum=False, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution_auto_batched.JointDistributionNamedAutoBatched with
TFPDistribution
.
JointDistributionSequential¶
-
class
JointDistributionSequential
(model, validate_args=False, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution_sequential.JointDistributionSequential with
TFPDistribution
.
JointDistributionSequentialAutoBatched¶
-
class
JointDistributionSequentialAutoBatched
(model, batch_ndims=0, use_vectorized_map=True, validate_args=False, experimental_use_kahan_sum=False, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution_auto_batched.JointDistributionSequentialAutoBatched with
TFPDistribution
.
Kumaraswamy¶
-
class
Kumaraswamy
(concentration1=1.0, concentration0=1.0, validate_args=False, allow_nan_stats=True, name='Kumaraswamy')¶ Wraps tensorflow_probability.substrates.jax.distributions.kumaraswamy.Kumaraswamy with
TFPDistribution
.
LKJ¶
-
class
LKJ
(dimension, concentration, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name='LKJ')¶ Wraps tensorflow_probability.substrates.jax.distributions.lkj.LKJ with
TFPDistribution
.
LambertWDistribution¶
-
class
LambertWDistribution
(distribution, shift, scale, tailweight=None, validate_args=False, allow_nan_stats=True, name='LambertWDistribution')¶ Wraps tensorflow_probability.substrates.jax.distributions.lambertw_f.LambertWDistribution with
TFPDistribution
.
LambertWNormal¶
-
class
LambertWNormal
(loc, scale, tailweight=None, validate_args=False, allow_nan_stats=True, name='LambertWNormal')¶ Wraps tensorflow_probability.substrates.jax.distributions.lambertw_f.LambertWNormal with
TFPDistribution
.
Laplace¶
-
class
Laplace
(loc, scale, validate_args=False, allow_nan_stats=True, name='Laplace')¶ Wraps tensorflow_probability.substrates.jax.distributions.laplace.Laplace with
TFPDistribution
.
LinearGaussianStateSpaceModel¶
-
class
LinearGaussianStateSpaceModel
(num_timesteps, transition_matrix, transition_noise, observation_matrix, observation_noise, initial_state_prior, initial_step=0, mask=None, experimental_parallelize=False, validate_args=False, allow_nan_stats=True, name='LinearGaussianStateSpaceModel')¶ Wraps tensorflow_probability.substrates.jax.distributions.linear_gaussian_ssm.LinearGaussianStateSpaceModel with
TFPDistribution
.
LogLogistic¶
-
class
LogLogistic
(loc, scale, validate_args=False, allow_nan_stats=True, name='LogLogistic')¶ Wraps tensorflow_probability.substrates.jax.distributions.loglogistic.LogLogistic with
TFPDistribution
.
LogNormal¶
-
class
LogNormal
(loc, scale, validate_args=False, allow_nan_stats=True, name='LogNormal')¶ Wraps tensorflow_probability.substrates.jax.distributions.lognormal.LogNormal with
TFPDistribution
.
Logistic¶
-
class
Logistic
(loc, scale, validate_args=False, allow_nan_stats=True, name='Logistic')¶ Wraps tensorflow_probability.substrates.jax.distributions.logistic.Logistic with
TFPDistribution
.
LogitNormal¶
-
class
LogitNormal
(loc, scale, num_probit_terms_approx=2, validate_args=False, allow_nan_stats=True, name='LogitNormal')¶ Wraps tensorflow_probability.substrates.jax.distributions.logitnormal.LogitNormal with
TFPDistribution
.
Masked¶
-
class
Masked
(distribution, validity_mask, safe_sample_fn=<function _fixed_sample>, validate_args=False, allow_nan_stats=True, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.masked.Masked with
TFPDistribution
.
MatrixNormalLinearOperator¶
-
class
MatrixNormalLinearOperator
(loc, scale_row, scale_column, validate_args=False, allow_nan_stats=True, name='MatrixNormalLinearOperator')¶ Wraps tensorflow_probability.substrates.jax.distributions.matrix_normal_linear_operator.MatrixNormalLinearOperator with
TFPDistribution
.
MatrixTLinearOperator¶
-
class
MatrixTLinearOperator
(df, loc, scale_row, scale_column, validate_args=False, allow_nan_stats=True, name='MatrixTLinearOperator')¶ Wraps tensorflow_probability.substrates.jax.distributions.matrix_t_linear_operator.MatrixTLinearOperator with
TFPDistribution
.
MixtureSameFamily¶
-
class
MixtureSameFamily
(mixture_distribution, components_distribution, reparameterize=False, validate_args=False, allow_nan_stats=True, name='MixtureSameFamily')¶ Wraps tensorflow_probability.substrates.jax.distributions.mixture_same_family.MixtureSameFamily with
TFPDistribution
.
Moyal¶
-
class
Moyal
(loc, scale, validate_args=False, allow_nan_stats=True, name='Moyal')¶ Wraps tensorflow_probability.substrates.jax.distributions.moyal.Moyal with
TFPDistribution
.
Multinomial¶
-
class
Multinomial
(total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='Multinomial')¶ Wraps tensorflow_probability.substrates.jax.distributions.multinomial.Multinomial with
TFPDistribution
.
MultivariateNormalDiag¶
-
class
MultivariateNormalDiag
(loc=None, scale_diag=None, scale_identity_multiplier=None, validate_args=False, allow_nan_stats=True, experimental_use_kahan_sum=False, name='MultivariateNormalDiag')¶ Wraps tensorflow_probability.substrates.jax.distributions.mvn_diag.MultivariateNormalDiag with
TFPDistribution
.
MultivariateNormalDiagPlusLowRank¶
-
class
MultivariateNormalDiagPlusLowRank
(loc=None, scale_diag=None, scale_perturb_factor=None, scale_perturb_diag=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalDiagPlusLowRank')¶ Wraps tensorflow_probability.substrates.jax.distributions.mvn_diag_plus_low_rank.MultivariateNormalDiagPlusLowRank with
TFPDistribution
.
MultivariateNormalDiagPlusLowRankCovariance¶
-
class
MultivariateNormalDiagPlusLowRankCovariance
(loc=None, cov_diag_factor=None, cov_perturb_factor=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalDiagPlusLowRankCovariance')¶ Wraps tensorflow_probability.substrates.jax.distributions.mvn_diag_plus_low_rank_covariance.MultivariateNormalDiagPlusLowRankCovariance with
TFPDistribution
.
MultivariateNormalFullCovariance¶
-
class
MultivariateNormalFullCovariance
(loc=None, covariance_matrix=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalFullCovariance')¶ Wraps tensorflow_probability.substrates.jax.distributions.mvn_full_covariance.MultivariateNormalFullCovariance with
TFPDistribution
.
MultivariateNormalLinearOperator¶
-
class
MultivariateNormalLinearOperator
(loc=None, scale=None, validate_args=False, allow_nan_stats=True, experimental_use_kahan_sum=False, name='MultivariateNormalLinearOperator')¶ Wraps tensorflow_probability.substrates.jax.distributions.mvn_linear_operator.MultivariateNormalLinearOperator with
TFPDistribution
.
MultivariateNormalTriL¶
-
class
MultivariateNormalTriL
(loc=None, scale_tril=None, validate_args=False, allow_nan_stats=True, experimental_use_kahan_sum=False, name='MultivariateNormalTriL')¶ Wraps tensorflow_probability.substrates.jax.distributions.mvn_tril.MultivariateNormalTriL with
TFPDistribution
.
MultivariateStudentTLinearOperator¶
-
class
MultivariateStudentTLinearOperator
(df, loc, scale, validate_args=False, allow_nan_stats=True, name='MultivariateStudentTLinearOperator')¶ Wraps tensorflow_probability.substrates.jax.distributions.multivariate_student_t.MultivariateStudentTLinearOperator with
TFPDistribution
.
NegativeBinomial¶
-
class
NegativeBinomial
(total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='NegativeBinomial')¶ Wraps tensorflow_probability.substrates.jax.distributions.negative_binomial.NegativeBinomial with
TFPDistribution
.
Normal¶
-
class
Normal
(loc, scale, validate_args=False, allow_nan_stats=True, name='Normal')¶ Wraps tensorflow_probability.substrates.jax.distributions.normal.Normal with
TFPDistribution
.
NormalInverseGaussian¶
-
class
NormalInverseGaussian
(loc, scale, tailweight, skewness, validate_args=False, allow_nan_stats=True, name='NormalInverseGaussian')¶ Wraps tensorflow_probability.substrates.jax.distributions.normal_inverse_gaussian.NormalInverseGaussian with
TFPDistribution
.
OneHotCategorical¶
-
class
OneHotCategorical
(logits=None, probs=None, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='OneHotCategorical')¶ Wraps tensorflow_probability.substrates.jax.distributions.onehot_categorical.OneHotCategorical with
TFPDistribution
.
OrderedLogistic¶
-
class
OrderedLogistic
(cutpoints, loc, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='OrderedLogistic')¶ Wraps tensorflow_probability.substrates.jax.distributions.ordered_logistic.OrderedLogistic with
TFPDistribution
.
PERT¶
-
class
PERT
(low, peak, high, temperature=4.0, validate_args=False, allow_nan_stats=False, name='PERT')¶ Wraps tensorflow_probability.substrates.jax.distributions.pert.PERT with
TFPDistribution
.
Pareto¶
-
class
Pareto
(concentration, scale=1.0, validate_args=False, allow_nan_stats=True, name='Pareto')¶ Wraps tensorflow_probability.substrates.jax.distributions.pareto.Pareto with
TFPDistribution
.
PlackettLuce¶
-
class
PlackettLuce
(scores, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='PlackettLuce')¶ Wraps tensorflow_probability.substrates.jax.distributions.plackett_luce.PlackettLuce with
TFPDistribution
.
Poisson¶
-
class
Poisson
(rate=None, log_rate=None, force_probs_to_zero_outside_support=None, interpolate_nondiscrete=True, validate_args=False, allow_nan_stats=True, name='Poisson')¶ Wraps tensorflow_probability.substrates.jax.distributions.poisson.Poisson with
TFPDistribution
.
PoissonLogNormalQuadratureCompound¶
-
class
PoissonLogNormalQuadratureCompound
(loc, scale, quadrature_size=8, quadrature_fn=<function quadrature_scheme_lognormal_quantiles>, validate_args=False, allow_nan_stats=True, name='PoissonLogNormalQuadratureCompound')¶ Wraps tensorflow_probability.substrates.jax.distributions.poisson_lognormal.PoissonLogNormalQuadratureCompound with
TFPDistribution
.
PowerSpherical¶
-
class
PowerSpherical
(mean_direction, concentration, validate_args=False, allow_nan_stats=True, name='PowerSpherical')¶ Wraps tensorflow_probability.substrates.jax.distributions.power_spherical.PowerSpherical with
TFPDistribution
.
ProbitBernoulli¶
-
class
ProbitBernoulli
(probits=None, probs=None, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='ProbitBernoulli')¶ Wraps tensorflow_probability.substrates.jax.distributions.probit_bernoulli.ProbitBernoulli with
TFPDistribution
.
QuantizedDistribution¶
-
class
QuantizedDistribution
(distribution, low=None, high=None, validate_args=False, name='QuantizedDistribution')¶ Wraps tensorflow_probability.substrates.jax.distributions.quantized_distribution.QuantizedDistribution with
TFPDistribution
.
RelaxedBernoulli¶
-
class
RelaxedBernoulli
(temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='RelaxedBernoulli')¶ Wraps tensorflow_probability.substrates.jax.distributions.relaxed_bernoulli.RelaxedBernoulli with
TFPDistribution
.
RelaxedOneHotCategorical¶
-
class
RelaxedOneHotCategorical
(temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='RelaxedOneHotCategorical')¶ Wraps tensorflow_probability.substrates.jax.distributions.relaxed_onehot_categorical.RelaxedOneHotCategorical with
TFPDistribution
.
Sample¶
-
class
Sample
(distribution, sample_shape=(), validate_args=False, experimental_use_kahan_sum=False, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.sample.Sample with
TFPDistribution
.
SigmoidBeta¶
-
class
SigmoidBeta
(concentration1, concentration0, validate_args=False, allow_nan_stats=True, name='SigmoidBeta')¶ Wraps tensorflow_probability.substrates.jax.distributions.sigmoid_beta.SigmoidBeta with
TFPDistribution
.
SinhArcsinh¶
-
class
SinhArcsinh
(loc, scale, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name='SinhArcsinh')¶ Wraps tensorflow_probability.substrates.jax.distributions.sinh_arcsinh.SinhArcsinh with
TFPDistribution
.
Skellam¶
-
class
Skellam
(rate1=None, rate2=None, log_rate1=None, log_rate2=None, force_probs_to_zero_outside_support=False, validate_args=False, allow_nan_stats=True, name='Skellam')¶ Wraps tensorflow_probability.substrates.jax.distributions.skellam.Skellam with
TFPDistribution
.
SphericalUniform¶
-
class
SphericalUniform
(dimension, batch_shape=(), dtype=<class 'jax._src.numpy.lax_numpy.float32'>, validate_args=False, allow_nan_stats=True, name='SphericalUniform')¶ Wraps tensorflow_probability.substrates.jax.distributions.spherical_uniform.SphericalUniform with
TFPDistribution
.
StoppingRatioLogistic¶
-
class
StoppingRatioLogistic
(cutpoints, loc, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='StoppingRatioLogistic')¶ Wraps tensorflow_probability.substrates.jax.distributions.stopping_ratio_logistic.StoppingRatioLogistic with
TFPDistribution
.
StudentT¶
-
class
StudentT
(df, loc, scale, validate_args=False, allow_nan_stats=True, name='StudentT')¶ Wraps tensorflow_probability.substrates.jax.distributions.student_t.StudentT with
TFPDistribution
.
StudentTProcess¶
-
class
StudentTProcess
(df, kernel, index_points=None, mean_fn=None, observation_noise_variance=0.0, marginal_fn=None, jitter=1e-06, validate_args=False, allow_nan_stats=False, name='StudentTProcess')¶ Wraps tensorflow_probability.substrates.jax.distributions.student_t_process.StudentTProcess with
TFPDistribution
.
TransformedDistribution¶
-
class
TransformedDistribution
(distribution, bijector, kwargs_split_fn=<function _default_kwargs_split_fn>, validate_args=False, parameters=None, name=None)¶ Wraps tensorflow_probability.substrates.jax.distributions.transformed_distribution.TransformedDistribution with
TFPDistribution
.
Triangular¶
-
class
Triangular
(low=0.0, high=1.0, peak=0.5, validate_args=False, allow_nan_stats=True, name='Triangular')¶ Wraps tensorflow_probability.substrates.jax.distributions.triangular.Triangular with
TFPDistribution
.
TruncatedCauchy¶
-
class
TruncatedCauchy
(loc, scale, low, high, validate_args=False, allow_nan_stats=True, name='TruncatedCauchy')¶ Wraps tensorflow_probability.substrates.jax.distributions.truncated_cauchy.TruncatedCauchy with
TFPDistribution
.
TruncatedNormal¶
-
class
TruncatedNormal
(loc, scale, low, high, validate_args=False, allow_nan_stats=True, name='TruncatedNormal')¶ Wraps tensorflow_probability.substrates.jax.distributions.truncated_normal.TruncatedNormal with
TFPDistribution
.
Uniform¶
-
class
Uniform
(low=0.0, high=1.0, validate_args=False, allow_nan_stats=True, name='Uniform')¶ Wraps tensorflow_probability.substrates.jax.distributions.uniform.Uniform with
TFPDistribution
.
VariationalGaussianProcess¶
-
class
VariationalGaussianProcess
(kernel, index_points, inducing_index_points, variational_inducing_observations_loc, variational_inducing_observations_scale, mean_fn=None, observation_noise_variance=None, predictive_noise_variance=None, jitter=1e-06, validate_args=False, allow_nan_stats=False, name='VariationalGaussianProcess')¶ Wraps tensorflow_probability.substrates.jax.distributions.variational_gaussian_process.VariationalGaussianProcess with
TFPDistribution
.
VectorDeterministic¶
-
class
VectorDeterministic
(loc, atol=None, rtol=None, validate_args=False, allow_nan_stats=True, name='VectorDeterministic')¶ Wraps tensorflow_probability.substrates.jax.distributions.deterministic.VectorDeterministic with
TFPDistribution
.
VectorExponentialDiag¶
-
class
VectorExponentialDiag
(loc=None, scale_diag=None, scale_identity_multiplier=None, validate_args=False, allow_nan_stats=True, name='VectorExponentialDiag')¶ Wraps tensorflow_probability.substrates.jax.distributions.vector_exponential_diag.VectorExponentialDiag with
TFPDistribution
.
VonMises¶
-
class
VonMises
(loc, concentration, validate_args=False, allow_nan_stats=True, name='VonMises')¶ Wraps tensorflow_probability.substrates.jax.distributions.von_mises.VonMises with
TFPDistribution
.
VonMisesFisher¶
-
class
VonMisesFisher
(mean_direction, concentration, validate_args=False, allow_nan_stats=True, name='VonMisesFisher')¶ Wraps tensorflow_probability.substrates.jax.distributions.von_mises_fisher.VonMisesFisher with
TFPDistribution
.
Weibull¶
-
class
Weibull
(concentration, scale, validate_args=False, allow_nan_stats=True, name='Weibull')¶ Wraps tensorflow_probability.substrates.jax.distributions.weibull.Weibull with
TFPDistribution
.
WishartLinearOperator¶
-
class
WishartLinearOperator
(df, scale, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name='WishartLinearOperator')¶ Wraps tensorflow_probability.substrates.jax.distributions.wishart.WishartLinearOperator with
TFPDistribution
.
WishartTriL¶
-
class
WishartTriL
(df, scale_tril=None, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name='WishartTriL')¶ Wraps tensorflow_probability.substrates.jax.distributions.wishart.WishartTriL with
TFPDistribution
.
Zipf¶
-
class
Zipf
(power, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, force_probs_to_zero_outside_support=None, interpolate_nondiscrete=True, sample_maximum_iterations=100, validate_args=False, allow_nan_stats=False, name='Zipf')¶ Wraps tensorflow_probability.substrates.jax.distributions.zipf.Zipf with
TFPDistribution
.
Constraints¶
Constraint¶
-
class
Constraint
[source]¶ Bases:
object
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
-
is_discrete
= False¶
-
event_dim
= 0¶
-
dependent¶
-
dependent
= <numpyro.distributions.constraints._Dependent object>¶ Placeholder for variables whose support depends on other variables. These variables obey no simple coordinate-wise constraints.
Parameters: - is_discrete (bool) – Optional value of
.is_discrete
in case this can be computed statically. If not provided, access to the.is_discrete
attribute will raise a NotImplementedError. - event_dim (int) – Optional value of
.event_dim
in case this can be computed statically. If not provided, access to the.event_dim
attribute will raise a NotImplementedError.
- is_discrete (bool) – Optional value of
greater_than¶
-
greater_than
(lower_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
integer_interval¶
-
integer_interval
(lower_bound, upper_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
integer_greater_than¶
-
integer_greater_than
(lower_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
interval¶
-
interval
(lower_bound, upper_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
less_than¶
-
less_than
(upper_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
multinomial¶
-
multinomial
(upper_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
nonnegative_integer¶
-
nonnegative_integer
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
positive_definite¶
-
positive_definite
= <numpyro.distributions.constraints._PositiveDefinite object>¶
positive_integer¶
-
positive_integer
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
positive_ordered_vector¶
-
positive_ordered_vector
= <numpyro.distributions.constraints._PositiveOrderedVector object>¶ Constrains to a positive real-valued tensor where the elements are monotonically increasing along the event_shape dimension.
real_vector¶
-
real_vector
= <numpyro.distributions.constraints._IndependentConstraint object>¶ Wraps a constraint by aggregating over
reinterpreted_batch_ndims
-many dims incheck()
, so that an event is valid only if all its independent entries are valid.
softplus_positive¶
-
softplus_positive
= <numpyro.distributions.constraints._SoftplusPositive object>¶
softplus_lower_cholesky¶
-
softplus_lower_cholesky
= <numpyro.distributions.constraints._SoftplusLowerCholesky object>¶
Transforms¶
Transform¶
-
class
Transform
[source]¶ Bases:
object
-
domain
= <numpyro.distributions.constraints._Real object>¶
-
codomain
= <numpyro.distributions.constraints._Real object>¶
-
event_dim
¶
-
inv
¶
-
AbsTransform¶
-
class
AbsTransform
[source]¶ Bases:
numpyro.distributions.transforms.Transform
-
domain
= <numpyro.distributions.constraints._Real object>¶
-
codomain
= <numpyro.distributions.constraints._GreaterThan object>¶
-
AffineTransform¶
-
class
AffineTransform
(loc, scale, domain=<numpyro.distributions.constraints._Real object>)[source]¶ Bases:
numpyro.distributions.transforms.Transform
Note
When scale is a JAX tracer, we always assume that scale > 0 when calculating codomain.
-
codomain
¶
-
CholeskyTransform¶
-
class
CholeskyTransform
[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform via the mapping \(y = cholesky(x)\), where x is a positive definite matrix.
-
domain
= <numpyro.distributions.constraints._PositiveDefinite object>¶
-
codomain
= <numpyro.distributions.constraints._LowerCholesky object>¶
-
ComposeTransform¶
CorrCholeskyTransform¶
-
class
CorrCholeskyTransform
[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transforms a uncontrained real vector \(x\) with length \(D*(D-1)/2\) into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:
- First we convert \(x\) into a lower triangular matrix with the following order:
\[\begin{split}\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix}\end{split}\]2. For each row \(X_i\) of the lower triangular part, we apply a signed version of class
StickBreakingTransform
to transform \(X_i\) into a unit Euclidean length vector using the following steps:- Scales into the interval \((-1, 1)\) domain: \(r_i = \tanh(X_i)\).
- Transforms into an unsigned domain: \(z_i = r_i^2\).
- Applies \(s_i = StickBreakingTransform(z_i)\).
- Transforms back into signed domain: \(y_i = (sign(r_i), 1) * \sqrt{s_i}\).
-
domain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
codomain
= <numpyro.distributions.constraints._CorrCholesky object>¶
CorrMatrixCholeskyTransform¶
-
class
CorrMatrixCholeskyTransform
[source]¶ Bases:
numpyro.distributions.transforms.CholeskyTransform
Transform via the mapping \(y = cholesky(x)\), where x is a correlation matrix.
-
domain
= <numpyro.distributions.constraints._CorrMatrix object>¶
-
codomain
= <numpyro.distributions.constraints._CorrCholesky object>¶
-
ExpTransform¶
InvCholeskyTransform¶
-
class
InvCholeskyTransform
(domain=<numpyro.distributions.constraints._LowerCholesky object>)[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform via the mapping \(y = x @ x.T\), where x is a lower triangular matrix with positive diagonal.
-
codomain
¶
-
LowerCholeskyAffine¶
-
class
LowerCholeskyAffine
(loc, scale_tril)[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform via the mapping \(y = loc + scale\_tril\ @\ x\).
Parameters: - loc – a real vector.
- scale_tril – a lower triangular matrix with positive diagonal.
-
domain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
codomain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
LowerCholeskyTransform¶
-
class
LowerCholeskyTransform
[source]¶ Bases:
numpyro.distributions.transforms.Transform
-
domain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
codomain
= <numpyro.distributions.constraints._LowerCholesky object>¶
-
OrderedTransform¶
-
class
OrderedTransform
[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform a real vector to an ordered vector.
References:
- Stan Reference Manual v2.20, section 10.6, Stan Development Team
-
domain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
codomain
= <numpyro.distributions.constraints._OrderedVector object>¶
PermuteTransform¶
-
class
PermuteTransform
(permutation)[source]¶ Bases:
numpyro.distributions.transforms.Transform
-
domain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
codomain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
PowerTransform¶
-
class
PowerTransform
(exponent)[source]¶ Bases:
numpyro.distributions.transforms.Transform
-
domain
= <numpyro.distributions.constraints._GreaterThan object>¶
-
codomain
= <numpyro.distributions.constraints._GreaterThan object>¶
-
SigmoidTransform¶
SoftplusLowerCholeskyTransform¶
-
class
SoftplusLowerCholeskyTransform
[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform from unconstrained vector to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.
-
domain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
codomain
= <numpyro.distributions.constraints._SoftplusLowerCholesky object>¶
-
SoftplusTransform¶
-
class
SoftplusTransform
[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform from unconstrained space to positive domain via softplus \(y = \log(1 + \exp(x))\). The inverse is computed as \(x = \log(\exp(y) - 1)\).
-
domain
= <numpyro.distributions.constraints._Real object>¶
-
codomain
= <numpyro.distributions.constraints._SoftplusPositive object>¶
-
StickBreakingTransform¶
-
class
StickBreakingTransform
[source]¶ Bases:
numpyro.distributions.transforms.Transform
-
domain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
codomain
= <numpyro.distributions.constraints._Simplex object>¶
-
Flows¶
InverseAutoregressiveTransform¶
-
class
InverseAutoregressiveTransform
(autoregressive_nn, log_scale_min_clip=-5.0, log_scale_max_clip=3.0)[source]¶ Bases:
numpyro.distributions.transforms.Transform
An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016,
\(\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}\)where \(\mathbf{x}\) are the inputs, \(\mathbf{y}\) are the outputs, \(\mu_t,\sigma_t\) are calculated from an autoregressive network on \(\mathbf{x}\), and \(\sigma_t>0\).
References
- Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934], Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling
-
domain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
codomain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
log_abs_det_jacobian
(x, y, intermediates=None)[source]¶ Calculates the elementwise determinant of the log jacobian.
Parameters: - x (numpy.ndarray) – the input to the transform
- y (numpy.ndarray) – the output of the transform
BlockNeuralAutoregressiveTransform¶
-
class
BlockNeuralAutoregressiveTransform
(bn_arn)[source]¶ Bases:
numpyro.distributions.transforms.Transform
An implementation of Block Neural Autoregressive flow.
References
- Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz
-
domain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
codomain
= <numpyro.distributions.constraints._IndependentConstraint object>¶
-
log_abs_det_jacobian
(x, y, intermediates=None)[source]¶ Calculates the elementwise determinant of the log jacobian.
Parameters: - x (numpy.ndarray) – the input to the transform
- y (numpy.ndarray) – the output of the transform