Warning
This page was created from a pull request.
jax.lax.gather¶
-
jax.lax.
gather
(operand, start_indices, dimension_numbers, slice_sizes)[source]¶ Gather operator.
Wraps XLA’s Gather operator.
The semantics of gather are complicated, and its API might change in the future. For most use cases, you should prefer Numpy-style indexing (e.g., x[:, (1,4,7), …]), rather than using gather directly.
- Parameters
operand (
Any
) – an array from which slices should be takenstart_indices (
Any
) – the indices at which slices should be takendimension_numbers (
GatherDimensionNumbers
) – a lax.GatherDimensionNumbers object that describes how dimensions of operand, start_indices and the output relate.slice_sizes (
Sequence
[int
]) – the size of each slice. Must be a sequence of non-negative integers with length equal to ndim(operand).
- Return type
- Returns
An array containing the gather output.