Warning
This page was created from a pull request.
jax.lax.stop_gradient¶
-
jax.lax.stop_gradient(x)[source]¶ Stops gradient computation.
Operationally
stop_gradientis the identity function, that is, it returns argument x unchanged. However,stop_gradientprevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations,stop_gradientstops gradients for all of them.For example:
>>> jax.grad(lambda x: x**2)(3.) array(6., dtype=float32) >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) array(0., dtype=float32) >>> jax.grad(jax.grad(lambda x: x**2))(3.) array(2., dtype=float32) >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) array(0., dtype=float32)