Warning
This page was created from a pull request.
jax.lax.fori_loopΒΆ
-
jax.lax.
fori_loop
(lower, upper, body_fun, init_val)[source]ΒΆ Loop from
lower
toupper
by reduction tojax.lax.while_loop()
.The type signature in brief is
fori_loop :: Int -> Int -> ((int, a) -> a) -> a -> a
The semantics of
fori_loop
are given by this Python implementation:def fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val
Unlike that Python version,
fori_loop
is implemented in terms of a call tojax.lax.while_loop()
. See thejax.lax.while_loop()
documentation for more information.Also unlike the Python analogue, the loop-carried value
val
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the typea
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).- Parameters
lower β an integer representing the loop index lower bound (inclusive)
upper β an integer representing the loop index upper bound (exclusive)
body_fun β function of type
(int, a) -> a
.init_val β initial loop carry value of type
a
.
- Returns
Loop value from the final iteration, of type
a
.