take_along_axis

paddle. take_along_axis ( arr: Tensor, indices: Tensor, axis: int, broadcast: bool = True, *, out: Tensor | None = None ) Tensor [source]

Take values from the input array by given indices matrix along the designated axis.

Note

Alias Support: The parameter name input can be used as an alias for arr, and dim can be used as an alias for axis. For example, repeat_interleave(input=tensor_arr, dim=1, ...) is equivalent to repeat_interleave(arr=tensor_arr, axis=1, ...).

Parameters
  • arr (Tensor) – The input Tensor. Supported data types are bfloat16, float16, float32, float64, int32, int64, uint8. alias: input.

  • indices (Tensor) – Indices to take along each 1d slice of arr. This must match the dimension of arr, and need to broadcast against arr. Supported data type are int32 and int64.

  • axis (int) – The axis to take 1d slices along. alias: dim.

  • broadcast (bool, optional) – whether the indices broadcast.

  • out (Tensor, optional) – The output Tensor. If set, the output will be written to this Tensor.

Returns

Tensor, The indexed element, same dtype with arr.

Examples

>>> import paddle

>>> x = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7,8,9]])
>>> index = paddle.to_tensor([[0]])
>>> axis = 0
>>> result = paddle.take_along_axis(x, index, axis)
>>> print(result)
Tensor(shape=[1, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
[[1, 2, 3]])