fused_stack_transpose_quant

paddle.incubate.nn.functional. fused_stack_transpose_quant ( x: Sequence[Tensor], transpose: bool = True ) tuple[Tensor, Tensor] [source]

Fused operation that performs stacking, optional transposition, and quantization on a list of bfloat16 tensors.

This API supports both dynamic and static graph modes. In dynamic mode, it invokes the corresponding C++ core op. In static mode, it appends the op manually to the graph.

Parameters
  • x (list[Tensor] or tuple[Tensor]) – A list or tuple of bfloat16 tensors, where each tensor has shape [M, N]. All tensors should have the same shape and dtype.

  • transpose (bool, optional) – If True, applies a transpose before quantization. Default is False.

Returns

  • out (Tensor): The quantized output tensor with dtype float8_e4m3fn.

  • scale (Tensor): A float32 tensor representing the quantization scale.

Return type

tuple

Raises
  • TypeError – If x is not a list or tuple of bfloat16 tensors.

  • TypeError – If transpose is not a boolean.

  • RuntimeError – If not running in dynamic mode but trying to call the dynamic op directly.

Examples

import paddle.incubate.nn.functional as F

x_vec = []
num_experts = 1
seq_len = 2048
hidden_size = 128
for _ in range(num_experts):
    x = paddle.randn([seq_len, hidden_size], dtype='bfloat16')
    x = paddle.clip(x, min=-50, max=50)
    x_vec.append(x)

out, scale = F.fused_stack_transpose_quant(x_vec, transpose=True)

print(out.shape) # [128, 2048]
print(scale.shape) # [1, 16]

out, scale = F.fused_stack_transpose_quant(x_vec, transpose=False)

print(out.shape) # [2048, 128]
print(scale.shape) # [16, 1]