split

paddle.compat. split ( tensor: Tensor, split_size_or_sections: int | Sequence[int], dim: int = 0 ) tuple[Tensor, ...] [source]

(PyTorch Compatible API) Split the input tensor into multiple sub-Tensors.

Parameters
  • tensor (Tensor) – A N-D Tensor. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64.

  • split_size_or_sections (int|list|tuple) – If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size. If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections. Negative inputs are not allowed. For example: for a dim with 9 channels, [2, 3, -1] will not be interpreted as [2, 3, 4], but will be rejected and an exception will be thrown.

  • dim (int|Tensor, optional) – The dim along which to split, it can be a integer or a 0-D Tensor with shape [] and data type int32 or int64. If :math::dim < 0, the dim to split along is \(rank(x) + dim\). Default is 0.

Returns

tuple(Tensor), The tuple of segmented Tensors.

Note

This is a pytorch compatible API that follows the function signature and behavior of torch.split. To use the original split of paddle, please consider paddle.split

Examples

>>> import paddle

>>> # x is a Tensor of shape [3, 8, 5]
>>> x = paddle.rand([3, 8, 5])

>>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=1)
>>> print(out0.shape)
[3, 3, 5]
>>> print(out1.shape)
[3, 3, 5]
>>> print(out2.shape)
[3, 2, 5]

>>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=[1, 2, 5], dim=1)
>>> print(out0.shape)
[3, 1, 5]
>>> print(out1.shape)
[3, 2, 5]
>>> print(out2.shape)
[3, 5, 5]

>>> # dim is negative, the real dim is (rank(x) + dim)=1
>>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=-2)
>>> print(out0.shape)
[3, 3, 5]
>>> print(out1.shape)
[3, 3, 5]
>>> print(out2.shape)
[3, 2, 5]