local_map

paddle.distributed. local_map ( func: Callable[..., Any], out_placements: list[list[dist.Placement]], in_placements: list[list[dist.Placement]] | None = None, process_mesh: ProcessMesh | None = None, reshard_inputs: bool = False ) Callable[..., Any] [source]

The local_map API allows users to pass dist_tensors to a function that is written to be applied on paddle.Tensor s. It works by extracting the local components of dist_tensors, calling the function, and wrapping the outputs as dist_tensors according to the out_placements.

Parameters
  • func (Callable) – The function to be applied on each local shard of dist_tensors.

  • out_placements (list[list[dist.Placement]]) – The desired placements for each output tensor. Must be a list where each element is a list of Placement objects specifying the distribution strategy for that output tensor. The length of the outer list must match the number of outputs from func. For non-tensor outputs, the corresponding placement must be None. When there are no dist_tensor inputs, process_mesh must be specified to use non-None placements.

  • in_placements (Optional[list[list[dist.Placement]]], optional) – The required placements for each input tensor. If specified, must be a list where each element is a list of Placement objects defining the distribution strategy for that input tensor. The length of the outer list must match the number of input tensors. Default: None

  • process_mesh (ProcessMesh, optional) – The process mesh that all dist_tensors are placed on. If not specified, this will be inferred from the input dist_tensors’ process mesh. local_map requires all dist_tensors to be placed on the same process mesh. Must be specified when there are no dist_tensor inputs but out_placements contains non-None values. Default: None

  • reshard_inputs (bool, optional) – the bool value indicating whether to reshard the input :dist_tensors when their placements are different from the required input placements. If this value is False and some :dist_tensor input has a different placement, an exception will be raised. Default: False.

Returns

A function that applies func to local shards of input dist_tensors and returns dist_tensors or original values.

Return type

Callable

Example

>>> from __future__ import annotations
>>> import paddle
>>> import paddle.distributed as dist
>>> from paddle import Tensor
>>> from paddle.distributed import ProcessMesh

>>> def custom_function(x):
...     mask = paddle.zeros_like(x)
...     if dist.get_rank() == 0:
...         mask[1:3] = 1
...     else:
...         mask[4:7] = 1
...     x = x * mask
...     mask_sum = paddle.sum(x)
...     mask_sum = mask_sum / mask.sum()
...     return mask_sum

>>> 
>>> dist.init_parallel_env()
>>> mesh = ProcessMesh([0, 1], dim_names=["x"])
>>> local_input = paddle.arange(0, 10, dtype="float32")
>>> local_input = local_input + dist.get_rank()
>>> input_dist = dist.auto_parallel.api.dtensor_from_local(
...     local_input, mesh, [dist.Shard(0)]
... )
>>> wrapped_func = dist.local_map(
...     custom_function,
...     out_placements=[[dist.Partial(dist.ReduceType.kRedSum)]],
...     in_placements=[[dist.Shard(0)]],
...     process_mesh=mesh
... )
>>> output_dist = wrapped_func(input_dist)

>>> local_value = output_dist._local_value()
>>> gathered_values: list[Tensor] = []
>>> dist.all_gather(gathered_values, local_value)

>>> print(f"[Rank 0] local_value={gathered_values[0].item()}")
[Rank 0] local_value=1.5
>>> print(f"[Rank 1] local_value={gathered_values[1].item()}")
[Rank 1] local_value=6.0
>>> print(f"global_value (distributed)={output_dist.item()}")
global_value (distributed)=7.5

>>> # This case needs to be executed in a multi-card environment
>>> # export CUDA_VISIBLE_DEVICES=0,1
>>> # python -m paddle.distributed.launch {test_case}.py