PrepareContextParallel

class paddle.distributed. PrepareContextParallel ( backend: str = 'p2p' ) [source]

Prepare Input for context parallel optimizations.

This will work for Layer that calls like whole-llama Layer which is the first layer in the network.

Users can set backend=’p2p/all2all’ for different context parallel strategys.

backend=’p2p’ will use Ring FlashAttention strategy which segments input with balance in the sequence dimension before whole-llama Layer. backend=’all2all’ will use Deepspeed Ulysses strategy(Paddle SegmentParallel strategy) which segments input in the sequence dimension before whole-llama Layer.

Parameters

backend (string) – select strategy for context parallel, now support ‘p2p’ and ‘all2all’.

Examples


>>> import paddle
>>> import paddle.distributed as dist
>>> class SDPALayer(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...
...     def forward(self, q, k, v):
...         return paddle.nn.functional.scaled_dot_product_attention(q, k, v)
>>>
>>> class AttentionLayer(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.hidden_size = 64
...         self.num_key_value_heads = 10
...         self.head_dim = 64
...         self.sdpa = SDPALayer()
...         self.q = paddle.nn.Linear(
...             self.hidden_size,
...             self.hidden_size,
...             bias_attr=False,
...         )
...         self.k = paddle.nn.Linear(
...             self.hidden_size,
...             self.num_key_value_heads * self.head_dim,
...             bias_attr=False,
...         )
...         self.v = paddle.nn.Linear(
...             self.hidden_size,
...             self.num_key_value_heads * self.head_dim,
...             bias_attr=False,
...         )
...
...     def forward(self, input):
...         q = self.q(input)
...         k = self.k(input)
...         v = self.v(input)
...         return self.sdpa(q, k, v)
>>>
>>> class LlamaLayer(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.attention = AttentionLayer()
...
...     def forward(self, input, label):
...         return self.attention(input)
>>>
>>> class LlamaForCausalLayer(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.llama = LlamaLayer()
...         self.weight = self.create_parameter(shape=[64, 1024])
...         self.loss_func = paddle.nn.CrossEntropyLoss()
...
...     def forward(self, input, label):
...         out = self.llama(input, label)
...         logits = paddle.matmul(out, self.weight)
...         loss = self.loss_func(logits, label)
...         return logits
>>>
>>> 
>>> layer = LlamaForCausalLayer()
>>> mp_config = {
...     'llama': dist.PrepareContextParallel('p2p'),
...     'sdpa': dist.ContextParallel('p2p'),
... }