Warning
This page was created from a pull request.
jax.numpy.piecewiseΒΆ
-
jax.numpy.
piecewise
(x, condlist, funclist, *args, **kw)[source]ΒΆ Evaluate a piecewise-defined function.
LAX-backend implementation of
piecewise()
. Unlike np.piecewise,jax.numpy.piecewise()
requires functions in funclist to be traceable by JAX, as it is implemeted viajax.lax.switch()
. See thejax.lax.switch()
documentation for more information.Original docstring below.
Given a set of conditions and corresponding functions, evaluate each function on the input data wherever its condition is true.
- Parameters
x (ndarray or scalar) β The input domain.
condlist (list of bool arrays or bool scalars) β Each boolean array corresponds to a function in funclist. Wherever condlist[i] is True, funclist[i](x) is used as the output value.
funclist (list of callables, f(x,*args,**kw), or scalars) β Each function is evaluated over x wherever its corresponding condition is True. It should take a 1d array as input and give an 1d array or a scalar value as output. If, instead of a callable, a scalar is provided then a constant function (
lambda x: scalar
) is assumed.args (tuple, optional) β Any further arguments given to piecewise are passed to the functions upon execution, i.e., if called
piecewise(..., ..., 1, 'a')
, then each function is called asf(x, 1, 'a')
.kw (dict, optional) β Keyword arguments used in calling piecewise are passed to the functions upon execution, i.e., if called
piecewise(..., ..., alpha=1)
, then each function is called asf(x, alpha=1)
.
- Returns
out β The output is the same shape and type as x and is found by calling the functions in funclist on the appropriate portions of x, as defined by the boolean arrays in condlist. Portions not covered by any condition have a default value of 0.
- Return type
Notes
This is similar to choose or select, except that functions are evaluated on elements of x that satisfy the corresponding condition from condlist.
The result is:
|-- |funclist[0](x[condlist[0]]) out = |funclist[1](x[condlist[1]]) |... |funclist[n2](x[condlist[n2]]) |--
Examples
Define the sigma function, which is -1 for
x < 0
and +1 forx >= 0
.>>> x = np.linspace(-2.5, 2.5, 6) >>> np.piecewise(x, [x < 0, x >= 0], [-1, 1]) array([-1., -1., -1., 1., 1., 1.])
Define the absolute value, which is
-x
forx <0
andx
forx >= 0
.>>> np.piecewise(x, [x < 0, x >= 0], [lambda x: -x, lambda x: x]) array([2.5, 1.5, 0.5, 0.5, 1.5, 2.5])
Apply the same function to a scalar value.
>>> y = -2 >>> np.piecewise(y, [y < 0, y >= 0], [lambda x: -x, lambda x: x]) array(2)