Warning
This page was created from a pull request.
jax.lax.condΒΆ
-
jax.lax.cond(*args, **kwargs)[source]ΒΆ Conditionally apply
true_funorfalse_fun.cond()has equivalent semantics to this Python implementation:def cond(pred, true_fun, false_fun, operand): if pred: return true_fun(operand) else: return false_fun(operand)
predmust be a scalar type.Functions
true_fun/false_funmay not need to refer to anoperandto compute their result, but one must still be provided to thecondcall and be accepted by both the branch functions, e.g.:jax.lax.cond( get_predicate_value(), lambda _: 23, lambda _: 42, operand=None)
- Parameters
pred β Boolean scalar type, indicating which branch function to apply.
true_fun β Function (A -> B), to be applied if
predis True.false_fun β Function (A -> B), to be applied if
predis False.operand β Operand (A) input to either branch depending on
pred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.
- Returns
Value (B) of either
true_fun(operand)orfalse_fun(operand), depending on the value ofpred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.