build_sharded_state_dict
- paddle.distributed. build_sharded_state_dict ( state_dict: dict[str, Tensor], shard_rules: dict[str, int] | None = None, prefix: str = '' ) dict[str, ShardedWeight] [source]
-
Converts a regular state dict to a sharded state dict based on sharding rules.
- Parameters
-
state_dict – The original state dictionary containing tensors
shard_rules – Dictionary mapping tensor names to their sharding axes. If None, treated as empty dict (no tensor parallelism).
prefix – Optional prefix to prepend to all tensor keys
- Returns
-
Dictionary with the same keys as input but values converted to ShardedWeight or regular Tensor based on sharding rules.
Note
Tensors not in shard_rules will be wrapped as non-sharded ShardedWeights.