1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""XLA utility functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python.util import tf_inspect
24
25
26def is_flat(outputs):
27  """Checks if outputs is a flat structure.
28
29    Following structures and values are considered flat:
30    1) None
31    2) A single object
32    3) A list or tuple of Tensors/Operations
33
34    The only structures that this function understands are sequences and
35    dictionaries.  E.g. this means that if outputs contains a single
36    user-defined Object, it is considered to be flat. Errors are raised later on
37    if that Object cannot be converted to a Tensor.
38
39  Args:
40    outputs: Output from `computation` inside `xla.compile`.
41
42  Returns:
43    A boolean indicates whether outputs is flat.
44  """
45  # If outputs is a list or tuple, check if it has any nested structure. If
46  # there is, then outputs is non-flat.
47  if isinstance(outputs, collections.Sequence):
48    for o in outputs:
49      if isinstance(o, collections.Sequence) or isinstance(o, dict):
50        return False
51
52  # If outputs is a dict, it is non-flat.
53  if isinstance(outputs, dict):
54    return False
55
56  # Getting here means either outputs itself is a single non-structured value
57  # or it is a flat list of single non-structured values.
58  return True
59
60
61def check_function_argument_count(func, input_arity, infeed_queue):
62  """Validate the number of input arguments to an XLA function.
63
64  Args:
65    func: the Python function that will be called to generate the body of an XLA
66      computation graph.
67    input_arity: the number of explicit arguments supplied by the caller.
68    infeed_queue: if not None, the infeed queue that will supply
69      additional arguments to the function.
70
71  Returns:
72    None if function can be called with the supplied number of
73      arguments, or an error string if it cannot.
74  """
75  def format_error(complaint, quantity):
76    return '%s %d argument%s' % (complaint, quantity, ''
77                                 if quantity == 1 else 's')
78
79  num_args_supplied = input_arity
80  if infeed_queue is not None:
81    num_args_supplied += infeed_queue.number_of_tuple_elements
82  arg_spec = tf_inspect.getargspec(func)
83  num_func_args = len(arg_spec.args)
84  if arg_spec.defaults is None:
85    num_func_defaults = 0
86  else:
87    num_func_defaults = len(arg_spec.defaults)
88  min_func_args = num_func_args - num_func_defaults
89  if num_args_supplied < min_func_args:
90    # The required number of arguments is not enough to call the function.
91    if num_func_defaults == 0 and arg_spec.varargs is None:
92      return format_error('exactly', num_func_args)
93    else:
94      return format_error('at least', min_func_args)
95  if arg_spec.varargs is None and num_args_supplied > num_func_args:
96    # The required number of arguments is too many to call the function.
97    if num_func_defaults == 0:
98      return format_error('exactly', num_func_args)
99    else:
100      return format_error('at most', num_func_args)
101  # Reaching here means either
102  # 1) There are varargs, func can accept any number of arguments greater than
103  # the minimum.
104  # 2) Number of supplied arguments falls in range of acceptable argument count
105  # of func.
106  return None
107