Pyro Primitives¶
param¶
-
param
(name, init_value=None, **kwargs)[source]¶ Annotate the given site as an optimizable parameter for use with
jax.experimental.optimizers
. For an example of how param statements can be used in inference algorithms, refer tosvi()
.Parameters: - name (str) – name of site.
- init_value (numpy.ndarray) – initial value specified by the user. Note that the onus of using this to initialize the optimizer is on the user / inference algorithm, since there is no global parameter store in NumPyro.
Returns: value for the parameter. Unless wrapped inside a handler like
substitute
, this will simply return the initial value.
sample¶
-
sample
(name, fn, obs=None, sample_shape=())[source]¶ Returns a random sample from the stochastic function fn. This can have additional side effects when wrapped inside effect handlers like
substitute
.Parameters: - name (str) – name of the sample site
- fn – Python callable
- obs (numpy.ndarray) – observed value
- sample_shape – Shape of samples to be drawn.
Returns: sample from the stochastic fn.
module¶
-
module
(name, nn, input_shape=None)[source]¶ Declare a
stax
style neural network inside a model so that its parameters are registered for optimization viaparam()
statements.Parameters: Returns: a apply_fn with bound parameters that takes an array as an input and returns the neural network transformed output array.