1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""xla is an experimental library that provides XLA support APIs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22
23from six.moves import xrange  # pylint: disable=redefined-builtin
24
25from tensorflow.compiler.jit.ops import xla_ops
26from tensorflow.compiler.jit.ops import xla_ops_grad  # pylint: disable=unused-import
27from tensorflow.core.framework import attr_value_pb2
28from tensorflow.python.distribute import summary_op_util
29from tensorflow.python.eager import context
30from tensorflow.python.eager import def_function
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.util import compat
37from tensorflow.python.util import nest
38from tensorflow.python.util import tf_inspect
39from tensorflow.python.util.compat import collections_abc
40from tensorflow.python.util.deprecation import deprecated
41from tensorflow.python.util.tf_export import tf_export
42
43_XLA_COMPILE_ATTR = '_xla_compile_id'
44_MAX_WARNING_LINES = 5
45
46# Operations that indicate some error in the users graph. For example, XLA
47# computation should not have any Placeholder op.
48_DENYLISTED_OPS = set([
49    'Placeholder',
50])
51
52# XLA doesn't currently support reading of intermediate tensors, thus some ops
53# are not supported.
54_UNSUPPORTED_OPS = set([
55    'AudioSummary',
56    'AudioSummaryV2',
57    'HistogramSummary',
58    'ImageSummary',
59    'MergeSummary',
60    'Print',
61    'ScalarSummary',
62    'TensorSummary',
63    'TensorSummaryV2',
64])
65
66
67@tf_export('xla.experimental.compile')
68@deprecated(
69    None, 'xla.experimental.compile is deprecated. Consider using '
70    'tf.function(jit_compile=True)',
71    warn_once=True)
72def compile(computation, inputs=None):  # pylint: disable=redefined-builtin
73  """Builds an operator that compiles and runs `computation` with XLA.
74
75  NOTE: In eager mode, `computation` will have `@tf.function` semantics.
76
77  Args:
78    computation: A Python function that builds a computation to apply to the
79      input. If the function takes n inputs, 'inputs' should be a list of n
80      tensors.
81
82      `computation` may return a list of operations and tensors.  Tensors must
83      come before operations in the returned list.  The return value of
84      `compile` is a list of tensors corresponding to the tensors from the
85      output of `computation`.
86
87      All `Operation`s returned from `computation` will be executed when
88      evaluating any of the returned output tensors.
89    inputs: A list of inputs or `None` (equivalent to an empty list). Each input
90      can be a nested structure containing values that are convertible to
91      tensors. Note that passing an N-dimension list of compatible values will
92      result in a N-dimension list of scalar tensors rather than a single Rank-N
93      tensors. If you need different behavior, convert part of inputs to tensors
94      with `tf.convert_to_tensor`.
95
96  Returns:
97    Same data structure as if computation(*inputs) is called directly with some
98    exceptions for correctness. Exceptions include:
99      1) None output: a NoOp would be returned which control-depends on
100         computation.
101      2) Single value output: A tuple containing the value would be returned.
102      3) Operation-only outputs: a NoOp would be returned which
103         control-depends on computation.
104      TODO(b/121383831): Investigate into removing these special cases.
105
106  Raises:
107    RuntimeError: if called when eager execution is enabled.
108
109  Known issues:
110    When a tf.random operation is built with XLA, the implementation doesn't
111      pass the user provided seed to the XLA compiler. As such, the XLA compiler
112      generates a random number and uses it as a seed when compiling the
113      operation. This implementation causes a violation of the Tensorflow
114      defined semantics in two aspects. First, changing the value of the user
115      defined seed doesn't change the numbers generated by the operation.
116      Second, when a seed is not specified, running the program multiple times
117      will generate the same numbers.
118
119  """
120  if context.executing_eagerly():
121    @def_function.function
122    def xla_compile_wrapper():
123      return _compile_internal(computation, inputs)
124
125    return xla_compile_wrapper()
126
127  return _compile_internal(computation, inputs)
128
129
130class XLACompileContext(control_flow_ops.XLAControlFlowContext):
131  """A `ControlFlowContext` for nodes inside an XLA computation cluster.
132
133  THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.
134
135  The primary role of `XLACompileContext` is to mark operators inside a
136  xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
137  a unique name.
138
139  `ControlFlowContext` is used to perform the annotation since it integrates
140  with Tensorflow constructs like ResourceVariables. For example, if a
141  `ResourceVariable` is constructed inside a xla.compile() block, the
142  `ResourceVariable` implementation can use
143  `with ops.control_dependencies(None)` to build the variable's definition
144  outside the compiled computation.
145  """
146
147  def __init__(self, name, pivot):
148    """Builds a new XLACompileContext.
149
150    Args:
151      name: a unique name for the context, used to populate the
152        `_xla_compile_id` attribute.
153      pivot: a pivot node. Nodes in the XLACompileContext that do not have any
154        inputs will have a control dependency on the pivot node. This ensures
155        that nodes are correctly included in any enclosing control flow
156        contexts.
157    """
158    super(XLACompileContext, self).__init__()
159    self._name = name
160    self._name_as_bytes = compat.as_bytes(name)
161    self._unsupported_ops = []
162    self._pivot = pivot
163
164  def report_unsupported_operations(self):
165    if self._unsupported_ops:
166      op_str = '\n'.join([
167          '  %s (%s)' % (op.type, op.name)
168          for op in self._unsupported_ops[:_MAX_WARNING_LINES]
169      ])
170      logging.warning('%d unsupported operations found: \n%s',
171                      len(self._unsupported_ops), op_str)
172      if len(self._unsupported_ops) > _MAX_WARNING_LINES:
173        logging.warning('... and %d more',
174                        len(self._unsupported_ops) - _MAX_WARNING_LINES)
175
176  def _RemoveExternalControlEdges(self, op):
177    """Remove any external control dependency on this op."""
178    internal_control_inputs = []
179    external_control_inputs = []
180    for x in op.control_inputs:
181      # pylint: disable=protected-access
182      is_internal_op = False
183      ctxt = x._get_control_flow_context()
184      while ctxt is not None:
185        if ctxt == self:
186          is_internal_op = True
187          break
188        ctxt = ctxt._outer_context
189      if is_internal_op:
190        internal_control_inputs.append(x)
191      else:
192        external_control_inputs.append(x)
193      # pylint: enable=protected-access
194    # pylint: disable=protected-access
195    op._remove_all_control_inputs()
196    op._add_control_inputs(internal_control_inputs)
197    # pylint: enable=protected-access
198    return internal_control_inputs, external_control_inputs
199
200  def AddOp(self, op):
201    """Create op in XLACompileContext and notifies outer context recursively."""
202    # pylint: disable=protected-access
203    if op.type in _DENYLISTED_OPS:
204      logging.error(
205          'Operation of type %s (%s) is not supported in XLA. Execution will '
206          'fail if this op is used in the graph. ', op.type, op.name)
207
208    # TODO(ycao): Automatically disable summaries instead of reporting them.
209    if op.type in _UNSUPPORTED_OPS:
210      self._unsupported_ops.append(op)
211
212    if any(x.dtype._is_ref_dtype for x in op.inputs):
213      raise NotImplementedError(
214          'Non-resource Variables are not supported inside XLA computations '
215          '(operator name: %s)' % op.name)
216
217    if _XLA_COMPILE_ATTR in op.node_def.attr:
218      raise ValueError('XLA compiled computations cannot be nested, (operator '
219                       'name: %s)' % op.name)
220
221    op._set_attr(
222        _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
223
224    op.graph.prevent_feeding(op)
225    op.graph.prevent_fetching(op)
226
227    # Remove any control edges from outer control flow contexts. These may cause
228    # mismatched frame errors. An example is when one of op's inputs is
229    # generated in a different While control flow context.
230    (internal_control_inputs,
231     external_control_inputs) = self._RemoveExternalControlEdges(op)
232
233    if not op.inputs:
234      # Add a control edge from the control pivot to this op.
235      if not internal_control_inputs:
236        # pylint: disable=protected-access
237        op._add_control_input(self._pivot)
238        # pylint: enable=protected-access
239    else:
240      for index in xrange(len(op.inputs)):
241        x = op.inputs[index]
242        real_x = self.AddValue(x)
243        if real_x is not x:
244          op._update_input(index, real_x)  # pylint: disable=protected-access
245
246    if external_control_inputs:
247      # Use an identity to pull control inputs as data inputs. Note that we
248      # ignore ops which don't have outputs. TODO(phawkins): fix that.
249      with ops.control_dependencies(None):
250        self.Enter()
251        external_control_inputs = [
252            array_ops.identity(x.outputs[0]).op
253            for x in external_control_inputs
254            if x.outputs
255        ]
256        self.Exit()
257      # pylint: disable=protected-access
258      op._add_control_inputs(external_control_inputs)
259      # pylint: enable=protected-access
260
261    # Mark op's outputs as seen by this context and any outer contexts.
262    output_names = [x.name for x in op.outputs]
263    context = self
264    while context is not None:
265      # pylint: disable=protected-access
266      context._values.update(output_names)
267      context = context._outer_context
268      # pylint: enable=protected-access
269
270    if self._outer_context:
271      self._outer_context.AddInnerOp(op)
272
273  def AddValue(self, val):
274    """Add `val` to the current context and its outer context recursively."""
275    if val.name in self._values:
276      # Use the real value if it comes from outer context.
277      result = self._external_values.get(val.name)
278      return val if result is None else result
279
280    result = val
281    self._values.add(val.name)
282    if self._outer_context:
283      result = self._outer_context.AddValue(val)
284      self._values.add(result.name)
285
286    self._external_values[val.name] = result
287
288    return result
289
290  def AddInnerOp(self, op):
291    self.AddOp(op)
292    if self._outer_context:
293      self._outer_context.AddInnerOp(op)
294
295  @property
296  def grad_state(self):
297    # Define the gradient loop state associated with the XLACompileContext to
298    # be None as the XLACompileContext does not get nested nor does the
299    # grad_state outside the XLACompileContext affect the graph inside so the
300    # grad_state should be as if this is the top-level gradient state.
301    return None
302
303  @property
304  def back_prop(self):
305    """Forwards to the enclosing while context, if any."""
306    if self.GetWhileContext():
307      return self.GetWhileContext().back_prop
308    return False
309
310
311def _compile_internal(computation, inputs=None):
312  """Builds graph operators that compiles and symbolically executes computation.
313
314  Args:
315    computation: A Python function that builds the computation to compile and
316      execute.
317    inputs: A list of inputs or `None` (equivalent to an empty list). Each input
318      can be a nested structure containing values that are convertible to
319      tensors. Note that passing an N-dimension list of compatible values will
320      result in a N-dimension list of scalar tensors rather than a single Rank-N
321      tensors. If you need different behavior, convert part of inputs to tensors
322      with `tf.convert_to_tensor`.
323
324  Returns:
325    Same data structure as if computation(*inputs) is called directly with some
326    exceptions for correctness. Exceptions include: 1) None output 2) Single
327    value output 3) Operation-only outputs
328  Raises:
329    ValueError: If any element in computation outputs is neither an operations
330      or a value that can be converted to tensor.
331    ValueError: If computation outputs is non-flat and contains any Operations.
332    TypeError: If `inputs` is not a list or tuple.
333  """
334  if inputs is None:
335    inputs = []
336
337  if not isinstance(inputs, collections_abc.Sequence):
338    raise TypeError('inputs must be a list')
339
340  # Flatten inputs.
341  flat_inputs = nest.flatten(inputs)
342  # Converts inputs to Tensors.
343  flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
344
345  cluster_name = ops.get_default_graph().unique_name('cluster')
346  pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
347  context = XLACompileContext(name=cluster_name, pivot=pivot)
348  try:
349    context.Enter()
350
351    # Add identity ops so even unused inputs are 'consumed' by the
352    # computation.
353    flat_inputs = [
354        array_ops.identity(x, name='input_{}'.format(i))
355        for i, x in enumerate(flat_inputs)
356    ]
357
358    # Re-pack flat_inputs in same structure as 'inputs'.
359    computation_inputs = nest.pack_sequence_as(
360        structure=inputs, flat_sequence=flat_inputs)
361
362    # Only resource variables work inside an XLA computation, so turn on
363    # resource variables for the computation.
364    vscope = variable_scope.get_variable_scope()
365    saved_use_resource = vscope.use_resource
366    vscope.set_use_resource(True)
367
368    with _disable_summary_context():
369      outputs = computation(*computation_inputs)
370
371    # Restore variable scope after computation.
372    vscope.set_use_resource(saved_use_resource)
373
374    outputs_is_flat = is_flat(outputs)
375    if outputs_is_flat:
376      output_tensors, control_deps = _postprocess_flat_outputs(outputs)
377    else:
378      output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
379
380    context.ExitResult(output_tensors)
381  finally:
382    context.report_unsupported_operations()
383    context.Exit()
384
385  # When XLA computation returns only operations and no tensors, a NoOp
386  # dependent on the operations in outputs is returned. Otherwise final
387  # outputs would be empty and there is no way to trigger returned
388  # operations.
389  if not output_tensors:
390    return control_flow_ops.group(control_deps, name='output_0')
391
392  output_tensors = [
393      xla_ops.xla_cluster_output(o, name='output{}'.format(i))
394      for i, o in enumerate(output_tensors)
395  ]
396
397  with ops.control_dependencies(control_deps):
398    # Wraps the outputs in identity operators that carries control
399    # dependencies.
400    output_tensors = [
401        array_ops.identity(o, name='output_%d' % i)
402        for i, o in enumerate(output_tensors)
403    ]
404
405  # If `computation` returned non-flat output structure, pack output tensors
406  # back into same structure.
407  if not outputs_is_flat:
408    output_tensors = nest.pack_sequence_as(
409        structure=outputs, flat_sequence=output_tensors)
410
411  return output_tensors
412
413
414def is_flat(outputs):
415  """Checks if outputs is a flat structure.
416
417    Following structures and values are considered flat:
418    1) None
419    2) A single object
420    3) A list or tuple of Tensors/Operations
421
422    The only structures that this function understands are sequences,
423    dictionaries and types defined using the attrs library.  E.g. this means
424    that if outputs contains a single user-defined Object, it is considered to
425    be flat. Errors are raised later on if that Object cannot be converted to a
426    Tensor.
427
428  Args:
429    outputs: Output from `computation` inside `xla.compile`.
430
431  Returns:
432    A boolean indicates whether outputs is flat.
433  """
434  # If outputs is a list or tuple, check if it has any nested structure. If
435  # there is, then outputs is non-flat.
436  if isinstance(outputs, collections_abc.Sequence):
437    for o in outputs:
438      if (isinstance(o, collections_abc.Sequence) or
439          isinstance(o, collections_abc.Mapping) or
440          hasattr(o.__class__, '__attrs_attrs__')):
441        return False
442
443  # If outputs is a dict, it is non-flat.
444  if isinstance(outputs, collections_abc.Mapping):
445    return False
446
447  # If outputs is from the attrs library, it is non-flat.
448  if hasattr(outputs.__class__, '__attrs_attrs__'):
449    return False
450
451  # Getting here means either outputs itself is a single non-structured value
452  # or it is a flat list of single non-structured values.
453  return True
454
455
456def _postprocess_flat_outputs(outputs):
457  """Validates flat outputs and adds back device assignments.
458
459  Args:
460    outputs: Output from `computation` inside `xla.compile`.
461
462  Returns:
463    Tensors and Operations extracted from outputs.
464  """
465  # Following code segment is to preserve legacy behavior. Previously we only
466  # supported flat outputs and thus for consistency it was nice to convert even
467  # single element into a tuple. But now that we support arbitrary output
468  # structure, this is no longer necessary.
469  # TODO(b/121383831): Migrate all legacy use cases and delete this special
470  # case.
471  # If the computation returns `None`, make it an empty tuple.
472  if outputs is None:
473    outputs = tuple()
474  # If the computation only returned one value, make it a tuple.
475  if not isinstance(outputs, collections_abc.Sequence):
476    outputs = (outputs,)
477
478  # Append `no_op` here so that return value of this function always contains
479  # at least one op that can trigger XlaLaunch node.
480  outputs += (control_flow_ops.no_op(),)
481  try:
482    outputs = [
483        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
484        for o in outputs
485    ]
486  except Exception as e:
487    raise ValueError(
488        'XLA computation function return values must all either be Operations'
489        ' or convertible to Tensors. Got error: "%s"' % str(e))
490
491  # Separates the returned Operations and Tensors.
492  output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
493  output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
494
495  if outputs != output_tensors + output_operations:
496    raise ValueError(
497        'XLA computation function must return zero or more Tensor values '
498        'followed by zero or more Operations.')
499
500  new_output_tensors = []
501  for t in output_tensors:
502    with ops.device(t.device if t.device else ''):
503      new_output_tensors.append(array_ops.identity(t))
504
505  return new_output_tensors, output_operations
506
507
508def _postprocess_non_flat_outputs(outputs):
509  """Validates non-flat outputs and adds back device assignments.
510
511  Args:
512    outputs: Output from `computation` inside `xla.compile`.
513
514  Returns:
515    Tensors extracted from outputs and an empty list because Operations are not
516    allowed in non-flat outputs..
517  """
518  # Convert all non-Operation outputs to Tensors.
519  new_output_tensors = []
520  for o in nest.flatten(outputs):
521    if isinstance(o, ops.Operation):
522      raise ValueError(
523          'xla.compile does not support Operation as return value in non-flat '
524          'output structure. You can set returned Operations as control '
525          'dependencies of returned Tensors so Operations are triggered when '
526          'Tensors are evaluated. Operation found: "%s"' % o.name)
527
528    try:
529      o = ops.convert_to_tensor(o)
530    except Exception as e:
531      raise ValueError(
532          'XLA computation function return values must all either be '
533          'Operations or convertible to Tensors. Got error: "%s"' % str(e))
534
535    # Makes sure even pass-through inputs/outputs are touched in compile
536    # context by creating an Identity node inside compile context.
537    with ops.device(o.device if o.device else ''):
538      new_output_tensors.append(array_ops.identity(o))
539
540  return new_output_tensors, []
541
542
543@contextlib.contextmanager
544def _disable_summary_context():
545  """Enters a context where all summary ops are skipped.
546
547  Summaries are not yet supported in xla.compile(). So we provide this context
548  manager that can skip creating summary ops. This is a temporary workaround due
549  to XLA not supporting summary ops.
550
551  Yields:
552    None.
553  """
554  original_skip_summary_func = summary_op_util.skip_summary
555  summary_op_util.skip_summary = lambda: True
556
557  try:
558    yield
559  finally:
560    summary_op_util.skip_summary = original_skip_summary_func
561
562
563class _CapturedObject(object):
564  """A placeholder to capture an object."""
565
566  def __init__(self):
567    self._object = None
568
569  def capture(self, o):
570    if self._object:
571      raise RuntimeError(
572          'InternalError: _CapturedObject can capture only once. Please file '
573          'bug.')
574
575    self._object = o
576
577  def get(self):
578    return self._object
579
580
581def _get_scaffold(captured_scaffold_fn):
582  """Retrieves the Scaffold from `captured_scaffold_fn`."""
583  scaffold_fn = captured_scaffold_fn.get()
584
585  if not scaffold_fn:
586    return None
587
588  scaffold = scaffold_fn()
589  if scaffold is None:
590    raise ValueError(
591        'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
592
593  return scaffold
594
595
596def check_function_argument_count(func, input_arity, infeed_queue):
597  """Validate the number of input arguments to an XLA function.
598
599  Args:
600    func: the Python function that will be called to generate the body of an XLA
601      computation graph.
602    input_arity: the number of explicit arguments supplied by the caller.
603    infeed_queue: if not None, the infeed queue that will supply
604      additional arguments to the function.
605
606  Returns:
607    None if function can be called with the supplied number of
608      arguments, or an error string if it cannot.
609  """
610  def format_error(complaint, quantity):
611    return '%s %d argument%s' % (complaint, quantity, ''
612                                 if quantity == 1 else 's')
613
614  num_args_supplied = input_arity
615  if infeed_queue is not None:
616    num_args_supplied += infeed_queue.number_of_tuple_elements
617  arg_spec = tf_inspect.getargspec(func)
618  num_func_args = len(arg_spec.args)
619  if arg_spec.defaults is None:
620    num_func_defaults = 0
621  else:
622    num_func_defaults = len(arg_spec.defaults)
623  min_func_args = num_func_args - num_func_defaults
624  if num_args_supplied < min_func_args:
625    # The required number of arguments is not enough to call the function.
626    if num_func_defaults == 0 and arg_spec.varargs is None:
627      return format_error('exactly', num_func_args)
628    else:
629      return format_error('at least', min_func_args)
630  if arg_spec.varargs is None and num_args_supplied > num_func_args:
631    # The required number of arguments is too many to call the function.
632    if num_func_defaults == 0:
633      return format_error('exactly', num_func_args)
634    else:
635      return format_error('at most', num_func_args)
636  # Reaching here means either
637  # 1) There are varargs, func can accept any number of arguments greater than
638  # the minimum.
639  # 2) Number of supplied arguments falls in range of acceptable argument count
640  # of func.
641  return None
642