sdpa_kernel

paddle.nn.attention. sdpa_kernel ( backends: list[paddle.nn.attention.sdpa.SDPBackend] | paddle.nn.attention.sdpa.SDPBackend, set_priority: bool = False ) [source]

Context manager to select which backend to use for scaled dot product attention.

Warning

This function is beta and subject to change.

Parameters
  • backends (Union[list[SDPBackend], SDPBackend]) – A backend or list of backends for scaled dot product attention.

  • set_priority (bool, optional) – Whether the ordering of the backends is interpreted as their priority order. Default: False.

Example

>>> import paddle
>>> from paddle.nn.functional import scaled_dot_product_attention
>>> from paddle.nn.attention import SDPBackend, sdpa_kernel
>>> # Create dummy tensors
>>> query = paddle.rand(shape=[2, 4, 8, 16])
>>> key = paddle.rand(shape=[2, 4, 8, 16])
>>> value = paddle.rand(shape=[2, 4, 8, 16])
>>> # Example 1: Only enable math backend
>>> with sdpa_kernel(SDPBackend.MATH):
...     out = scaled_dot_product_attention(query, key, value)
>>> print(out.shape)
[2, 4, 8, 16]
>>> # Example 2: Enable multiple backends
>>> with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
...     out = scaled_dot_product_attention(query, key, value)
>>> print(out.shape)
[2, 4, 8, 16]
>>> # Example 3: Set priority order for multiple backends
>>> with sdpa_kernel(
...     [SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION],
...     set_priority=True,
... ):
...     out = scaled_dot_product_attention(query, key, value)
>>> print(out.shape)
[2, 4, 8, 16]
>>> 
>>> # Example 4: Flash attention (skipped due to environment requirements)
>>> with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
...     out = scaled_dot_product_attention(query, key, value)
>>> 

This context manager can be used to select which backend to use for scaled dot product attention. Upon exiting the context manager, the previous state of the flags will be restored.