Warning

This page was created from a pull request.

jax.nn.initializers package

Common neural network layer initializers, consistent with definitions used in Keras and Sonnet.

Initializers

This module provides common neural network layer initializers, consistent with definitions used in Keras and Sonnet.

zeros(key, shape[, dtype])

ones(key, shape[, dtype])

uniform([scale, dtype])

normal([stddev, dtype])

variance_scaling(scale, mode, distribution)

glorot_uniform([in_axis, out_axis, dtype])

glorot_normal([in_axis, out_axis, dtype])

lecun_uniform([in_axis, out_axis, dtype])

lecun_normal([in_axis, out_axis, dtype])

he_uniform([in_axis, out_axis, dtype])

he_normal([in_axis, out_axis, dtype])