Optimizers¶
Optimizer classes defined here are light wrappers over the corresponding optimizers
sourced from jax.experimental.optimizers
with an interface that is better
suited for working with NumPyro inference algorithms.
Adam¶
-
class
Adam
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
adam()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
Adagrad¶
-
class
Adagrad
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
adagrad()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
ClippedAdam¶
-
class
ClippedAdam
(*args, clip_norm=10.0, **kwargs)[source]¶ Adam
optimizer with gradient clipping.Parameters: clip_norm (float) – All gradient values will be clipped between [-clip_norm, clip_norm]. Reference:
A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
Momentum¶
-
class
Momentum
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
momentum()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
RMSProp¶
-
class
RMSProp
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
rmsprop()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
RMSPropMomentum¶
-
class
RMSPropMomentum
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
rmsprop_momentum()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
SGD¶
-
class
SGD
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
sgd()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
SM3¶
-
class
SM3
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
sm3()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-