fused_transpose_split_quant

paddle.incubate.nn.functional. fused_transpose_split_quant ( x, input_scales, tokens_per_expert, pow_2_scales=False ) [source]

Applies fused transpose, split, and quantization operation for Mixture of Experts (MoE) models.

Note

This function performs three operations in a single optimized CUDA kernel: 1. Quantizes input from bfloat16 to float8_e4m3fn format using column-wise scaling 2. Transposes the matrix from [M, K] to [K, M] layout 3. Splits the transposed data across multiple experts based on token distribution

Parameters
  • x (Tensor) – Input tensor of shape [M, K] with dtype bfloat16, where M is the total number of tokens and K is the feature dimension. M must be divisible by 128 for optimal performance.

  • tokens_per_expert (List[int]) – List containing the number of tokens assigned to each expert. Each value should be a multiple of 128 for optimal performance. The sum should equal M (total tokens). Values can be 0 for unused experts.

  • pow_2_scales (bool, optional) – Whether to constrain quantization scales to powers of 2 for better hardware efficiency. If True, scales will be rounded to the nearest power of 2. Default: False.

Returns

  • outs (List[Tensor]). List of quantized and transposed output tensors, one per expert. Each tensor has shape [K, tokens_per_expert[i]] and dtype float8_e4m3fn. Empty tensors are included for experts with 0 tokens.

  • scales (List[Tensor]). List of dequantization scale tensors, one per expert. Each tensor has shape [K // 128, tokens_per_expert[i] // 128] and dtype float32. These are the reciprocal of quantization scales.

Return type

tuple

Examples

>>> 
>>> import paddle
>>> import paddle.incubate.nn.functional as F
>>> paddle.set_device('gpu')

>>> x = paddle.randn([384, 512], dtype='bfloat16')
>>> x = paddle.clip(x, min=-50, max=50)
>>> tokens_per_expert = [128, 128, 128]
>>> outs, scales = F.fused_transpose_split_quant(x,None, tokens_per_expert, pow_2_scales=True)
>>> print(outs[0].shape)
[512, 128]
>>> print(scales[0].shape)
[1, 512]