switch_case

paddle.static.nn. switch_case ( branch_index, branch_fns, default=None, name=None ) [source]
Api_attr

Static Graph

This operator is like a C++ switch/case statement.

Parameters
  • branch_index (Tensor) – A Tensor whose numel should be 1 (shape [] or shape [1]) to specify which branch to execute. The data type is int32, int64 or uint8.

  • branch_fns (dict|list|tuple) – If it’s a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it’s a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors.

  • default (callable, optional) – Callable that returns a structure of Tensors.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to api_guide_Name.

Returns

Tensors returned by the callable specified by branch_index in branch_fns, or Tensors returned by default if default is not None and no index matches in branch_fns, or Tensors returned by the callable with the max index in branch_fns if default is None and no index matches in branch_fns.

Return type

Tensor|list(Tensor)

Raises
  • TypeError – If the type of branch_index is not Tensor.

  • TypeError – If the data type of branch_index is not int32, int64 or uint8.

  • TypeError – If the type of branch_fns is not dict, list or tuple.

  • TypeError – If the elements of branch_fns is not 2-tuple.

  • TypeError – If the first element of 2-tuple in branch_fns is not integer.

  • ValueError – If the first element of 2-tuple in branch_fns is not unique.

  • TypeError – If the second element of 2-tuple in branch_fns is not callable.

  • TypeError – If default is not None but it is not callable.

Examples

>>> import paddle
>>> paddle.enable_static()

>>> def fn_1():
...    return paddle.full(shape=[1, 2], dtype='float32', fill_value=1)

>>> def fn_2():
...    return paddle.full(shape=[2, 2], dtype='int32', fill_value=2)

>>> def fn_3():
...    return paddle.full(shape=[3], dtype='int32', fill_value=3)

>>> startup_program = paddle.static.default_startup_program()
>>> main_program = paddle.static.default_main_program()
>>> with paddle.static.program_guard(main_program, startup_program):
...    index_1 = paddle.full(shape=[1], dtype='int32', fill_value=1)
...    index_2 = paddle.full(shape=[1], dtype='int32', fill_value=2)
...
...    out_1 = paddle.static.nn.switch_case(
...        branch_index=index_1,
...        branch_fns={1: fn_1, 2: fn_2},
...        default=fn_3)
...
...    out_2 = paddle.static.nn.switch_case(
...        branch_index=index_2,
...        branch_fns=[(1, fn_1), (2, fn_2)],
...        default=fn_3)
...
...    # Argument default is None and no index matches. fn_3 will be called because of the max index 7.
...    out_3 = paddle.static.nn.switch_case(
...        branch_index=index_2,
...        branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])
...
...    exe = paddle.static.Executor(paddle.CPUPlace())
...    res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
...    # Variable: fill_constant_1.tmp_0
...    #   - message: The content of input layer:
...    #   - lod: {}
...    #   - place: Place(cpu)
...    #   - shape: [2, 3]
...    #   - layout: NCHW
...    #   - dtype: int64
...    #   - data: [3 3 3 3 3 3]

>>> print(res_1)
[[1. 1.]]

>>> print(res_2)
[[2 2]
 [2 2]]

>>> print(res_3)
[3 3 3]