1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Define tflite op hints (intrinsic operations).
16
17This essentially allows defining a TensorFlow API for tflite operations in
18Python with hints on how they are represented in TensorFlow Lite. This basically
19is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution
20graph and is useful for LSTMs and other complicated TensorFlow constructions
21that are difficult to pattern match in TOCO, but are represented by a single
22accelerated tflite op.
23
24Example:
25  def tflite_cool_activation(input):
26    # A cool activation function.
27    custom = tf.lite.OpHint("cool_activation")
28    input, = custom.add_inputs(input)
29    output = tf.sigmoid(input) * input
30    output, = custom.add_outputs(output)
31    return output
32
33  image = tf.placeholder(tf.float32, (1, 16, 16, 1))
34  output = tf.identity(tflite_cool_activation(image))
35
36  session = tf.Session()
37
38  graphdef_to_convert = tf.lite.convert_op_hints_to_stubs(session)
39  tflite_graph = tf.lite.toco_convert(graphdef_to_convert, [image], [output])
40  with open("/tmp/graph.fb", "wb") as fp:
41    fp.write(tflite_graph)
42
43How does it work?:
44
45OpHint is a helper that you use when defining a vanilla python function.
46It allows you to wrap arguments with tf.identities with some custom attributes.
47These attributes allow you to find the original block of ops that was created.
48For example, if you use cool_activation above you essentially get:
49
50a_input = tf.identity()
51result = tf.multiply(tf.sigmoid(a_input), a_input)
52output = tf.identity()
53
54a_input, output are identities that have parameters representing
55what argument they are, what the name of the function they should turn into
56in tf lite as well as a guid that uniquely identifies a particular invocation.
57
58Once you have built your whole tensorflow graph, you can run it and train it
59as usual, but after you have done that, you need to convert the graph into
60a form that replaces these subgraphs wrapped in identities to stub ops. These
61ops don't actually exist in the normal TensorFlow runtime, but will be
62understood by toco later.
63"""
64
65# TODO(aselle): Make this use generic graph transformations.
66# TODO(aselle): _tensor_name_base should be called _tensor_name_to_op_name.
67
68from __future__ import absolute_import
69from __future__ import division
70from __future__ import print_function
71
72import collections as _collections
73import copy as _copy
74import json as _json
75import uuid as _uuid
76import six as _six
77
78from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2
79from tensorflow.core.framework import graph_pb2 as _graph_pb2
80from tensorflow.core.framework import node_def_pb2 as _node_def_pb2
81from tensorflow.python.framework import ops as _ops
82# TODO(aselle): publicize these apis if we continue to use these.
83from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
84from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
85from tensorflow.python.ops import array_ops as _array_ops
86from tensorflow.python.util import compat as _compat
87from tensorflow.python.util.all_util import remove_undocumented
88from tensorflow.python.util.tf_export import tf_export as _tf_export
89
90
91@_tf_export("lite.OpHint")
92class OpHint(object):
93  """A class that helps build tflite function invocations.
94
95  It allows you to take a bunch of TensorFlow ops and annotate the construction
96  such that toco knows how to convert it to tflite. This embeds a pseudo
97  function in a TensorFlow graph. This allows embedding high-level API usage
98  information in a lower level TensorFlow implementation so that an alternative
99  implementation can be substituted later.
100
101  Essentially, any "input" into this pseudo op is fed into an identity, and
102  attributes are added to that input before being used by the constituent ops
103  that make up the pseudo op. A similar process is done to any output that
104  is to be exported from the current op.
105
106  """
107  # TODO(aselle): When TensorFlow functions functionality works for arbitrary
108  # constructs, this mechanism can be retired and changed to use python defun's.
109
110  # Attr constants that are used for representation in the GraphDef. These
111  # will be used on every Identity op that is involved in a total OpHint.
112
113  # Name of the OpHint function (cosmetic).
114  FUNCTION_NAME_ATTR = "_tflite_function_name"
115  # UUID of the function (each OpHint gets a new uuid).
116  FUNCTION_UUID_ATTR = "_tflite_function_uuid"
117  # The index index of the input (or nothing if it is an output).
118  FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index"
119  # The output index of the output (or nothing if it is an input).
120  FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index"
121  # An index that orders aggregate arguments. Aggregate arguments are ones
122  # that are separate but will be fused horizontally. For example a static LSTM
123  # has a lstm cell for each time step. Each one has a separate opHint, but a
124  # fused SequentialLSTM will treat this as a single tensor.
125  FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index"
126  # The way in which multiple parts of the aggregate argument will be joined
127  # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST,
128  # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK.
129  FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate"
130  # On fused OpHint stub, the order of inputs that the final LSTM call will
131  # have. What this means is that the TensorFlow order might be
132  # "foo", "bar", "stuff" and you might want the TF lite op order to be
133  # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this
134  # attribute to [2, 0, 1, -1].
135  TFLITE_INPUT_INDICES = "_tflite_input_indices"
136  # OpHint level.
137  FUNCTION_LEVEL_ATTR = "_tflite_ophint_level"
138  # Ophint internal mapping, this is for high level Ophint only.
139  # This basically contains three kinds of mapping:
140  #   1) How parental ophinted inputs map to the first child ophinted inputs;
141  #   2) How internal children nodes are connected;
142  #   3) How parental ophinted outputs map to the last child ophinted outputs.
143  CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping"
144
145  # Types of aggregations
146  #  stack: stacks all ophints with matching tags. i.e. for a static rnn.
147  #   specifically, this is good for an input or output to a static rnn cell.
148  AGGREGATE_STACK = "stack"
149  # first: only takes the first output (one with lowest sort index)
150  # of matching tags. This is good for the input state to an RNN.
151  AGGREGATE_FIRST = "first"
152  # aggregation last takes only the last tag (one with highest sort index).
153  # This is good for an output value on the last stack item of a
154  # static rnn.
155  AGGREGATE_LAST = "last"
156
157  class OpHintArgumentTracker(object):
158    """Conceptually tracks indices of arguments of "OpHint functions".
159
160    The inputs and arguments of these functions both use an instance
161    of the class so they can have independent numbering.
162    """
163
164    def __init__(self,
165                 function_name,
166                 unique_function_id,
167                 node_name_prefix,
168                 attr_name,
169                 level=1,
170                 children_inputs_mappings=None):
171      """Initialize ophint argument.
172
173      Args:
174        function_name: Name of the function that this tracks arguments for.
175        unique_function_id: UUID of function that this tracks arguments for.
176        node_name_prefix: How identities that are created are named.
177        attr_name: Name of attribute to use to store the index for this hint.
178          i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX
179        level: Hierarchical level of the Ophint node, a number.
180        children_inputs_mappings: Inputs/Outputs mapping for children hints.
181      """
182
183      # The global index is the argument index of the op. This is in contrast
184      # to the sort index which is the sequence number of a particular instance
185      # of a given global index. For example, you may have called add hint
186      # twice with the tag "foo". Then the global index will be 0 for both
187      # and the sort index will be 0 for the first added and 1 for the second.
188      self._function_name = function_name
189      self._unique_function_id = unique_function_id
190      self._next_global_index = 0  # The absolute global index
191      self._used_global_indices = set()
192      self._tag_to_global_index = {}  # The argument index a given tag maps to
193      self._tag_to_next_sort_index = {}  # The current index for each tag
194      self._node_name_prefix = node_name_prefix
195      self._attr_name = attr_name
196      self._level = level
197      self._children_inputs_mappings = children_inputs_mappings
198
199    def _get_new_global_index(self, index_override):
200      """Return the next unused argument index in order or use an override.
201
202      Args:
203        index_override: An index to use instead of the next available or None
204          to use the next available.
205
206      Returns:
207        A valid global_index to use for the next hint argument.
208
209      Raises:
210        ValueError: If the index_override is already used by another hint.
211      """
212      if index_override is None:
213        global_index = self._next_global_index
214      else:
215        if index_override in self._used_global_indices:
216          raise ValueError("Index %d was already used by another call to add")
217        global_index = index_override
218      # Make next_global_index valid
219      self._used_global_indices.add(global_index)
220      while self._next_global_index in self._used_global_indices:
221        self._next_global_index += 1
222      return global_index
223
224    def add(self, arg, tag=None, name=None, aggregate=None,
225            index_override=None):
226      """Return a wrapped tensor of an input tensor as an argument.
227
228      Args:
229        arg: A TensorFlow tensor that should be considered an argument.
230        tag: String tag to identify arguments that should be packed.
231        name: Name of argument. This is included in the Identity hint op names.
232        aggregate: Strategy to aggregate.
233        Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
234          and OpHint.AGGREGATE_STACK.
235          Note, aggregate is only valid if tag is specified.
236        index_override: Specify what input/output index should this be in the
237          final stub. i.e. add(arg0, index=1); add(arg1, index=0) will make the
238          final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than
239          the default call order based ordering.
240
241      Returns:
242        A tensor representing the wrapped argument.
243
244      Raises:
245        ValueError: When indices are not consistent.
246      """
247
248      # Find the appropriate index
249      if tag is None:
250        if aggregate is not None:
251          raise ValueError("You must specify `tag` if using aggregate.")
252        global_index = self._get_new_global_index(index_override)
253        sort_index = None
254      else:
255        if aggregate is None:
256          raise ValueError("You must specify `aggregate` if using tag.")
257        if tag not in self._tag_to_global_index:
258          self._tag_to_global_index[tag] = (
259              self._get_new_global_index(index_override))
260          self._tag_to_next_sort_index[tag] = 0
261        elif (index_override and
262              index_override != self._tag_to_global_index[tag]):
263          raise ValueError(
264              "Tag %r was called with two indices %r and %r" %
265              (tag, index_override, self._tag_to_global_index[tag]))
266        global_index = self._tag_to_global_index[tag]
267        sort_index = self._tag_to_next_sort_index[tag]
268        self._tag_to_next_sort_index[tag] += 1
269
270      uuid = self._unique_function_id
271      name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name,
272                                    uuid, global_index, sort_index, name)
273
274      identity_op = _array_ops.identity(arg, name=name)
275
276      # pylint: disable=protected-access
277      identity_op.op._set_attr(
278          OpHint.FUNCTION_NAME_ATTR,
279          _attr_value_pb2.AttrValue(
280              s=_compat.as_bytes(self._function_name)))
281      identity_op.op._set_attr(
282          OpHint.FUNCTION_UUID_ATTR,
283          _attr_value_pb2.AttrValue(
284              s=_compat.as_bytes(self._unique_function_id)))
285      identity_op.op._set_attr(
286          self._attr_name, _attr_value_pb2.AttrValue(i=global_index))
287      identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR,
288                               _attr_value_pb2.AttrValue(i=self._level))
289      if self._children_inputs_mappings:
290        identity_op.op._set_attr(
291            OpHint.CHILDREN_INPUTS_MAPPINGS,
292            _attr_value_pb2.AttrValue(
293                s=_compat.as_bytes(_json.dumps(
294                    self._children_inputs_mappings))))
295
296      if sort_index is not None:
297        identity_op.op._set_attr(
298            OpHint.FUNCTION_SORT_INDEX_ATTR,
299            _attr_value_pb2.AttrValue(i=sort_index))
300      if aggregate is not None:
301        identity_op.op._set_attr(
302            OpHint.FUNCTION_AGGREGATE_ATTR,
303            _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate))))
304      # pylint: enable=protected-access
305      return identity_op
306
307  def __init__(self,
308               function_name,
309               level=1,
310               children_inputs_mappings=None,
311               **kwargs):
312    """Create a OpHint.
313
314    Args:
315      function_name: Name of the function (the custom op name in tflite)
316      level: OpHint level.
317      children_inputs_mappings: Children OpHint inputs/outputs mapping.
318        children_inputs_mappings should like below:
319        "parent_first_child_input":
320            [{"parent_input_index": num, "child_input_index": num}, ...]
321        "parent_last_child_output":
322            [{"parent_output_index": num, "child_output_index": num}, ...]
323        "internal_children_input_output":
324            [{"child_input_index": num, "child_output_index": num}, ...]
325      **kwargs: Keyword arguments of any constant attributes for the function.
326    """
327    self._function_name = function_name
328    self._level = level
329    if self._level == 1:
330      assert children_inputs_mappings is None
331    else:
332      assert isinstance(children_inputs_mappings, dict)
333    self._children_inputs_mappings = children_inputs_mappings
334    if self._children_inputs_mappings is not None:
335      self._validate_children_inputs_mappings(self._children_inputs_mappings)
336    self._unique_function_id = _uuid.uuid1().hex  # TODO(aselle): Unique enough?
337    self._attrs_to_store_later = kwargs
338    self._stored_attrs = False
339    self._inputs = OpHint.OpHintArgumentTracker(
340        self._function_name, self._unique_function_id, "InputHint",
341        OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings)
342    self._outputs = OpHint.OpHintArgumentTracker(
343        self._function_name, self._unique_function_id, "OutputHint",
344        OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level,
345        self._children_inputs_mappings)
346
347  def _validate_children_inputs_mappings(self, children_inputs_mappings):
348    """Validate children inputs mappings is in the right format.
349
350    Args:
351      children_inputs_mappings: the Children ophint inputs/outputs mapping.
352    """
353    assert isinstance(children_inputs_mappings, dict)
354    assert "parent_first_child_input" in children_inputs_mappings
355    assert "parent_last_child_output" in children_inputs_mappings
356    assert "internal_children_input_output" in children_inputs_mappings
357
358    # validate parent_first_child_input.
359
360    def assert_dictlist_has_keys(dictlist, keys):
361      for dikt in dictlist:
362        assert isinstance(dikt, dict)
363        for key in keys:
364          assert key in dikt
365
366    assert_dictlist_has_keys(
367        children_inputs_mappings["parent_first_child_input"],
368        ["parent_ophint_input_index", "first_child_ophint_input_index"])
369    assert_dictlist_has_keys(
370        children_inputs_mappings["parent_last_child_output"],
371        ["parent_output_index", "child_output_index"])
372    assert_dictlist_has_keys(
373        children_inputs_mappings["internal_children_input_output"],
374        ["child_input_index", "child_output_index"])
375
376  def _setattr(self, dest_op, name, value):
377    tensor_value = _ops.convert_to_tensor(value)
378    # pylint: disable=protected-access
379    dest_op.op._set_attr(name, _attr_value_pb2.AttrValue(
380        tensor=tensor_value.op.node_def.attr["value"].tensor))
381    # pylint: enable=protected-access
382
383  def add_input(self, *args, **kwargs):
384    """Add a wrapped input argument to the hint.
385
386    Args:
387      *args: The input tensor.
388      **kwargs:
389        "name" label
390        "tag" a tag to group multiple arguments that will be aggregated. I.e.
391          a string like 'cool_input'. Basically multiple inputs can be added
392          to the same hint for parallel operations that will eventually be
393          combined. An example would be static_rnn which creates multiple copies
394          of state or inputs.
395        "aggregate" aggregation strategy that is valid only for tag non None.
396          Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
397          and OpHint.AGGREGATE_STACK.
398        "index_override" The global index to use. This corresponds to the
399          argument order in the final stub that will be generated.
400    Returns:
401      The wrapped input tensor.
402    """
403    return self._inputs.add(*args, **kwargs)
404
405  def add_output(self, *args, **kwargs):
406    """Add a wrapped output argument to the hint.
407
408    Args:
409      *args: The output tensor.
410      **kwargs:
411        "name" label
412        "tag" a tag to group multiple arguments that will be aggregated. I.e.
413          a string like 'cool_input'. Basically multiple inputs can be added
414          to the same hint for parallel operations that will eventually be
415          combined. An example would be static_rnn which creates multiple copies
416          of state or inputs.
417        "aggregate" aggregation strategy that is valid only for tag non None.
418          Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
419          and OpHint.AGGREGATE_STACK.
420        "index_override" The global index to use. This corresponds to the
421          argument order in the final stub that will be generated.
422    Returns:
423      The wrapped output tensor.
424    """
425    return self._outputs.add(*args, **kwargs)
426
427  def add_inputs(self, *args, **kwargs):
428    """Add a sequence of inputs to the function invocation.
429
430    Args:
431      *args: List of inputs to be converted (should be Tf.Tensor).
432      **kwargs: This allows 'names' which should be a list of names.
433    Returns:
434      Wrapped inputs (identity standins that have additional metadata). These
435      are also are also tf.Tensor's.
436    """
437    if "names" in kwargs:
438      return [
439          self._inputs.add(arg, name=name)
440          for arg, name in zip(args, kwargs["names"])
441      ]
442    else:
443      return [self._inputs.add(arg) for arg in args]
444
445  def add_outputs(self, *args, **kwargs):
446    """Add a sequence of outputs to the function invocation.
447
448    Args:
449      *args: List of outputs to be converted (should be tf.Tensor).
450      **kwargs: See
451    Returns:
452      Wrapped outputs (identity standins that have additional metadata). These
453      are also tf.Tensor's.
454    """
455    if "names" in kwargs:
456      return [
457          self._outputs.add(arg, name=name)
458          for arg, name in zip(args, kwargs["names"])
459      ]
460    else:
461      return [self._outputs.add(arg) for arg in args]
462
463
464class _LiteOperand(object):
465  """Abstract operand for a tflite hint function._dynamic_rnn_loop.
466
467  This is a base class that handles representing arguments to an OpHint.
468  It also is able to serialize operands to the stubbed graph_def.
469  Child classes are responsible for being able to
470  store information about the hint identity operators. They are also responsible
471  for knowing how to serialize to output graphdefs.
472
473  Typically this will be implemented by holding one or more identity nodes
474  that were previously discovered as hints.
475  """
476
477  def aggregate_and_return_name_for_input(self, out_graphdef):
478    """This adds the node(s) to out_graphdef and returns the input node name.
479
480    Args:
481      out_graphdef: A graphdef that is ready to have this input added.
482
483    Returns:
484      The output that the stub should use as an input for this operand.
485
486    Raises:
487      RuntimeError: if the method is not implemented.
488    """
489    del out_graphdef
490    raise RuntimeError("Unimplemented abstract method.")
491
492  def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
493                                           out_graphdef):
494    """Add node(s) to graph representing output operands and returns type.
495
496    Args:
497      fused_op_name: name of the fused op stub name.
498      output_index: Output index that we are currently processing from stub.
499      out_graphdef: The destination graphdef we are currently building up.
500
501    Returns:
502      The datatype of this identity.
503
504    Raises:
505      RuntimeError: if the method is not implemented.
506    """
507    del fused_op_name, output_index, out_graphdef
508    raise RuntimeError("Unimplemented abstract method.")
509
510
511class _LiteSingleOperand(_LiteOperand):
512  """A simple operand that is non-aggregated (i.e. most hints)."""
513
514  def __init__(self, node):
515    _LiteOperand.__init__(self)
516    self.node = node
517    self.name = _tensor_name_base(node.name)
518
519  def flatten(self):
520    return [self.name]
521
522  def aggregate_and_return_name_for_input(self, out_graphdef):
523    return self.name
524
525  def aggregate_and_return_name_for_output(self, fused_op_name, index,
526                                           out_graphdef):
527    output_node = _copy.deepcopy(self.node)
528    del output_node.input[:]
529    output_node.input.append(_tensorflow_output_name(fused_op_name, index))
530    out_graphdef.node.extend([output_node])
531    return self.node.attr["type"].i
532
533  def __str__(self):
534    return str(self.name)
535
536
537class _LiteAggregateOperand(_LiteOperand):
538  """An operand for a tflite hint function that is aggregated from many.
539
540  For example, an LSTM is a grid of operators that are all related. Inputs
541  going into them may need to be fused, so they should all be tracked as
542  related arguments.
543  """
544
545  def __init__(self, aggregation):
546    _LiteOperand.__init__(self)
547    self.aggregation = aggregation
548    self.names = {}
549    self.nodes = {}
550    self.flattened = None
551
552  def add(self, sort, node):
553    self.names[sort] = _tensor_name_base(node.name)
554    self.nodes[sort] = node
555
556  def flatten_nodes(self):
557    """Return a list of all the node protos in aggregation sorted order."""
558    if not self.flattened:
559      self.flattened = [None] * len(self.nodes)
560      for idx, node in _six.iteritems(self.nodes):
561        self.flattened[idx] = node
562      for n in self.nodes:
563        if n is None:
564          raise RuntimeError("Aggregate was missing argument.")
565      if self.aggregation == OpHint.AGGREGATE_FIRST:
566        self.flattened = self.flattened[:1]
567      elif self.aggregation == OpHint.AGGREGATE_LAST:
568        self.flattened = self.flattened[-1:]
569      elif self.aggregation == OpHint.AGGREGATE_STACK:
570        pass
571      else:
572        raise ValueError(
573            "Invalid aggregation type %r specified" % self.aggregation)
574    return self.flattened
575
576  def flatten(self):
577    """Return a list of all node names in aggregation sorted sorter."""
578    return [_tensor_name_base(x.name) for x in self.flatten_nodes()]
579
580  def aggregate_and_return_name_for_input(self, out_graphdef):
581    """This adds the nodes to out_graphdef and returns an aggregated output.
582
583    In particular, if you have 4 inputs to a hint stub, this will be the
584    node that you can use as an output. I.e. you have 4 timesteps from a
585    static rnn, then a fused UnidriecitonalLSTM will expect 1 input with
586    all 4 time steps. So here we make a pack and return the output name of
587    that pack.
588
589    Args:
590      out_graphdef: A graphdef that is ready to have this input added.
591
592    Returns:
593      The name of a pack that aggregates this node.
594    """
595    flattened = self.flatten_nodes()
596    if len(flattened) == 1:
597      return _tensor_name_base(flattened[0].name)
598    else:
599      new_node = _node_def_pb2.NodeDef()
600      new_node.op = "Pack"
601      new_node.name = "OpHintStack-%s" % flattened[0].name
602      new_node.attr["N"].i = len(flattened)
603      new_node.attr["T"].type = flattened[0].attr["T"].type
604      for discrete in flattened:
605        new_node.input.append(_tensor_name_base(discrete.name))
606      out_graphdef.node.extend([new_node])
607      return new_node.name
608
609  def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
610                                           out_graphdef):
611    """This adds to `out_graphdef` all the unaggregated outputs.
612
613    I.e. we are outputting from a fused stub, but we need to make it compatible
614    with the unfused original graph so we insert an unpack. Ideally in a later
615    stage the unpack -> pack sequences will be removed.
616
617    Args:
618      fused_op_name: The name of the stub we are in the process of fusing.
619      output_index: The output output_index this object represents.
620      out_graphdef: The graphdef we are in the process of buildings
621
622    Returns:
623      The type of the aggregated output (so we can finish building the stub
624      op).
625    """
626    flattened = self.flatten_nodes()
627    if len(flattened) == 1:
628      temp_op = _LiteSingleOperand(flattened[0])
629      return temp_op.aggregate_and_return_name_for_output(
630          fused_op_name, output_index, out_graphdef)
631    else:
632      stack_node = _node_def_pb2.NodeDef()
633      stack_node.op = "Unpack"
634      stack_node.name = "OpHintUnstack-%s" % flattened[0].name
635      stack_node.attr["num"].i = len(flattened)
636      output_type = flattened[0].attr["T"].type
637      stack_node.attr["T"].type = output_type
638      stack_node.input.append(_tensorflow_output_name(
639          fused_op_name, output_index))
640      out_graphdef.node.extend([stack_node])
641
642      for idx, discrete in enumerate(flattened):
643        output_node = _copy.deepcopy(discrete)
644        del output_node.input[:]
645        output_node.input.append(_tensorflow_output_name(stack_node.name, idx))
646        out_graphdef.node.extend([output_node])
647
648      return output_type
649
650  def __str__(self):
651    s = "\t\t\tAGGREGATE %s\n" % self.aggregation
652    for sort, val in self.names.iteritems():
653      s += "\t\t\t%d: %s\n" % (sort, val)
654    return s
655
656
657class _LiteFuncCall(object):
658  """Represent a TensorFlow Lite custom function.
659
660  This is uses to accumulate found hints in the graphdef into a single
661  conceptual unit.
662
663  Attributes:
664    inputs: inputs to the op (hash from index # to argument)
665    outputs: outputs to the op (hash from index # to argument)
666    function_name: the tflite custom op name to use
667    uuid: a unique call id for this particular call  (i.e.
668      multiple function calls would have the same function_name but different
669      uuids.
670    params: A param name to key value for op constant data. I.e. for
671      axis on a reduction, strides on a convolution, etc.
672    level: Level of the OpHint.
673    children_inputs_mappings: If the Ophint has children, children inputs
674      mappings indicate how their inputs & outputs are mapped.
675  """
676
677  def __init__(self):
678    self.inputs = {}
679    self.outputs = {}
680    self.function_name = None
681    self.uuid = None
682    self.params = {}
683    self.level = -1
684    self.children_inputs_mappings = {}
685
686  def flattened_inputs_and_outputs(self):
687    """Return a list of inputs and outputs in a flattened format.
688
689    Returns:
690      Tuple of (inputs, outputs). where input and output i a list of names.
691    """
692    def _flatten(input_or_output_dict):
693      flattened_items = []
694      for item in input_or_output_dict.values():
695        flattened_items.extend(item.flatten())
696      return flattened_items
697
698    return _flatten(self.inputs), _flatten(self.outputs)
699
700  def __str__(self):
701    def format_args(items):
702      s = ""
703      for idx, item in items.iteritems():
704        s += ("\t\t%d:\n" % idx) + str(item)
705      return s
706
707    inputs_str = "\tInputs\n" + format_args(self.inputs)
708    outputs_str = "\tOutputs\n" + format_args(self.outputs)
709
710    return (
711        "tflite function %s call %s level %d "
712        "\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" %
713        (self.function_name, self.uuid, self.level, inputs_str, outputs_str))
714
715
716def _find_all_hints_in_nodes(nodes):
717  """Look at the all the input nodes and return a list of LiteFuncCall objs.
718
719  Args:
720    nodes: A TensorFlow graph_def to look for LiteFuncCalls.
721
722  Returns:
723    a list of `LifeFuncCall` objects in the form
724
725  """
726  func_calls = _collections.defaultdict(_LiteFuncCall)
727
728  for node in nodes:
729    attr = node.attr
730    # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
731    uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
732    if (OpHint.FUNCTION_UUID_ATTR not in attr
733        or not attr[OpHint.FUNCTION_UUID_ATTR].s):
734      continue
735
736    # Start building function
737    call_def = func_calls[uuid]
738    call_def.uuid = uuid
739    call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
740    call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
741    # Get sorting and aggregation information
742
743    sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
744            if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
745    if sort == -1: sort = None
746    aggregation = None
747    if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
748      aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
749
750    if OpHint.CHILDREN_INPUTS_MAPPINGS in attr:
751      call_def.children_inputs_mappings = _json.loads(
752          _compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s))
753
754    # Add the input or output
755    def put_operand(stuff, index, sort, operand, aggregation):
756      """Add a given index into the function structure."""
757      if sort is None:
758        stuff[index] = _LiteSingleOperand(operand)
759      else:
760        if index not in stuff:
761          stuff[index] = _LiteAggregateOperand(aggregation)
762        stuff[index].add(sort, operand)
763
764    if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
765      put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i,
766                  sort, node, aggregation)
767    if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
768      put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i,
769                  sort, node, aggregation)
770
771    # Remember attributes
772    for a in attr:
773      if a.startswith("_tflite_attr_"):
774        call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor
775
776  return func_calls
777
778
779def _extract_topology_sequence_mapping(nodes):
780  return dict(
781      (_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes))
782
783
784def _find_children_hints_in_while_loop(function_def, nodes_mapping):
785  """Find children hints and all nodes inside the while loop.
786
787  Args:
788    function_def: Function def of the while loop.
789    nodes_mapping: While loop input_arg : real node name.
790
791  Returns:
792    Ordered children hints and all re-mapped nodes inside the while loop.
793  """
794  new_nodes = []
795
796  # Make nodes inside function def inputs point to the real nodes.
797  for node in function_def.node_def:
798    for i in range(len(node.input)):
799      if node.input[i] in nodes_mapping:
800        node.input[i] = nodes_mapping[node.input[i]]
801    new_nodes.append(_copy.deepcopy(node))
802  name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def)
803  children_hints = _find_all_hints_in_nodes(new_nodes)
804  children_hints_q = []
805  # Ordered by the outputs.
806  for hint in _six.itervalues(children_hints):
807    _, output_names = hint.flattened_inputs_and_outputs()
808    seq = name_to_seq_num[output_names[0]]
809    for output_name in output_names:
810      seq = min(seq, name_to_seq_num[output_name])
811    children_hints_q.append((seq, hint))
812  children_hints_q.sort(key=lambda tup: tup[0])
813  ordered_children_hints = [x[1] for x in children_hints_q]
814  return ordered_children_hints, new_nodes
815
816
817def _find_children_hints(call, graph_def):
818  """Find all children hints.
819
820  For a given OpHint, we find all children hints inside it, we also copy all the
821  nodes inside function defs (if applicable) to the original graph_def, they are
822  returned in a list as well.
823
824  Args:
825    call: Parent OpHint that contains children ophints.
826    graph_def: Original graph def.
827
828  Returns:
829    Ordered children hints inside the parent ophint; new graph def that contains
830    nodes inside function defs (if applicable); nodes inside function defs.
831  """
832  name_to_input_name, _, _ = _extract_graph_summary(graph_def)
833  input_names, output_names = call.flattened_inputs_and_outputs()
834
835  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
836  reachable_by_output = _bfs_for_reachable_nodes(output_names,
837                                                 name_to_input_name)
838  output_nodes_set = set(output_names)
839  children_hints = []
840  out = _graph_pb2.GraphDef()
841  out.library.CopyFrom(graph_def.library)
842  out.versions.CopyFrom(graph_def.versions)
843  function_def_nodes = set()
844  for node in graph_def.node:
845    out.node.extend([_copy.deepcopy(node)])
846    n = _tensor_name_base(node.name)
847    if n in reachable_by_output:
848      if n not in reachable_by_input and n not in output_nodes_set:
849        # special handle for while loop function def.
850        if node.op == "While":
851          body_name = node.attr["body"].func.name
852          inputs_outside_loop = node.input
853          for function_def in graph_def.library.function:
854            if function_def.signature.name == body_name:
855              function_inputs = function_def.signature.input_arg
856              assert len(inputs_outside_loop) == len(function_inputs)
857              nodes_mapping = {}
858              for i in range(len(function_inputs)):
859                nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i]
860              # TODO(b/123050804): Consider use grappler.
861              (children_hints_in_loop,
862               new_nodes) = _find_children_hints_in_while_loop(
863                   function_def, nodes_mapping)
864              function_def_nodes.update([x.name for x in new_nodes])
865              children_hints.extend(children_hints_in_loop)
866              out.node.extend(new_nodes)
867
868  return children_hints, out, function_def_nodes
869
870
871def _tensor_name_base(full_tensor_name):
872  """Removes the device assignment code from a tensor.
873
874  e.g. _tensor_name_base("foo:3") => "foo"
875
876  Args:
877    full_tensor_name: A tensor name that is annotated with a device placement
878      (this is what tensor flow introspection gives).
879  Returns:
880    A name without any device assignment.
881  """
882  if full_tensor_name.startswith("^"):
883    return full_tensor_name[1:]
884  return full_tensor_name.split(":")[0]
885
886
887def _tensorflow_output_name(tensor_name, output_index):
888  return tensor_name if output_index == 0 else "%s:%d" % (tensor_name,
889                                                          output_index)
890
891
892# TODO(aselle): This should be converted to grappler in the future.
893def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
894                           name_to_input_name):
895  """Checks to make sure node only connects to predecessor graph through inputs.
896
897  Args:
898    n: Node to check
899    reachable_by_input: Nodes that are reachable by all inputs of subgraph
900    input_nodes_set: The set of nodes that are "inputs".
901    name_to_input_name: Maps from name to the list of inputs.
902
903  Raises:
904    TypeError: If the given node uses items past inputs directly.
905  """
906  next_to_visit = [n]
907  visited = set()
908  while next_to_visit:
909    current_node = next_to_visit.pop()
910    visited.add(current_node)
911    if (current_node in reachable_by_input
912        and current_node not in input_nodes_set):
913      raise TypeError(
914          "Node %s uses input %s not in input_nodes." % (n, current_node))
915    if current_node not in input_nodes_set:
916      next_to_visit += [
917          input_node for input_node in name_to_input_name[current_node]
918          if input_node not in visited
919      ]
920
921
922# TODO(aselle): This should be converted to grappler in the future.
923def _convert_single_op_hint_to_stub(call,
924                                    graph_def,
925                                    function_def_nodes=None,
926                                    is_last_run=True):
927  """Given a graph_def, converts `call` into a stub and returns a new graph_def.
928
929  Args:
930    call: A single function call to be converted.
931    graph_def: A graph_def to use as input (that has call obviously).
932    function_def_nodes: Nodes inside the function def those are not connected to
933      the graph.
934    is_last_run: Whether it is the last run for a given pass (for OpHint has
935      children).
936
937  Returns:
938    A new transformed graph-def that has call as a stub (single op).
939
940  Note: after this process, the graph_def can no longer be loaded into
941      the tensorflow runtime, so all future manipulations are done in graph_def
942      level.
943  """
944  if function_def_nodes is None:
945    function_def_nodes = set()
946  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
947      graph_def)
948  input_names, output_names = call.flattened_inputs_and_outputs()
949
950  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
951  reachable_by_output = _bfs_for_reachable_nodes(output_names,
952                                                 name_to_input_name)
953  output_nodes_set = set(output_names)
954  nodes_after_fuse = []
955  nodes_deleted_by_fuse = set()
956  # Classify each node. We want to keep everything reachable by input, but
957  # we don't know if things that are not reachable by output or input (things
958  # after fusing).
959  for node in graph_def.node:
960    n = _tensor_name_base(node.name)
961    if n in reachable_by_output:
962      if n not in reachable_by_input and n not in output_nodes_set:
963        nodes_deleted_by_fuse.add(n)
964    elif n not in reachable_by_input and n not in function_def_nodes:
965      # n is a node that after all the fusings, so keep it.
966      nodes_after_fuse.append(n)
967    else:
968      # In the last run, n is a node that is randomly in the graph but not
969      # connected to the chain of dependencies, we will delete n, otherwise
970      # we keep them.
971      if not is_last_run:
972        nodes_after_fuse.append(n)
973
974  # Make a new graphdef with all the pre-input and input nodes
975  out = _graph_pb2.GraphDef()
976  reachable_by_input_sorted = sorted(
977      list(reachable_by_input), key=lambda n: name_to_seq_num[n])
978  for node in reachable_by_input_sorted:
979    out.node.extend([_copy.deepcopy(name_to_node[node])])
980
981  # Create any stacks to aggregate arguments into to a single input
982  # i.e. for static_rnn's.
983  # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1
984  sorted_input_indices = list(call.inputs.keys())
985  sorted_input_indices.sort()
986  sorted_output_indices = list(call.outputs.keys())
987  sorted_output_indices.sort()
988  new_node = _node_def_pb2.NodeDef()
989  # Delegate to each operand to produce the proper new input for this stub node.
990  # In particular, an aggregate input will now be a Pack of some previously
991  # non-fused things.
992  for input_index in sorted_input_indices:
993    inputs = call.inputs[input_index]
994    input_name = inputs.aggregate_and_return_name_for_input(out)
995    new_node.input.append(input_name)
996  new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)
997
998  # Create the function
999  new_node.op = call.function_name
1000  new_node.name = call.uuid
1001  out.node.extend([new_node])
1002
1003  # Now call each output argument to give them a chance to make the proper
1004  # output type and add it to our new_node.
1005  output_dtypes = []
1006  for output_index in sorted_output_indices:
1007    output = call.outputs[output_index]
1008    output_dtype = (
1009        output.aggregate_and_return_name_for_output(new_node.name, output_index,
1010                                                    out))
1011    output_dtypes.append(output_dtype)
1012  new_node.attr["_output_types"].list.type[:] = output_dtypes
1013  # TODO(aselle): what is right here?
1014  new_node.attr["_output_quantized"].b = False
1015
1016  # Add post output nodes that do not depend on the outputs
1017  for n in nodes_after_fuse:
1018    should_keep = True
1019    for input_name in name_to_input_name[n]:
1020      if input_name in nodes_deleted_by_fuse:
1021        should_keep = False
1022    if should_keep:
1023      out.node.extend([_copy.deepcopy(name_to_node[n])])
1024
1025  # Misc. graph_def data that needs copying.
1026  out.library.CopyFrom(graph_def.library)
1027  out.versions.CopyFrom(graph_def.versions)
1028
1029  return out
1030
1031
1032# TODO(aselle): This should be converted to grappler in the future.
1033def _remove_one_redundant_stack_unstack(in_graph_def):
1034  """Removes a stack->unstack pattern from in_graph_def in a returned graph.
1035
1036  Args:
1037    in_graph_def: Graph def to use as input.
1038  Returns:
1039    Simplified tuple (graph_def, changed_something) where changed_something
1040    is true if anything was done.
1041  """
1042  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
1043      in_graph_def)
1044  del name_to_seq_num
1045
1046  # TODO(aselle): Make this not hardcoded.
1047  do_generic_pack_unpack = True
1048
1049  out = _graph_pb2.GraphDef()
1050  out.library.CopyFrom(in_graph_def.library)
1051  out.versions.CopyFrom(in_graph_def.versions)
1052  for n in in_graph_def.node:
1053    node_name = _tensor_name_base(n.name)
1054    if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"):
1055      continue
1056    next_to_visit = [node_name]
1057    visited = set()
1058
1059    unpack_nodes = set()
1060    pack_node = node_name
1061
1062    # Find a pattern of unstack connected to a stack (with identities
1063    # in between.
1064    matches_pattern = True
1065    is_hint_created_stack = False
1066    while next_to_visit:
1067      current_node_name = next_to_visit[0]
1068      visited.add(current_node_name)
1069      del next_to_visit[0]
1070      node = name_to_node[current_node_name]
1071      is_op_hint_stack = node.name.startswith("OpHintStack")
1072      is_op_hint_unstack = node.name.startswith("OpHintUnstack")
1073      if (node.op == "Identity" or is_op_hint_stack
1074          or (do_generic_pack_unpack and node.op == "Pack")):
1075        is_hint_created_stack |= is_op_hint_stack
1076        next_to_visit += [
1077            input_node for input_node in name_to_input_name[current_node_name]
1078            if input_node not in visited
1079        ]
1080      elif (is_op_hint_unstack
1081            or (do_generic_pack_unpack and node.op == "Unpack")):
1082        unpack_nodes.add(node.name)
1083        is_hint_created_stack &= is_op_hint_unstack
1084      else:
1085        matches_pattern = False
1086        break
1087      visited.add(node.name)
1088
1089    if matches_pattern and len(unpack_nodes) == 1:
1090      pack_node = node_name
1091
1092      # Check to see if anyone depends on the intermediate identity or the
1093      # Unstacked form
1094      no_external_dependency = True
1095      for other_n in in_graph_def.node:
1096        if other_n.name in visited: continue
1097        for input_tensor in name_to_input_name[other_n.name]:
1098          input_op = _tensor_name_base(input_tensor)
1099          if input_op in visited and input_op != pack_node:
1100            no_external_dependency = False
1101      # Proceed with the substitution if the stack/unstack pair was created
1102      # through hints, or that it was not, but nobody is consuming things
1103      # between the stack and unstack.
1104      if is_hint_created_stack or no_external_dependency:
1105        end = unpack_nodes.pop()
1106        end_input = name_to_node[end].input[0]
1107        # All nodes that depend on the final stack need to be redone to use
1108        for other_n in in_graph_def.node:
1109          node_name = _tensor_name_base(other_n.name)
1110          if node_name not in visited:
1111            new_node = _copy.deepcopy(other_n)
1112            new_node.input[:] = [
1113                (end_input if stripped == pack_node else
1114                 non_stripped) for stripped, non_stripped in zip(
1115                     name_to_input_name[node_name], new_node.input[:])
1116            ]
1117            out.node.extend([new_node])
1118        return out, True
1119  return in_graph_def, False
1120
1121
1122def _remove_redundant_stack_unstack(graph_def):
1123  curr = graph_def
1124  del graph_def
1125  changed_stuff = True
1126  while changed_stuff:
1127    curr, changed_stuff = _remove_one_redundant_stack_unstack(curr)
1128  return curr
1129
1130
1131def _get_correct_mapping(original_index, nodes):
1132  # Special handle for the index is -1 case.
1133  # If it is -1, return the last index.
1134  if original_index == -1:
1135    node_indices = nodes.keys()
1136    node_indices = sorted(node_indices)
1137    return node_indices[-1]
1138  else:
1139    return original_index
1140  return original_index
1141
1142
1143def _convert_op_hints_to_stubs_helper(
1144    graph_def, write_callback=lambda sess, graph_def: None):
1145  """Converts a graph_def to a new graph_def where all op hints are stubbed.
1146
1147  Args:
1148    graph_def: A graph def that we should convert.
1149    write_callback: A function pointer that can be used to write intermediate
1150      steps of graph transformation (optional).
1151  Returns:
1152    A new stubbed graph_def.
1153  """
1154  hints = _find_all_hints_in_nodes(graph_def.node)
1155
1156  hints_q = []
1157  for hint in _six.itervalues(hints):
1158    hints_q.append((hint.level, hint.uuid))
1159
1160  hints_q.sort(key=lambda tup: tup[0])
1161  for i in range(len(hints_q) - 1, -1, -1):
1162    level, hint_uuid = hints_q[i]
1163
1164  curr_graph_def = graph_def
1165  del graph_def  # prevent using graph_def again (common source of error)
1166  for i in range(len(hints_q) - 1, -1, -1):
1167    level, hint_uuid = hints_q[i]
1168    if level >= 2:
1169      children_hints, curr_graph_def, function_def_nodes = _find_children_hints(
1170          hints[hint_uuid], curr_graph_def)
1171      # pylint: disable=superfluous-parens
1172      assert (len(children_hints) > 0)  #  pylint: disable=g-explicit-length-test
1173      # pylint: enable=superfluous-parens
1174
1175      # Re-wire the children hints inputs/outputs, so latter child's inputs
1176      # connect to previous child node's outputs.
1177      children_inputs_mappings = hints[hint_uuid].children_inputs_mappings
1178      for j in range(len(children_hints)):
1179        child_hint = children_hints[j]
1180        if j == 0:
1181          for mapping in children_inputs_mappings["parent_first_child_input"]:
1182            parent_input_index = _get_correct_mapping(
1183                mapping["parent_ophint_input_index"], hints[hint_uuid].inputs)
1184            child_input_index = _get_correct_mapping(
1185                mapping["first_child_ophint_input_index"], child_hint.inputs)
1186            child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[
1187                parent_input_index]
1188        else:
1189          for mapping in children_inputs_mappings[
1190              "internal_children_input_output"]:
1191            input_index = _get_correct_mapping(mapping["child_input_index"],
1192                                               child_hint.inputs)
1193            output_index = _get_correct_mapping(mapping["child_output_index"],
1194                                                children_hints[j - 1].outputs)
1195            child_hint.inputs[input_index] = children_hints[
1196                j - 1].outputs[output_index]
1197        if j == len(children_hints) - 1:
1198          for mapping in children_inputs_mappings["parent_last_child_output"]:
1199            parent_output_index = _get_correct_mapping(
1200                mapping["parent_output_index"], hints[hint_uuid].outputs)
1201            child_output_index = _get_correct_mapping(
1202                mapping["child_output_index"], child_hint.outputs)
1203            child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[
1204                parent_output_index]
1205
1206      for j in range(len(children_hints)):
1207        child_hint = children_hints[j]
1208        curr_graph_def = _convert_single_op_hint_to_stub(
1209            child_hint, curr_graph_def, function_def_nodes,
1210            j == len(children_hints) - 1)
1211    else:
1212      curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid],
1213                                                       curr_graph_def)
1214      write_callback(curr_graph_def, "initial")
1215  # The stubbing process can create stacks/unstacks in the case of LSTMs
1216  # remove them.
1217  curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def)
1218  return curr_graph_def
1219
1220
1221def find_all_hinted_output_nodes(session=None, graph_def=None):
1222  """Find all Ophints output nodes in the graph.
1223
1224  This is used to get all the output nodes those are ophinted, it is important
1225  for operation like convert_variables_to_constants keep all ophints structure.
1226  Note: only one of session or graph_def should be used, not both.
1227
1228  Args:
1229    session: A TensorFlow session that contains the graph to convert.
1230    graph_def: A graph def that we should convert.
1231
1232  Returns:
1233    A list of OpHints output nodes.
1234  Raises:
1235    ValueError: If both session and graph_def are provided.
1236  """
1237  if session is not None and graph_def is not None:
1238    raise ValueError("Provide only one of session and graph_def.")
1239  hinted_outputs_nodes = []
1240  if session is not None:
1241    hints = _find_all_hints_in_nodes(session.graph_def.node)
1242  elif graph_def is not None:
1243    hints = _find_all_hints_in_nodes(graph_def.node)
1244  for hint in _six.itervalues(hints):
1245    _, ouput_nodes = hint.flattened_inputs_and_outputs()
1246    hinted_outputs_nodes.extend(ouput_nodes)
1247  return hinted_outputs_nodes
1248
1249
1250@_tf_export("lite.experimental.convert_op_hints_to_stubs")
1251def convert_op_hints_to_stubs(session=None,
1252                              graph_def=None,
1253                              write_callback=lambda graph_def, comments: None):
1254  """Converts a graphdef with LiteOp hints into stub operations.
1255
1256  This is used to prepare for toco conversion of complex intrinsic usages.
1257  Note: only one of session or graph_def should be used, not both.
1258
1259  Args:
1260    session: A TensorFlow session that contains the graph to convert.
1261    graph_def: A graph def that we should convert.
1262    write_callback: A function pointer that can be used to write intermediate
1263      steps of graph transformation (optional).
1264  Returns:
1265    A new graphdef with all ops contained in OpHints being replaced by
1266    a single op call with the right parameters.
1267  Raises:
1268    ValueError: If both session and graph_def are provided.
1269  """
1270
1271  if session is not None and graph_def is not None:
1272    raise ValueError("Provide only one of session and graph_def.")
1273
1274  if session is not None:
1275    return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback)
1276  elif graph_def is not None:
1277    return _convert_op_hints_to_stubs_helper(graph_def, write_callback)
1278  else:
1279    raise ValueError("Must specify session or graph_def as input.")
1280
1281
1282_allowed_symbols = [
1283    "OpHint", "convert_op_hints_to_stubs", "convert_op_hints_to_stubs_new",
1284    "find_all_hinted_output_nodes"
1285]
1286remove_undocumented(__name__, _allowed_symbols)
1287