fused_weighted_swiglu_act_quant

paddle.incubate.nn.functional. fused_weighted_swiglu_act_quant ( x: Tensor, prob: Optional[Tensor] = None, using_pow2_scaling: bool = False, name: Optional[str] = None ) tuple[paddle.Tensor, paddle.Tensor] [source]

Applies fused weighted SwiGLU activation followed by quantization to float8_e4m3fn format.

Note

This function combines four operations into a single optimized CUDA kernel: 1. SwiGLU activation: SwiGLU(x1, x2) = SiLU(x1) * x2 = (x1 * sigmoid(x1)) * x2 2. Probability weighting: multiply by optional probability factors 3. Activation computation: compute final activation values in float32 precision 4. Quantization: convert results to float8_e4m3fn with computed scaling factors

The input tensor is split into two halves along the last dimension: - Left half [0, cols/2): first input to SwiGLU (gate values) - Right half [cols/2, cols): second input to SwiGLU (activation values)

Parameters
  • x (Tensor) – Input tensor with dtype bfloat16 and shape […, cols], where cols must be even. The tensor is interpreted as two concatenated matrices: gate values [0:cols/2] and activation values [cols/2:cols]. Typical shapes: [batch_size, sequence_length, hidden_dim] or [tokens, expert_dim] in MoE scenarios.

  • prob (Tensor, optional) – Probability weighting tensor with dtype float32 and shape matching x’s batch dimensions […]. Each value multiplies the corresponding row’s activation output.

  • using_pow2_scaling (bool, optional) – Whether to use power-of-2 quantization scaling for hardware efficiency.

Returns

  • out (Tensor). Quantized activation output with dtype float8_e4m3fn and shape […, cols/2]. Contains the quantized SwiGLU results.

  • scale (Tensor). Dequantization scales with dtype float32 and shape […, (cols/2 + 127) // 128]. Each scale corresponds to a 128-element block in the output tensor. To dequantize: original_value = quantized_value / scale.

Return type

tuple

Examples

>>> 
>>> import paddle
>>> import paddle.incubate.nn.functional as F
>>> paddle.set_device('gpu')

>>> batch_size, seq_len, expert_dim = 32, 128, 2048
>>> x = paddle.randn([batch_size, seq_len, expert_dim], dtype='bfloat16')
>>> quantized_out, scales = F.fused_weighted_swiglu_act_quant(x)
>>> print(x.shape)
[32, 128, 2048]
>>> print(quantized_out.shape)
[4096, 1024]
>>> print(scales.shape)
[4096, 8]