Warning
This page was created from a pull request.
jax.lax.stop_gradient¶
-
jax.lax.
stop_gradient
(x)[source]¶ Stops gradient computation.
Operationally
stop_gradient
is the identity function, that is, it returns argument x unchanged. However,stop_gradient
prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations,stop_gradient
stops gradients for all of them.For example:
>>> jax.grad(lambda x: x**2)(3.) array(6., dtype=float32) >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) array(0., dtype=float32) >>> jax.grad(jax.grad(lambda x: x**2))(3.) array(2., dtype=float32) >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) array(0., dtype=float32)