1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""smart_cond and related utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.client import pywrap_tf_session as c_api
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.ops import control_flow_ops
25
26
27def smart_cond(pred, true_fn=None, false_fn=None, name=None):
28  """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
29
30  If `pred` is a bool or has a constant value, we return either `true_fn()`
31  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
32
33  Args:
34    pred: A scalar determining whether to return the result of `true_fn` or
35      `false_fn`.
36    true_fn: The callable to be performed if pred is true.
37    false_fn: The callable to be performed if pred is false.
38    name: Optional name prefix when using `tf.cond`.
39
40  Returns:
41    Tensors returned by the call to either `true_fn` or `false_fn`.
42
43  Raises:
44    TypeError: If `true_fn` or `false_fn` is not callable.
45  """
46  if not callable(true_fn):
47    raise TypeError("`true_fn` must be callable.")
48  if not callable(false_fn):
49    raise TypeError("`false_fn` must be callable.")
50
51  pred_value = smart_constant_value(pred)
52  if pred_value is not None:
53    if pred_value:
54      return true_fn()
55    else:
56      return false_fn()
57  else:
58    return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,
59                                 name=name)
60
61
62def smart_constant_value(pred):
63  """Return the bool value for `pred`, or None if `pred` had a dynamic value.
64
65  Args:
66    pred: A scalar, either a Python bool or tensor.
67
68  Returns:
69    True or False if `pred` has a constant boolean value, None otherwise.
70
71  Raises:
72    TypeError: If `pred` is not a Tensor or bool.
73  """
74  if isinstance(pred, ops.Tensor):
75    pred_value = tensor_util.constant_value(pred)
76    # TODO(skyewm): consider folding this into tensor_util.constant_value.
77    # pylint: disable=protected-access
78    if pred_value is None:
79      pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
80                                                        pred._as_tf_output())
81    # pylint: enable=protected-access
82  elif pred in {0, 1}:  # Accept 1/0 as valid boolean values
83    pred_value = bool(pred)
84  elif isinstance(pred, bool):
85    pred_value = pred
86  else:
87    raise TypeError("`pred` must be a Tensor, or a Python bool, or 1 or 0. "
88                    "Found instead: %s" % type(pred))
89
90  return pred_value
91
92
93def smart_case(pred_fn_pairs, default=None, exclusive=False, name="smart_case"):
94  """Like tf.case, except attempts to statically evaluate predicates.
95
96  If any predicate in `pred_fn_pairs` is a bool or has a constant value, the
97  associated callable will be called or omitted depending on its value.
98  Otherwise this functions like tf.case.
99
100  Args:
101    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
102                   callable which returns a list of tensors.
103    default: Optional callable that returns a list of tensors.
104    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
105    name: A name for this operation (optional).
106
107  Returns:
108    The tensors returned by the first pair whose predicate evaluated to True, or
109    those returned by `default` if none does.
110
111  Raises:
112    TypeError: If `pred_fn_pairs` is not a list/dictionary.
113    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
114    TypeError: If `fns[i]` is not callable for any i, or `default` is not
115               callable.
116  """
117  return control_flow_ops._case_helper(  # pylint: disable=protected-access
118      smart_cond, pred_fn_pairs, default, exclusive, name,
119      allow_python_preds=True)
120