load_merged_state_dict

paddle.distributed. load_merged_state_dict ( path: str, prefix: Optional[str] = None, unique_id: Optional[int] = None, offload: bool = False, aoa_config: Optional[dict[str, list[str]]] = None, safetensors: bool = False ) dict[str, paddle.Tensor] [source]

Load the distributed checkpoint and merge it to unsharded state_dict.

Parameters
  • path (str) – The directory to load checkpoint files.

  • prefix (str) – The flat_mapping prefix of state_dict key. e.g., ‘model’, Default None.

  • unique_id (int) – The unique id of checkpoint, used to distinguish between different checkpoint versions. Default is None, in which case the id the max id of given path, and the newest version checkpoint is loaded.

  • offload (bool) – Whether to offload the checkpoint data from GPU to CPU, set to True if GPU memory is not enough.

  • aoa_config (dict[str, list[str]]) – AOA config to change parameters. Default is None.

  • safetensors (bool) – Whether to use safetensors format. Default is False.

Returns

Merged state_dict.

Return type

dict

Example

>>> 
>>> import paddle
>>> import paddle.distributed as dist
>>> ckpt_path = "./checkpoint"
>>> w1 = paddle.arange(32).reshape([4, 8])
>>> mesh = dist.ProcessMesh([0, 1])
>>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)])
>>> state_dict = {"w1": sharded_w1}
>>> dist.save_state_dict(state_dict, ckpt_path) # save sharded checkpoint

>>> 
>>> import paddle
>>> import paddle.distributed as dist
>>> ckpt_path = "./checkpoint"
>>> unsharded_state_dict = dist.load_merged_state_dict(ckpt_path)  # load unsharded checkpoint
>>> print(f"unsharded_state_dict:{unsharded_state_dict}")
unsharded_state_dict:{'w1':
[[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ],
 [8 , 9 , 10, 11, 12, 13, 14, 15],
 [16, 17, 18, 19, 20, 21, 22, 23],
 [24, 25, 26, 27, 28, 29, 30, 31]])}
>>>