1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions called by the generated code to execute an eager-mode op."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22
23from google.protobuf import text_format
24from tensorflow.core.framework import tensor_pb2
25from tensorflow.python import pywrap_tensorflow
26from tensorflow.python.eager import core
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.util import compat
31
32
33def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
34  """Execute a TensorFlow operation.
35
36  Args:
37    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
38      execute.
39    num_outputs: The number of outputs of the operation to fetch.
40                 (Explicitly provided instead of being inferred for performance
41                 reasons).
42    inputs: A list of inputs to the operation. Each entry should be a Tensor, or
43      a value which can be passed to the Tensor constructor to create one.
44    attrs: A tuple with alternating string attr names and attr values for this
45      operation.
46    ctx: The value of context.context().
47    name: Customized name for the operation.
48
49  Returns:
50    List of output Tensor objects. The list is empty if there are no outputs
51
52  Raises:
53    An exception on error.
54  """
55  device_name = ctx.device_name
56  # pylint: disable=protected-access
57  try:
58    tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
59                                               op_name, inputs, attrs,
60                                               num_outputs)
61  except core._NotOkStatusException as e:
62    if name is not None:
63      message = e.message + " name: " + name
64    else:
65      message = e.message
66    six.raise_from(core._status_to_exception(e.code, message), None)
67  except TypeError as e:
68    if any(ops._is_keras_symbolic_tensor(x) for x in inputs):
69      raise core._SymbolicException
70    raise e
71  # pylint: enable=protected-access
72  return tensors
73
74
75def execute_with_callbacks(op_name, num_outputs, inputs, attrs, ctx, name=None):
76  """Monkey-patch to execute to enable execution callbacks."""
77  tensors = quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
78  for callback in ctx.post_execution_callbacks:
79    callback(op_name, inputs, attrs, tensors, name)
80
81  return tensors
82
83
84execute = quick_execute
85
86
87def record_gradient(unused_op_name, unused_inputs, unused_attrs, unused_results,
88                    unused_name):
89  """Import backprop if you want gradients recorded."""
90  pass
91
92
93def make_float(v, arg_name):
94  if not isinstance(v, compat.real_types):
95    raise TypeError("Expected float for argument '%s' not %s." %
96                    (arg_name, repr(v)))
97  return float(v)
98
99
100def make_int(v, arg_name):
101  if isinstance(v, six.string_types):
102    raise TypeError("Expected int for argument '%s' not %s." %
103                    (arg_name, repr(v)))
104  try:
105    return int(v)
106  except (ValueError, TypeError):
107    raise TypeError("Expected int for argument '%s' not %s." %
108                    (arg_name, repr(v)))
109
110
111def make_str(v, arg_name):
112  if not isinstance(v, compat.bytes_or_text_types):
113    raise TypeError("Expected string for argument '%s' not %s." %
114                    (arg_name, repr(v)))
115  return compat.as_bytes(v)  # Convert unicode strings to bytes.
116
117
118def make_bool(v, arg_name):
119  if not isinstance(v, bool):
120    raise TypeError("Expected bool for argument '%s' not %s." %
121                    (arg_name, repr(v)))
122  return v
123
124
125def make_type(v, arg_name):
126  try:
127    v = dtypes.as_dtype(v).base_dtype
128  except TypeError:
129    raise TypeError("Expected DataType for argument '%s' not %s." %
130                    (arg_name, repr(v)))
131  i = v.as_datatype_enum
132  return i
133
134
135def make_shape(v, arg_name):
136  """Convert v into a list."""
137  # Args:
138  #   v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
139  #   arg_name: String, for error messages.
140
141  # Returns:
142  #   None if the rank is unknown, otherwise a list of ints (or Nones in the
143  #   position where the dimension is unknown).
144  try:
145    shape = tensor_shape.as_shape(v)
146  except TypeError as e:
147    raise TypeError("Error converting %s to a TensorShape: %s." % (arg_name, e))
148  except ValueError as e:
149    raise ValueError("Error converting %s to a TensorShape: %s." % (arg_name,
150                                                                    e))
151  if shape.ndims is None:
152    return None
153  else:
154    return shape.as_list()
155
156
157def make_tensor(v, arg_name):
158  """Ensure v is a TensorProto."""
159  if isinstance(v, tensor_pb2.TensorProto):
160    return v
161  elif isinstance(v, six.string_types):
162    pb = tensor_pb2.TensorProto()
163    text_format.Merge(v, pb)
164    return pb
165  raise TypeError(
166      "Don't know how to convert %s to a TensorProto for argument '%s'." %
167      (repr(v), arg_name))
168
169
170def args_to_matching_eager(l, ctx, default_dtype=None):
171  """Convert sequence `l` to eager same-type Tensors."""
172  EagerTensor = ops.EagerTensor  # pylint: disable=invalid-name
173  for x in l:
174    if not isinstance(x, EagerTensor):
175      break
176  else:  # note: intentional for-else
177    return l[0]._datatype_enum(), l  # pylint: disable=protected-access
178  # TODO(josh11b): Could we do a better job if we also passed in the
179  # allowed dtypes when that was known?
180
181  # Is some input already a Tensor with a dtype?
182  dtype = None
183  for t in l:
184    if isinstance(t, EagerTensor):
185      dtype = t.dtype
186      break
187
188  internal_convert_to_tensor = ops.internal_convert_to_tensor
189  if dtype is None:
190    # Infer a dtype based on the first value, and use that dtype for the
191    # remaining values.
192    ret = []
193    for t in l:
194      ret.append(internal_convert_to_tensor(
195          t, dtype,
196          preferred_dtype=default_dtype,
197          ctx=ctx,
198          accept_symbolic_tensors=False))
199      if dtype is None:
200        dtype = ret[-1].dtype
201  else:
202    ret = [internal_convert_to_tensor(t, dtype, ctx=ctx) for t in l]
203
204  return dtype.as_datatype_enum, ret
205
206
207def convert_to_mixed_eager_tensors(values, ctx):
208  v = [ops.internal_convert_to_tensor(t, ctx=ctx) for t in values]
209  types = [t._datatype_enum() for t in v]  # pylint: disable=protected-access
210  return types, v
211
212
213def args_to_mixed_eager_tensors(lists, ctx):
214  """Converts a list of same-length lists of values to eager tensors."""
215  assert len(lists) > 1
216
217  # Generate an error if len(lists[i]) is not the same for all i.
218  lists_ret = []
219  for l in lists[1:]:
220    if len(l) != len(lists[0]):
221      raise ValueError(
222          "Expected list arguments to be the same length: %d != %d (%r vs. %r)."
223          % (len(lists[0]), len(l), lists[0], l))
224    lists_ret.append([])
225
226  # Convert the first element of each list first, then the second element, etc.
227  types = []
228  for i in range(len(lists[0])):
229    dtype = None
230    # If any list has a Tensor, use that dtype
231    for l in lists:
232      if isinstance(l[i], ops.EagerTensor):
233        dtype = l[i].dtype
234        break
235    if dtype is None:
236      # Convert the first one and use its dtype.
237      lists_ret[0].append(ops.internal_convert_to_tensor(lists[0][i], ctx=ctx))
238      dtype = lists_ret[0][i].dtype
239      for j in range(1, len(lists)):
240        lists_ret[j].append(
241            ops.internal_convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx))
242    else:
243      # Convert everything to the found dtype.
244      for j in range(len(lists)):
245        lists_ret[j].append(
246            ops.internal_convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx))
247    types.append(dtype.as_datatype_enum)
248  return types, lists_ret
249