moe_unpermute
- paddle.nn.functional. moe_unpermute ( hidden_states_unzipped: Tensor, zipped_expertwise_rowmap: Tensor, expert_routemap_topk: Tensor, token_prob_unzipped: Tensor, total_zipped_tokens: int, num_experts: int, use_mix_precision: bool = True, name: str | None = None ) tuple[Tensor, Tensor] [source]
-
- Parameters
-
hidden_states_unzipped (Tensor) – The input Tensor containing broadcasted and permuted hidden states. Shape: (seqlen_broadcasted, token_len). Dtype: bfloat16.
zipped_expertwise_rowmap (Tensor) – The input Tensor recording the mapping relationship for unpermute operation. Shape: (seqlen, num_experts). Dtype: int32.
expert_routemap_topk (Tensor) – The input Tensor indicating which expert each token is assigned to. Shape: (seqlen, 8). Value range: [-1, num_experts]. Dtype: int32.
token_prob_unzipped (Tensor) – The input Tensor containing flattened expert probabilities corresponding to hidden_states_unzipped. Shape: (seqlen_broadcasted, 1). Dtype: float32.
total_zipped_tokens_num (int) – The total number of tokens before permutation for output buffer allocation. Dtype: int32.
num_experts (int) – The number of experts. Dtype: int32.
use_mix_precision (bool, optional) – Whether to use mixed precision during accumulation. This option significantly improves precision when number of experts > 4. Default: True.
name (str|None, optional) – Name for the operation. Default: None.
- Returns
-
- A tuple containing:
-
hidden_states (Tensor): The output Tensor with unpermuted tokens. Shape: (seqlen, token_len). Dtype: bfloat16.
expert_prob_topk (Tensor): The output Tensor with unpermuted probabilities. Shape: (seqlen, topk). Dtype: float32.
- Return type
-
tuple[Tensor, Tensor]
Examples
>>> >>> >>> import paddle >>> import numpy as np >>> import paddle.nn.functional as F >>> hidden_states = paddle.randn([3, 128], dtype='bfloat16') >>> expert_routemap_topk = paddle.to_tensor([[-1, 0, -1, -1, 2, -1, -1, -1], ... [1, -1, -1, -1, -1, -1, -1, -1], ... [-1, -1, -1, -1, -1, -1, 1, -1]], ... dtype='int32') >>> expert_prob_topk= paddle.to_tensor([[0.0, 0.6, 0.0, 0.0, 0.4, 0.0, 0.0, 0.0], ... [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]], ... dtype='float32') >>> num_experts = 3 >>> tokens_per_expert = [1, 2, 1] >>> padding_alignment = 2 >>> hidden_states_unzipped, zipped_expertwise_rowmap, token_prob_unzipped, scale_unzipped = F.moe_permute( ... hidden_states, ... None, ... expert_routemap_topk, ... expert_prob_topk, ... num_experts, ... tokens_per_expert, ... padding_alignment, ... ) >>> # weighted by probs. >>> hidden_states_unzipped = (hidden_states_unzipped.astype("float32") * token_prob_unzipped.astype("float32").unsqueeze(-1)).astype("bfloat16") >>> zipped_tokens, zipped_probs = F.moe_unpermute(hidden_states_unzipped, zipped_expertwise_rowmap, expert_routemap_topk, token_prob_unzipped,3,3) >>> np.testing.assert_allclose(zipped_tokens.numpy(), hidden_states.numpy(), rtol=1e-05, atol=1e-06)