shard_weight
- paddle.distributed. shard_weight ( key: str, weight: Tensor, axis: int, group: Group ) ShardedWeight [source]
-
Creates a ShardedWeight by splitting the input tensor along a specified axis.
- Parameters
-
key – Unique identifier for the tensor.
weight – The input tensor to be sharded.
axis – The axis along which to shard the tensor.
group – The process group used for distributed communication.
- Returns
-
A ShardedWeight representing the local portion of the global tensor.