Warning
This page was created from a pull request.
jax.experimental.optimizers module¶
Optimizers for use with JAX.
This module contains some convenient optimizer definitions, specifically initialization and update functions, which can be used with ndarrays or arbitrarily-nested tuple/list/dicts of ndarrays.
An optimizer is modeled as an (init_fun, update_fun, get_params)
triple of
functions, where the component functions have these signatures:
init_fun(params)
Args:
params: pytree representing the initial parameters.
Returns:
A pytree representing the initial optimizer state, which includes the
initial parameters and may also include auxiliary values like initial
momentum. The optimizer state pytree structure generally differs from that
of `params`.
update_fun(step, grads, opt_state)
Args:
step: integer representing the step index.
grads: a pytree with the same structure as `get_params(opt_state)`
representing the gradients to be used in updating the optimizer state.
opt_state: a pytree representing the optimizer state to be updated.
Returns:
A pytree with the same structure as the `opt_state` argument representing
the updated optimizer state.
get_params(opt_state)
Args:
opt_state: pytree representing an optimizer state.
Returns:
A pytree representing the parameters extracted from `opt_state`, such that
the invariant `params == get_params(init_fun(params))` holds true.
Notice that an optimizer implementation has a lot of flexibility in the form of opt_state: it just has to be a pytree of JaxTypes (so that it can be passed to the JAX transforms defined in api.py) and it has to be consumable by update_fun and get_params.
Example Usage:
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)
def step(step, opt_state):
value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
opt_state = opt_update(step, grads, opt_state)
return value, opt_state
for step in range(num_steps):
value, opt_state = step(step, opt_state)
-
class
jax.experimental.optimizers.
JoinPoint
(subtree)[source]¶ Bases:
object
Marks the boundary between two joined (nested) pytrees.
-
class
jax.experimental.optimizers.
Optimizer
(init_fn, update_fn, params_fn)[source]¶ Bases:
tuple
-
property
init_fn
¶ Alias for field number 0
-
property
params_fn
¶ Alias for field number 2
-
property
update_fn
¶ Alias for field number 1
-
property
-
class
jax.experimental.optimizers.
OptimizerState
(packed_state, tree_def, subtree_defs)¶ Bases:
tuple
-
property
packed_state
¶ Alias for field number 0
-
property
subtree_defs
¶ Alias for field number 2
-
property
tree_def
¶ Alias for field number 1
-
property
-
jax.experimental.optimizers.
adagrad
(step_size, momentum=0.9)[source]¶ Construct optimizer triple for Adagrad.
Adaptive Subgradient Methods for Online Learning and Stochastic Optimization: http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
- Parameters
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
momentum – optional, a positive scalar value for momentum
- Returns
An (init_fun, update_fun, get_params) triple.
-
jax.experimental.optimizers.
adam
(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]¶ Construct optimizer triple for Adam.
- Parameters
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
b1 – optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).
b2 – optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).
eps – optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8).
- Returns
An (init_fun, update_fun, get_params) triple.
-
jax.experimental.optimizers.
adamax
(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]¶ Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm).
- Parameters
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
b1 – optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).
b2 – optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).
eps – optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8).
- Returns
An (init_fun, update_fun, get_params) triple.
-
jax.experimental.optimizers.
clip_grads
(grad_tree, max_norm)[source]¶ Clip gradients stored as a pytree of arrays to maximum norm max_norm.
-
jax.experimental.optimizers.
inverse_time_decay
(step_size, decay_steps, decay_rate, staircase=False)[source]¶
-
jax.experimental.optimizers.
l2_norm
(tree)[source]¶ Compute the l2 norm of a pytree of arrays. Useful for weight decay.
-
jax.experimental.optimizers.
momentum
(step_size, mass)[source]¶ Construct optimizer triple for SGD with momentum.
-
jax.experimental.optimizers.
nesterov
(step_size, mass)[source]¶ Construct optimizer triple for SGD with Nesterov momentum.
-
jax.experimental.optimizers.
optimizer
(opt_maker)[source]¶ Decorator to make an optimizer defined for arrays generalize to containers.
With this decorator, you can write init, update, and get_params functions that each operate only on single arrays, and convert them to corresponding functions that operate on pytrees of parameters. See the optimizers defined in optimizers.py for examples.
- Parameters
opt_maker (
Callable
[…,Tuple
[Callable
[[Any
],Any
],Callable
[[int
,Any
,Any
],Any
],Callable
[[Any
],Any
]]]) –a function that returns an
(init_fun, update_fun, get_params)
triple of functions that might only work with ndarrays, as perinit_fun :: ndarray -> OptStatePytree ndarray update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray get_params :: OptStatePytree ndarray -> ndarray
- Return type
- Returns
An
(init_fun, update_fun, get_params)
triple of functions that work on arbitrary pytrees, as perinit_fun :: ParameterPytree ndarray -> OptimizerState update_fun :: OptimizerState -> OptimizerState get_params :: OptimizerState -> ParameterPytree ndarray
The OptimizerState pytree type used by the returned functions is isomorphic to
ParameterPytree (OptStatePytree ndarray)
, but may store the state instead as e.g. a partially-flattened data structure for performance.
-
jax.experimental.optimizers.
pack_optimizer_state
(marked_pytree)[source]¶ Converts a marked pytree to an OptimizerState.
The inverse of unpack_optimizer_state. Converts a marked pytree with the leaves of the outer pytree represented as JoinPoints back into an OptimizerState. This function is intended to be useful when deserializing optimizer states.
- Parameters
marked_pytree – A pytree containing JoinPoint leaves that hold more pytrees.
- Returns
An equivalent OptimizerState to the input argument.
-
jax.experimental.optimizers.
polynomial_decay
(step_size, decay_steps, final_step_size, power=1.0)[source]¶
-
jax.experimental.optimizers.
rmsprop
(step_size, gamma=0.9, eps=1e-08)[source]¶ Construct optimizer triple for RMSProp.
- Parameters
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar. gamma: Decay parameter. eps: Epsilon parameter.
- Returns
An (init_fun, update_fun, get_params) triple.
-
jax.experimental.optimizers.
rmsprop_momentum
(step_size, gamma=0.9, eps=1e-08, momentum=0.9)[source]¶ Construct optimizer triple for RMSProp with momentum.
This optimizer is separate from the rmsprop optimizer because it needs to keep track of additional parameters.
- Parameters
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
gamma – Decay parameter.
eps – Epsilon parameter.
momentum – Momentum parameter.
- Returns
An (init_fun, update_fun, get_params) triple.
-
jax.experimental.optimizers.
sgd
(step_size)[source]¶ Construct optimizer triple for stochastic gradient descent.
- Parameters
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
- Returns
An (init_fun, update_fun, get_params) triple.
-
jax.experimental.optimizers.
sm3
(step_size, momentum=0.9)[source]¶ Construct optimizer triple for SM3.
Memory-Efficient Adaptive Optimization for Large-Scale Learning. https://arxiv.org/abs/1901.11150
- Parameters
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
momentum – optional, a positive scalar value for momentum
- Returns
An (init_fun, update_fun, get_params) triple.
-
jax.experimental.optimizers.
unpack_optimizer_state
(opt_state)[source]¶ Converts an OptimizerState to a marked pytree.
Converts an OptimizerState to a marked pytree with the leaves of the outer pytree represented as JoinPoints to avoid losing information. This function is intended to be useful when serializing optimizer states.
- Parameters
opt_state – An OptimizerState
- Returns
A pytree with JoinPoint leaves that contain a second level of pytrees.