moe_permute
- paddle.nn.functional. moe_permute ( hidden_states: Tensor, scale: Tensor | None, expert_routemap_topk: Tensor, expert_prob_topk: Tensor, num_experts: int, tokens_per_expert: list, padding_alignment: int, do_gather: bool = True, name: str | None = None ) tuple[Tensor, Tensor, Tensor, Tensor] [source]
-
Permute tokens for Mixture of Experts (MoE) computation in distributed training scenarios.
Note
This function reorganizes input tokens based on expert assignments to prepare for expert computation. It handles both bfloat16 and float8_e4m3fn data types with proper scaling for float8 inputs.
This function is typically used in pair of moe_unpermute to provide complete MoE functionality.
For float8 inputs, proper scaling must be provided via the scale parameter.
The padding_alignment parameter affects memory efficiency but not correctness.
Any output tokens can find an exact-match in the original input tokens.
This permute function has overcomed the aadiff issue, is deterministic.
- Parameters
-
hidden_states (Tensor) – The input tensor containing tokens to be permuted, stored in row-major layout. Supported data types: bfloat16 or float8_e4m3fn. Shape: [sequence_length, token_dimension]
scale (Tensor|None) – Scaling factors required when hidden_states is of float8 type. For float8 inputs, this tensor provides the scaling factors for dequantization. Shape: [sequence_length, ceil(token_dimension / 128)] Data type: float32
expert_routemap_topk (Tensor) – Tensor indicating expert assignments for each token (top-k experts). Each value represents the expert index the token is assigned to (-1 indicates not assigned). Shape: [sequence_length, top_k_experts] Data type: int32 Value range: [-1, num_experts)
expert_prob_topk (Tensor) – Tensor containing routing probabilities for top-k experts. Shape: [sequence_length, top_k_experts] Data type: float32
num_experts (int) – Total number of experts in the MoE layer, limited between 1 and 64.
tokens_per_expert (list[int]) – List where each element indicates the number of tokens assigned to the corresponding expert.
padding_alignment (int) – Tokens alignment requirement for expert buffers (in bytes). Must be a power of 2. Typical values are 16, 32 or 64 for optimal memory access.
do_gather (bool) – Decide whether do actual tokens gather operation or not, default is True.
name (str|None, optional) – Name prefix for the operation (optional). Default: None
- Returns
-
-
- hidden_states_unzipped (Tensor): The permuted and broadcasted input tensor.
-
Shape: [total_tokens_after_broadcast, token_dimension] Data type: same as input hidden_states
-
- zipped_expertwise_rowmap (Tensor): Mapping tensor used to restore original order (unpermute).
-
Shape: [sequence_length, num_experts] Data type: int32
-
- token_prob_unzipped (Tensor): Flattened expert probabilities aligned with permuted tokens.
-
Shape: [total_tokens_after_broadcast, 1] Data type: float32
-
- scale_unzipped (Tensor): Broadcasted scale tensor (only valid for float8 inputs).
-
Shape: [total_tokens_after_broadcast, ceil(token_dimension / 128)] Data type: float32
-
- Return type
-
tuple[Tensor, Tensor, Tensor, Tensor]
Examples
>>> >>> >>> import paddle >>> import numpy as np >>> import paddle.nn.functional as F >>> hidden_states = paddle.randn([3, 128], dtype='bfloat16') >>> expert_routemap_topk = paddle.to_tensor([[-1, 0, -1, -1, 2, -1, -1, -1], ... [1, -1, -1, -1, -1, -1, -1, -1], ... [-1, -1, -1, -1, -1, -1, 1, -1]], ... dtype='int32') >>> expert_prob_topk= paddle.to_tensor([[0.0, 0.6, 0.0, 0.0, 0.4, 0.0, 0.0, 0.0], ... [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]], ... dtype='float32') >>> num_experts = 3 >>> tokens_per_expert = [1, 2, 1] >>> padding_alignment = 2 >>> hidden_states_unzipped, zipped_expertwise_rowmap, token_prob_unzipped, scale_unzipped = F.moe_permute( ... hidden_states, ... None, ... expert_routemap_topk, ... expert_prob_topk, ... num_experts, ... tokens_per_expert, ... padding_alignment, ... ) >>> # weighted by probs. >>> hidden_states_unzipped = (hidden_states_unzipped.astype("float32") * token_prob_unzipped.astype("float32").unsqueeze(-1)).astype("bfloat16") >>> zipped_tokens, zipped_probs = F.moe_unpermute(hidden_states_unzipped, zipped_expertwise_rowmap, expert_routemap_topk, token_prob_unzipped,3,3) >>> np.testing.assert_allclose(zipped_tokens.numpy(), hidden_states.numpy(), rtol=1e-05, atol=1e-06)