ContextParallel
- class paddle.distributed. ContextParallel ( backend: str = 'p2p' ) [source]
-
Applies context parallel optimizations to the attention layer.
This will work for Layer that calls paddle.nn.functional.scaled_dot_product_attention).
Users can set backend=’p2p/all2all’ for different context parallel strategys.
backend=’p2p’ will use Ring FlashAttention strategy which segments q/k/v in the sequence dimension and communicates k/v between ranks. backend=’all2all’ will use Deepspeed Ulysses strategy(Paddle SegmentParallel strategy) which inserts all2all before and after sdpa compute.
Note:
- Parameters
-
backend (string) – select strategy for context parallel, now support ‘p2p’ and ‘all2all’.
Examples
>>> import paddle >>> import paddle.distributed as dist
>>> class SDPALayer(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... ... def forward(self, q, k, v): ... return paddle.nn.functional.scaled_dot_product_attention(q, k, v) >>> >>> class AttentionLayer(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.hidden_size = 64 ... self.num_key_value_heads = 10 ... self.head_dim = 64 ... self.sdpa = SDPALayer() ... self.q = paddle.nn.Linear( ... self.hidden_size, ... self.hidden_size, ... bias_attr=False, ... ) ... self.k = paddle.nn.Linear( ... self.hidden_size, ... self.num_key_value_heads * self.head_dim, ... bias_attr=False, ... ) ... self.v = paddle.nn.Linear( ... self.hidden_size, ... self.num_key_value_heads * self.head_dim, ... bias_attr=False, ... ) ... ... def forward(self, input): ... q = self.q(input) ... k = self.k(input) ... v = self.v(input) ... return self.sdpa(q, k, v) >>> >>> class LlamaLayer(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.attention = AttentionLayer() ... ... def forward(self, input, label): ... return self.attention(input) >>> >>> class LlamaForCausalLayer(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.llama = LlamaLayer() ... self.weight = self.create_parameter(shape=[64, 1024]) ... self.loss_func = paddle.nn.CrossEntropyLoss() ... ... def forward(self, input, label): ... out = self.llama(input, label) ... logits = paddle.matmul(out, self.weight) ... loss = self.loss_func(logits, label) ... return logits >>> >>> >>> layer = LlamaForCausalLayer() >>> mp_config = { ... 'llama': dist.PrepareContextParallel('p2p'), ... 'sdpa': dist.ContextParallel('p2p'), ... }