case

paddle.static.nn. case ( pred_fn_pairs, default=None, name=None ) [source]
Api_attr

Static Graph

This operator works like an if-elif-elif-else chain.

Parameters
  • pred_fn_pairs (list|tuple) – A list or tuple of (pred, fn) pairs. pred is a boolean Tensor whose numel should be 1 (shape [] or shape [1]), fn is a callable. All callables return the same structure of Tensors.

  • default (callable, optional) – Callable that returns a structure of Tensors.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to api_guide_Name.

Returns

Tensors returned by the callable from the first pair whose pred is True, or Tensors returned by default if no pred in pred_fn_pairs is True and default is not None, or Tensors returned by the last callable in pred_fn_pairs if no pred in pred_fn_pairs is True and default is None.

Return type

Tensor|list(Tensor)

Raises
  • TypeError – If the type of pred_fn_pairs is not list or tuple.

  • TypeError – If the type of elements in pred_fn_pairs is not tuple.

  • TypeError – If the size of tuples in pred_fn_pairs is not 2.

  • TypeError – If the first element of 2-tuple in pred_fn_pairs is not a Tensor.

  • TypeError – If the second element of 2-tuple in pred_fn_pairs is not callable.

  • TypeError – If default is not None but it is not callable.

Examples

>>> import paddle
>>> paddle.enable_static()

>>> def fn_1():
...     return paddle.full(shape=[1, 2], dtype='float32', fill_value=1)

>>> def fn_2():
...     return paddle.full(shape=[2, 2], dtype='int32', fill_value=2)

>>> def fn_3():
...     return paddle.full(shape=[3], dtype='int32', fill_value=3)

>>> main_program = paddle.static.default_startup_program()
>>> startup_program = paddle.static.default_main_program()

>>> with paddle.static.program_guard(main_program, startup_program):
...     x = paddle.full(shape=[1], dtype='float32', fill_value=0.3)
...     y = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
...     z = paddle.full(shape=[1], dtype='float32', fill_value=0.2)

...     pred_1 = paddle.less_than(z, x)  # true: 0.2 < 0.3
...     pred_2 = paddle.less_than(x, y)  # false: 0.3 < 0.1
...     pred_3 = paddle.equal(x, y)      # false: 0.3 == 0.1

...     # Call fn_1 because pred_1 is True
...     out_1 = paddle.static.nn.case(
...         pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)

...     # Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called.
...     # because fn_3 is the last callable in pred_fn_pairs.
...     out_2 = paddle.static.nn.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])

...     exe = paddle.static.Executor(paddle.CPUPlace())
...     res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2])
...     print(res_1, res_2)
[[1. 1.]] [3 3 3]