build_sharded_state_dict

paddle.distributed. build_sharded_state_dict ( state_dict: dict[str, Tensor], shard_rules: dict[str, int] | None = None, prefix: str = '' ) dict[str, ShardedWeight] [source]

Converts a regular state dict to a sharded state dict based on sharding rules.

Parameters
  • state_dict – The original state dictionary containing tensors

  • shard_rules – Dictionary mapping tensor names to their sharding axes. If None, treated as empty dict (no tensor parallelism).

  • prefix – Optional prefix to prepend to all tensor keys

Returns

Dictionary with the same keys as input but values converted to ShardedWeight or regular Tensor based on sharding rules.

Note

Tensors not in shard_rules will be wrapped as non-sharded ShardedWeights.