ShardedWeight

class paddle.distributed. ShardedWeight ( key: str, local_tensor: Tensor, local_shape: tuple[int, ...], global_shape: tuple[int, ...], global_offset: tuple[int, ...], is_flattened: bool = False, flattened_range: slice | None = None ) [source]

Represents a local shard of a distributed tensor parameter.

Parameters
  • key (str) – The name of the parameter.

  • local_tensor (Tensor) – The local shard of the parameter.

  • local_shape (Tuple[int, ...]) – The shape of the local shard.

  • global_shape (Tuple[int, ...]) – The global logical shape of the parameter.

  • global_offset (Tuple[int, ...]) – The offset of the local shard in the global parameter.

  • is_flattened (bool, optional) – Whether the parameter has been flattened (used in sharding_v2 scenarios). Default is False.

  • flattened_range (slice, optional) – If the parameter is flattened, this indicates the index range of the actual local shard within the local_tensor.