fused_swiglu_weighted_bwd

paddle.incubate.nn.functional. fused_swiglu_weighted_bwd ( o1: Tensor, do2_s: Tensor, unzipped_probs: Tensor, name: Optional[str] = None ) tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor] [source]

Computes gradients for fused weighted SwiGLU activation function in backward pass.

Note

This function performs the backward propagation for the SwiGLU (Swish-Gated Linear Unit) activation with probability weighting. It computes gradients with respect to both the input activations and the probability weights, while also recomputing forward pass values for memory efficiency. The kernel automatically selects between vectorized and standard implementations based on input dimensions.

Parameters
  • o1 (Tensor) – Forward pass input tensor with dtype bfloat16 and shape […, intermediate_size * 2]. The tensor is split into two halves: - Left half [0:intermediate_size]: x1 values (gate inputs) - Right half [intermediate_size:]: x2 values (activation inputs) This is the same input used in the forward SwiGLU computation.

  • do2_s (Tensor) – Upstream gradient tensor with dtype bfloat16 and shape […, intermediate_size]. Contains gradients flowing back from the next layer, representing ∂L/∂output before probability weighting. Each element corresponds to the gradient of one output element.

  • unzipped_probs (Tensor) – Probability weighting tensor with dtype float32 and shape matching the batch dimensions of o1 and do2_s […]. Each probability value was used to weight the corresponding row’s output in the forward pass.

Returns

  • do1 (Tensor). Input gradients with dtype bfloat16 and shape […, intermediate_size * 2]. Layout matches o1: - [0:intermediate_size]: ∂L/∂x1 (gradients w.r.t. gate inputs) - [intermediate_size:]: ∂L/∂x2 (gradients w.r.t. activation inputs)

  • probs_grad (Tensor). Probability gradients with dtype float32 and shape […]. Each element is ∂L/∂prob for the corresponding batch item, computed as the sum of (∂L/∂output_i * SwiGLU_output_i) across the intermediate dimension.

  • o2_s (Tensor). Recomputed forward output with dtype bfloat16 and shape […, intermediate_size]. Contains SwiGLU(x1, x2) * prob values. This avoids storing forward activations, trading computation for memory.

Return type

tuple

Examples

>>> 
>>> import paddle
>>> import paddle.incubate.nn.functional as F
>>> paddle.set_device('gpu')

>>> batch_size, seq_len = 32, 128
>>> intermediate_size = 2048

>>> o1 = paddle.randn([batch_size, seq_len, intermediate_size * 2], dtype='bfloat16')
>>> do2_s = paddle.randn([batch_size, seq_len, intermediate_size], dtype='bfloat16')
>>> expert_probs = paddle.rand([batch_size, seq_len, 1], dtype='float32')

>>> do1, probs_grad, o2_s = F.fused_swiglu_weighted_bwd(o1, do2_s, expert_probs)
>>> print(do1.shape)
[32, 128, 4096]
>>> print(probs_grad.shape)
[32, 128, 1]
>>> print(o2_s.shape)
[32, 128, 2048]