shard_weight

paddle.distributed. shard_weight ( key: str, weight: Tensor, axis: int, group: Group ) ShardedWeight [source]

Creates a ShardedWeight by splitting the input tensor along a specified axis.

Parameters
  • key – Unique identifier for the tensor.

  • weight – The input tensor to be sharded.

  • axis – The axis along which to shard the tensor.

  • group – The process group used for distributed communication.

Returns

A ShardedWeight representing the local portion of the global tensor.