Automatic rendering of NumPyro models¶
In this tutorial we will demonstrate how to create beautiful visualizations of your probabilistic graphical models.
[ ]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
[1]:
from jax import nn
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
assert numpyro.__version__.startswith('0.7.1')
A Simple Example¶
The visualization interface can be readily used with your models:
[2]:
def model(data):
m = numpyro.sample('m', dist.Normal(0, 1))
sd = numpyro.sample('sd', dist.LogNormal(m, 1))
with numpyro.plate('N', len(data)):
numpyro.sample('obs', dist.Normal(m, sd), obs=data)
[3]:
data = jnp.ones(10)
numpyro.render_model(model, model_args=(data,))
[3]:
The visualization can be saved to a file by providing filename='path'
to numpyro.render_model
. You can use different formats such as PDF or PNG by changing the filename’s suffix. When not saving to a file (filename=None
), you can also change the format with graph.format = 'pdf'
where graph
is the object returned by numpyro.render_model
.
[4]:
graph = numpyro.render_model(model, model_args=(data,), filename='model.pdf')
Tweaking the visualization¶
As numpyro.render_model
returns an object of type graphviz.dot.Digraph
, you can further improve the visualization of this graph. For example, you could use the unflatten preprocessor to improve the layout aspect ratio for more complex models.
[5]:
def mace(positions, annotations):
"""
This model corresponds to the plate diagram in Figure 3 of https://www.aclweb.org/anthology/Q18-1040.pdf.
"""
num_annotators = int(jnp.max(positions)) + 1
num_classes = int(jnp.max(annotations)) + 1
num_items, num_positions = annotations.shape
with numpyro.plate('annotator', num_annotators):
epsilon = numpyro.sample('epsilon', dist.Dirichlet(jnp.full(num_classes, 10)))
theta = numpyro.sample('theta', dist.Beta(0.5, 0.5))
with numpyro.plate('item', num_items, dim=-2):
# NB: using constant logits for discrete uniform prior
# (NumPyro does not have DiscreteUniform distribution yet)
c = numpyro.sample('c', dist.Categorical(logits=jnp.zeros(num_classes)))
with numpyro.plate('position', num_positions):
s = numpyro.sample('s', dist.Bernoulli(1 - theta[positions]))
probs = jnp.where(s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions])
numpyro.sample('y', dist.Categorical(probs), obs=annotations)
positions = jnp.array([1, 1, 1, 2, 3, 4, 5])
annotations = jnp.array([
[1, 3, 1, 2, 2, 2, 1, 3, 2, 2, 4, 2, 1, 2, 1,
1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,
1, 3, 1, 2, 2, 4, 2, 2, 3, 1, 1, 1, 2, 1, 2],
[1, 3, 1, 2, 2, 2, 2, 3, 2, 3, 4, 2, 1, 2, 2,
1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 3, 1, 1, 1,
1, 3, 1, 2, 2, 3, 2, 3, 3, 1, 1, 2, 3, 2, 2],
[1, 3, 2, 2, 2, 2, 2, 3, 2, 2, 4, 2, 1, 2, 1,
1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2,
1, 3, 1, 2, 2, 3, 1, 2, 3, 1, 1, 1, 2, 1, 2],
[1, 4, 2, 3, 3, 3, 2, 3, 2, 2, 4, 3, 1, 3, 1,
2, 1, 1, 2, 1, 2, 2, 3, 2, 1, 1, 2, 1, 1, 1,
1, 3, 1, 2, 3, 4, 2, 3, 3, 1, 1, 2, 2, 1, 2],
[1, 3, 1, 1, 2, 3, 1, 4, 2, 2, 4, 3, 1, 2, 1,
1, 1, 1, 2, 3, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,
1, 2, 1, 2, 2, 3, 2, 2, 4, 1, 1, 1, 2, 1, 2],
[1, 3, 2, 2, 2, 2, 1, 3, 2, 2, 4, 4, 1, 1, 1,
1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 2,
1, 3, 1, 2, 3, 4, 3, 3, 3, 1, 1, 1, 2, 1, 2],
[1, 4, 2, 1, 2, 2, 1, 3, 3, 3, 4, 3, 1, 2, 1,
1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1,
1, 3, 1, 2, 2, 3, 2, 3, 2, 1, 1, 1, 2, 1, 2],
]).T
# we subtract 1 because the first index starts with 0 in Python
positions -= 1
annotations -= 1
mace_graph = numpyro.render_model(mace, model_args=(positions, annotations))
[6]:
# default layout
mace_graph
[6]:
[7]:
# layout after processing the layout with unflatten
mace_graph.unflatten(stagger=2)
[7]:
Distribution annotations¶
It is possible to display the distribution of each RV in the generated plot by providing render_distributions=True
when calling numpyro.render_model
.
[8]:
def model(data):
x = numpyro.sample('x', dist.Normal(0, 1))
y = numpyro.sample('y', dist.LogNormal(x, 1))
with numpyro.plate('N', len(data)):
numpyro.sample('z', dist.Normal(x, y), obs=data)
[9]:
data = jnp.ones(10)
numpyro.render_model(model, model_args=(data,), render_distributions=True)
[9]: