index_select

paddle. index_select ( x: Tensor, index: Tensor, axis: int = 0, name: str | None = None, *, out: Tensor | None = None ) Tensor [source]

Returns a new tensor which indexes the input tensor along dimension axis using the entries in index which is a Tensor. The returned tensor has the same number of dimensions as the original x tensor. The dim-th dimension has the same size as the length of index; other dimensions have the same size as in the x tensor.

Note

Alias and Order Support: 1. The parameter name input can be used as an alias for x. 2. The parameter name dim can be used as an alias for axis. 3. This API also supports the PyTorch argument order (input, dim, index) for positional arguments, which will be converted to the Paddle order (x, index, axis). For example, paddle.index_select(input=x, dim=1, index=idx) is equivalent to paddle.index_select(x=x, axis=1, index=idx), and paddle.index_select(x, 1, idx) is equivalent to paddle.index_select(x, idx, axis=1).

Parameters
  • x (Tensor) – The input Tensor to be operated. The data of x can be one of float16, float32, float64, int32, int64, complex64 and complex128. alias: input.

  • index (Tensor) – The 1-D Tensor containing the indices to index. The data type of index must be int32 or int64.

  • axis (int, optional) – The dimension in which we index. Default: if None, the axis is 0. alias: dim.

  • name (str|None, optional) – For details, please refer to api_guide_Name. Generally, no setting is required. Default: None.

Keyword Arguments

out (Tensor|None, optional) – The output tensor. Default: None.

Returns

Tensor, A Tensor with same data type as x.

Examples

>>> import paddle

>>> x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
...                       [5.0, 6.0, 7.0, 8.0],
...                       [9.0, 10.0, 11.0, 12.0]])
>>> index = paddle.to_tensor([0, 1, 1], dtype='int32')
>>> out_z1 = paddle.index_select(x=x, index=index)
>>> print(out_z1.numpy())
[[1. 2. 3. 4.]
 [5. 6. 7. 8.]
 [5. 6. 7. 8.]]
>>> out_z2 = paddle.index_select(x=x, index=index, axis=1)
>>> print(out_z2.numpy())
[[ 1.  2.  2.]
 [ 5.  6.  6.]
 [ 9. 10. 10.]]