SequenceParallelEnable
对标识 Layer 进行序列并行策略。
备注
被标识的 Layer 的输入张量的形状应该为 [b, s, h]。
代码示例
>>> import paddle
>>> import paddle.distributed as dist
>>> class MLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = paddle.nn.Linear(8, 8)
...         self.fc2 = paddle.nn.Linear(8, 8)
...
...     def forward(self, input):
...         return self.fc2(self.fc1(input))
>>> layer = MLP()
>>> mp_config = {
...     'fc1': dist.SequenceParallelEnable()
... }