SplitPoint¶
- class paddle.distributed. SplitPoint ( value ) [source]
-
Marking the position of the split. BEGINNING: will split the model before the specified layer. END: will split the model after the specified layer.
Examples
>>> import paddle >>> import paddle.distributed as dist >>> class MLP(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.fc1 = paddle.nn.Linear(8, 8) ... self.fc2 = paddle.nn.Linear(8, 8) ... ... def forward(self, input): ... return self.fc2(self.fc1(input)) >>> >>> layer = MLP() >>> pp_config = { ... 'fc1': dist.SplitPoint.END ... }