fused_feedforward

paddle.incubate.nn.functional. fused_feedforward ( x: Tensor, linear1_weight: Tensor, linear2_weight: Tensor, linear1_bias: Tensor | None = None, linear2_bias: Tensor | None = None, ln1_scale: Tensor | None = None, ln1_bias: Tensor | None = None, ln2_scale: Tensor | None = None, ln2_bias: Tensor | None = None, dropout1_rate: float = 0.5, dropout2_rate: float = 0.5, activation: str = 'relu', ln1_epsilon: float = 1e-05, ln2_epsilon: float = 1e-05, pre_layer_norm: bool = False, training: bool = True, mode: _Mode = 'upscale_in_train', ring_id: int = -1, add_residual: bool = True, name: str | None = None ) Tensor [source]

This is a fusion operator to compute feed forward layer in transformer model architecture. This operator only supports running on GPU. The function of the operator is consistent with the following pseudo code:

>>> residual = x
>>> if pre_layer_norm:
...     out = layer_norm1(x)
...  else:
...     out = x
>>> out = linear2(dropout1(activation(linear1(src))))
>>> if add_residual:
...     out = residual + dropout2(out)
... else:
...     out = dropout2(out)
>>> if not pre_layer_norm:
...     out = layer_norm2(out)
Parameters
  • x (Tensor) – the input tensor could be 3-D tensor, the input data type could be float16, float32 or float64, the shape is`[batch_size, sequence_length, d_model]`.

  • linear1_weight (Tensor) – The weight of first linear, the data type is same as x, the shape is [d_model, dim_feedforward].

  • linear2_weight (Tensor) – The weight of second linear, the data type is same as x, the shape is [dim_feedforward, d_model].

  • linear1_bias (Tensor, optional) – The bias of first linear, the data type is same as x, the shape is [dim_feedforward]. Default None.

  • linear2_bias (Tensor, optional) – The bias of second linear, the data type is same as x, the shape is [d_model]. Default None.

  • ln1_scale (Tensor, optional) – the weight of first layer_norm, the data type is float32 or float64, the shape is same as x. Default None.

  • ln1_bias (Tensor, optional) – The bias of first layer_norm, the data type is float32 or float64, the shape is [d_model]. Default None.

  • ln2_scale (Tensor, optional) – The weight of second layer_norm, the data type is float32 or float64, the shape is same as x. Default None.

  • ln2_bias (Tensor, optional) – The bias of second layer_norm, the data type is float32 or float64, the shape is [d_model]. Default None.

  • dropout1_rate (float, optional) – The first dropout probability of setting units to zero. Default 0.5.

  • dropout2_rate (float, optional) – The second dropout probability of setting units to zero. Default 0.5.

  • activation (str, optional) – The activation. Default “relu”.

  • ln1_epsilon (float, optional) – Small float of first layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.

  • ln2_epsilon (float, optional) – Small float of second layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.

  • pre_layer_norm (bool, optional) – add layer_norm in the pre-processing stage or post-processing state.

  • training (bool, optional) – A flag indicating whether it is in train phrase or not. Default True.

  • mode (str, optional) –

    [‘upscale_in_train’(default) | ‘downscale_in_infer’]

    1. upscale_in_train(default), upscale the output at training time

      • train: out = input * mask / ( 1.0 - p )

      • inference: out = input

    2. downscale_in_infer, downscale the output at inference

      • train: out = input * mask

      • inference: out = input * (1.0 - p)

  • ring_id (int, optional) – For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using tensor parallel.

  • add_residual (bool, optional) – Whether add residual at the end. Default is True.

  • name (str, optional) – Name for the operation (optional, default is None). For more information, please refer to api_guide_Name.

Returns

The output Tensor, the data type and shape is same as x.

Return type

Tensor

Examples

>>> 
>>> import paddle
>>> paddle.device.set_device('gpu')
>>> import paddle.incubate.nn.functional as F

>>> x = paddle.randn(shape=(1, 8, 8), dtype="float32")
>>> linear1_weight = paddle.randn(shape=(8, 8), dtype="float32")
>>> linear2_weight = paddle.randn(shape=(8, 8), dtype="float32")
>>> out = F.fused_feedforward(x, linear1_weight, linear2_weight)
>>> print(out.shape)
[1, 8, 8]