fused_stack_transpose_quant

paddle.incubate.nn.functional. fused_stack_transpose_quant ( x: Sequence[Tensor], transpose: bool = True ) tuple[Tensor, Tensor] [source]

Fused operation that performs stacking, optional transposition, and quantization on a list of bfloat16 tensors.

Parameters
  • x (list[Tensor] or tuple[Tensor]) – A list or tuple of bfloat16 tensors, where each tensor has shape [M, K]. All tensors should have the same shape and dtype.

  • transpose (bool, optional) – If True, applies a transpose before quantization. Default is True.

Returns

  • out (Tensor): The quantized output tensor with dtype float8_e4m3fn.

  • scale (Tensor): A float32 tensor representing the quantization scale.

Return type

tuple

Examples

>>> 
>>> import paddle
>>> import paddle.incubate.nn.functional as F
>>> paddle.set_device('gpu')

>>> x_vec = []
>>> num_experts = 1
>>> seq_len = 2048
>>> hidden_size = 128
>>> for _ in range(num_experts):
...     x = paddle.randn([seq_len, hidden_size], dtype='bfloat16')
...     x = paddle.clip(x, min=-50, max=50)
...     x_vec.append(x)

>>> out, scale = F.fused_stack_transpose_quant(x_vec, transpose=True)
>>> print(out.shape)
[128, 2048]
>>> print(scale.shape)
[1, 16]