fused_weighted_swiglu_act_quant
- paddle.incubate.nn.functional. fused_weighted_swiglu_act_quant ( x: Tensor, prob: Optional[Tensor] = None, using_pow2_scaling: bool = False, name: Optional[str] = None ) tuple[paddle.Tensor, paddle.Tensor] [source]
-
Applies fused weighted SwiGLU activation followed by quantization to float8_e4m3fn format.
Note
This function combines four operations into a single optimized CUDA kernel: 1. SwiGLU activation: SwiGLU(x1, x2) = SiLU(x1) * x2 = (x1 * sigmoid(x1)) * x2 2. Probability weighting: multiply by optional probability factors 3. Activation computation: compute final activation values in float32 precision 4. Quantization: convert results to float8_e4m3fn with computed scaling factors
The input tensor is split into two halves along the last dimension: - Left half [0, cols/2): first input to SwiGLU (gate values) - Right half [cols/2, cols): second input to SwiGLU (activation values)
- Parameters
-
x (Tensor) – Input tensor with dtype bfloat16 and shape […, cols], where cols must be even. The tensor is interpreted as two concatenated matrices: gate values [0:cols/2] and activation values [cols/2:cols]. Typical shapes: [batch_size, sequence_length, hidden_dim] or [tokens, expert_dim] in MoE scenarios.
prob (Tensor, optional) – Probability weighting tensor with dtype float32 and shape matching x’s batch dimensions […]. Each value multiplies the corresponding row’s activation output.
using_pow2_scaling (bool, optional) – Whether to use power-of-2 quantization scaling for hardware efficiency.
- Returns
-
out (Tensor). Quantized activation output with dtype float8_e4m3fn and shape […, cols/2]. Contains the quantized SwiGLU results.
scale (Tensor). Dequantization scales with dtype float32 and shape […, (cols/2 + 127) // 128]. Each scale corresponds to a 128-element block in the output tensor. To dequantize: original_value = quantized_value / scale.
- Return type
-
tuple
Examples
>>> >>> import paddle >>> import paddle.incubate.nn.functional as F >>> paddle.set_device('gpu') >>> batch_size, seq_len, expert_dim = 32, 128, 2048 >>> x = paddle.randn([batch_size, seq_len, expert_dim], dtype='bfloat16') >>> quantized_out, scales = F.fused_weighted_swiglu_act_quant(x) >>> print(x.shape) [32, 128, 2048] >>> print(quantized_out.shape) [4096, 1024] >>> print(scales.shape) [4096, 8]