1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Define tflite op hints (intrinsic operations).
16
17This essentially allows defining a TensorFlow API for tflite operations in
18Python with hints on how they are represented in TensorFlow Lite. This basically
19is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution
20graph and is useful for LSTMs and other complicated TensorFlow constructions
21that are difficult to pattern match in TOCO, but are represented by a single
22accelerated tflite op.
23
24Example:
25  def tflite_cool_activation(input):
26    # A cool activation function.
27    custom = tf.lite.OpHint("cool_activation")
28    input, = custom.add_inputs(input)
29    output = tf.sigmoid(input) * input
30    output, = custom.add_outputs(output)
31    return output
32
33  image = tf.compat.v1.placeholder(tf.float32, (1, 16, 16, 1))
34  output = tf.identity(tflite_cool_activation(image))
35
36  session = tf.compat.v1.Session()
37
38  graphdef_to_convert = tf.lite.experimental.convert_op_hints_to_stubs(session)
39  tflite_graph = tf.compat.v1.lite.toco_convert(
40      graphdef_to_convert, [image], [output], allow_custom_ops=True)
41  with open("/tmp/graph.fb", "wb") as fp:
42    fp.write(tflite_graph)
43
44How does it work?:
45
46OpHint is a helper that you use when defining a vanilla python function.
47It allows you to wrap arguments with tf.identities with some custom attributes.
48These attributes allow you to find the original block of ops that was created.
49For example, if you use cool_activation above you essentially get:
50
51a_input = tf.identity()
52result = tf.multiply(tf.sigmoid(a_input), a_input)
53output = tf.identity()
54
55a_input, output are identities that have parameters representing
56what argument they are, what the name of the function they should turn into
57in tf lite as well as a guid that uniquely identifies a particular invocation.
58
59Once you have built your whole tensorflow graph, you can run it and train it
60as usual, but after you have done that, you need to convert the graph into
61a form that replaces these subgraphs wrapped in identities to stub ops. These
62ops don't actually exist in the normal TensorFlow runtime, but will be
63understood by toco later. The generated TensorFlow Lite flatbuffer file will
64contain a custom operator called "cool_activation". Developer needs to implement
65and register this operator in TensorFlow Lite in order to do inference.
66"""
67
68# TODO(aselle): Make this use generic graph transformations.
69# TODO(aselle): _tensor_name_base should be called _tensor_name_to_op_name.
70
71from __future__ import absolute_import
72from __future__ import division
73from __future__ import print_function
74
75import collections as _collections
76import copy as _copy
77import json as _json
78import uuid as _uuid
79import six as _six
80
81from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2
82from tensorflow.core.framework import graph_pb2 as _graph_pb2
83from tensorflow.core.framework import node_def_pb2 as _node_def_pb2
84from tensorflow.python.framework import dtypes as _dtypes
85from tensorflow.python.framework import ops as _ops
86from tensorflow.python.framework import tensor_util as _tensor_util
87# TODO(aselle): publicize these apis if we continue to use these.
88from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
89from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
90from tensorflow.python.ops import array_ops as _array_ops
91from tensorflow.python.util import compat as _compat
92from tensorflow.python.util import deprecation as _deprecation
93from tensorflow.python.util.all_util import remove_undocumented
94from tensorflow.python.util.tf_export import tf_export as _tf_export
95
96
97@_tf_export(v1=["lite.OpHint"])
98@_deprecation.deprecated(
99    None,
100    "Please follow instructions under "
101    "https://www.tensorflow.org/lite/convert/operation_fusion for operation"
102    "fusion in tflite."
103)
104class OpHint(object):
105  """A class that helps build tflite function invocations.
106
107  It allows you to take a bunch of TensorFlow ops and annotate the construction
108  such that toco knows how to convert it to tflite. This embeds a pseudo
109  function in a TensorFlow graph. This allows embedding high-level API usage
110  information in a lower level TensorFlow implementation so that an alternative
111  implementation can be substituted later.
112
113  Essentially, any "input" into this pseudo op is fed into an identity, and
114  attributes are added to that input before being used by the constituent ops
115  that make up the pseudo op. A similar process is done to any output that
116  is to be exported from the current op.
117
118  """
119  # TODO(aselle): When TensorFlow functions functionality works for arbitrary
120  # constructs, this mechanism can be retired and changed to use python defun's.
121
122  # Attr constants that are used for representation in the GraphDef. These
123  # will be used on every Identity op that is involved in a total OpHint.
124
125  # Name of the OpHint function (cosmetic).
126  FUNCTION_NAME_ATTR = "_tflite_function_name"
127  # UUID of the function (each OpHint gets a new uuid).
128  FUNCTION_UUID_ATTR = "_tflite_function_uuid"
129  # The input index of the input (or nothing if it is an output).
130  FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index"
131  # The output index of the output (or nothing if it is an input).
132  FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index"
133  # An index that orders aggregate arguments. Aggregate arguments are ones
134  # that are separate but will be fused horizontally. For example a static LSTM
135  # has a lstm cell for each time step. Each one has a separate opHint, but a
136  # fused SequentialLSTM will treat this as a single tensor.
137  FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index"
138  # The way in which multiple parts of the aggregate argument will be joined
139  # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST,
140  # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK.
141  FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate"
142  # On fused OpHint stub, the order of inputs that the final LSTM call will
143  # have. What this means is that the TensorFlow order might be
144  # "foo", "bar", "stuff" and you might want the TF lite op order to be
145  # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this
146  # attribute to [2, 0, 1, -1].
147  TFLITE_INPUT_INDICES = "_tflite_input_indices"
148  # OpHint level.
149  FUNCTION_LEVEL_ATTR = "_tflite_ophint_level"
150  # Ophint internal mapping, this is for high level Ophint only.
151  # This basically contains three kinds of mapping:
152  #   1) How parental ophinted inputs map to the first child ophinted inputs;
153  #   2) How internal children nodes are connected;
154  #   3) How parental ophinted outputs map to the last child ophinted outputs.
155  CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping"
156
157  # Types of aggregations
158  #  stack: stacks all ophints with matching tags. i.e. for a static rnn.
159  #   specifically, this is good for an input or output to a static rnn cell.
160  AGGREGATE_STACK = "stack"
161  # first: only takes the first output (one with lowest sort index)
162  # of matching tags. This is good for the input state to an RNN.
163  AGGREGATE_FIRST = "first"
164  # aggregation last takes only the last tag (one with highest sort index).
165  # This is good for an output value on the last stack item of a
166  # static rnn.
167  AGGREGATE_LAST = "last"
168
169  class OpHintArgumentTracker(object):
170    """Conceptually tracks indices of arguments of "OpHint functions".
171
172    The inputs and arguments of these functions both use an instance
173    of the class so they can have independent numbering.
174    """
175
176    def __init__(self,
177                 function_name,
178                 unique_function_id,
179                 node_name_prefix,
180                 attr_name,
181                 level=1,
182                 children_inputs_mappings=None):
183      """Initialize ophint argument.
184
185      Args:
186        function_name: Name of the function that this tracks arguments for.
187        unique_function_id: UUID of function that this tracks arguments for.
188        node_name_prefix: How identities that are created are named.
189        attr_name: Name of attribute to use to store the index for this hint.
190          i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX
191        level: Hierarchical level of the Ophint node, a number.
192        children_inputs_mappings: Inputs/Outputs mapping for children hints.
193      """
194
195      # The global index is the argument index of the op. This is in contrast
196      # to the sort index which is the sequence number of a particular instance
197      # of a given global index. For example, you may have called add hint
198      # twice with the tag "foo". Then the global index will be 0 for both
199      # and the sort index will be 0 for the first added and 1 for the second.
200      self._function_name = function_name
201      self._unique_function_id = unique_function_id
202      self._next_global_index = 0  # The absolute global index
203      self._used_global_indices = set()
204      self._tag_to_global_index = {}  # The argument index a given tag maps to
205      self._tag_to_next_sort_index = {}  # The current index for each tag
206      self._node_name_prefix = node_name_prefix
207      self._attr_name = attr_name
208      self._level = level
209      self._children_inputs_mappings = children_inputs_mappings
210
211    def _get_new_global_index(self, index_override):
212      """Return the next unused argument index in order or use an override.
213
214      Args:
215        index_override: An index to use instead of the next available or None
216          to use the next available.
217
218      Returns:
219        A valid global_index to use for the next hint argument.
220
221      Raises:
222        ValueError: If the index_override is already used by another hint.
223      """
224      if index_override is None:
225        global_index = self._next_global_index
226      else:
227        if index_override in self._used_global_indices:
228          raise ValueError("Index %d was already used by another call to add")
229        global_index = index_override
230      # Make next_global_index valid
231      self._used_global_indices.add(global_index)
232      while self._next_global_index in self._used_global_indices:
233        self._next_global_index += 1
234      return global_index
235
236    def add(self, arg, tag=None, name=None, aggregate=None,
237            index_override=None):
238      """Return a wrapped tensor of an input tensor as an argument.
239
240      Args:
241        arg: A TensorFlow tensor that should be considered an argument.
242        tag: String tag to identify arguments that should be packed.
243        name: Name of argument. This is included in the Identity hint op names.
244        aggregate: Strategy to aggregate.
245        Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
246          and OpHint.AGGREGATE_STACK.
247          Note, aggregate is only valid if tag is specified.
248        index_override: Specify what input/output index should this be in the
249          final stub. i.e. add(arg0, index=1); add(arg1, index=0) will make the
250          final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than
251          the default call order based ordering.
252
253      Returns:
254        A tensor representing the wrapped argument.
255
256      Raises:
257        ValueError: When indices are not consistent.
258      """
259
260      # Find the appropriate index
261      if tag is None:
262        if aggregate is not None:
263          raise ValueError("You must specify `tag` if using aggregate.")
264        global_index = self._get_new_global_index(index_override)
265        sort_index = None
266      else:
267        if aggregate is None:
268          raise ValueError("You must specify `aggregate` if using tag.")
269        if tag not in self._tag_to_global_index:
270          self._tag_to_global_index[tag] = (
271              self._get_new_global_index(index_override))
272          self._tag_to_next_sort_index[tag] = 0
273        elif (index_override and
274              index_override != self._tag_to_global_index[tag]):
275          raise ValueError(
276              "Tag %r was called with two indices %r and %r" %
277              (tag, index_override, self._tag_to_global_index[tag]))
278        global_index = self._tag_to_global_index[tag]
279        sort_index = self._tag_to_next_sort_index[tag]
280        self._tag_to_next_sort_index[tag] += 1
281
282      uuid = self._unique_function_id
283      name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name,
284                                    uuid, global_index, sort_index, name)
285
286      identity_op = _array_ops.identity(arg, name=name)
287
288      # pylint: disable=protected-access
289      identity_op.op._set_attr(
290          OpHint.FUNCTION_NAME_ATTR,
291          _attr_value_pb2.AttrValue(
292              s=_compat.as_bytes(self._function_name)))
293      identity_op.op._set_attr(
294          OpHint.FUNCTION_UUID_ATTR,
295          _attr_value_pb2.AttrValue(
296              s=_compat.as_bytes(self._unique_function_id)))
297      identity_op.op._set_attr(
298          self._attr_name, _attr_value_pb2.AttrValue(i=global_index))
299      identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR,
300                               _attr_value_pb2.AttrValue(i=self._level))
301      if self._children_inputs_mappings:
302        identity_op.op._set_attr(
303            OpHint.CHILDREN_INPUTS_MAPPINGS,
304            _attr_value_pb2.AttrValue(
305                s=_compat.as_bytes(_json.dumps(
306                    self._children_inputs_mappings))))
307
308      if sort_index is not None:
309        identity_op.op._set_attr(
310            OpHint.FUNCTION_SORT_INDEX_ATTR,
311            _attr_value_pb2.AttrValue(i=sort_index))
312      if aggregate is not None:
313        identity_op.op._set_attr(
314            OpHint.FUNCTION_AGGREGATE_ATTR,
315            _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate))))
316      # pylint: enable=protected-access
317      return identity_op
318
319  def __init__(self,
320               function_name,
321               level=1,
322               children_inputs_mappings=None,
323               **kwargs):
324    """Create a OpHint.
325
326    Args:
327      function_name: Name of the function (the custom op name in tflite)
328      level: OpHint level.
329      children_inputs_mappings: Children OpHint inputs/outputs mapping.
330        children_inputs_mappings should like below:
331        "parent_first_child_input":
332            [{"parent_input_index": num, "child_input_index": num}, ...]
333        "parent_last_child_output":
334            [{"parent_output_index": num, "child_output_index": num}, ...]
335        "internal_children_input_output":
336            [{"child_input_index": num, "child_output_index": num}, ...]
337      **kwargs: Keyword arguments of any constant attributes for the function.
338    """
339    self._function_name = function_name
340    self._level = level
341    if self._level == 1:
342      assert children_inputs_mappings is None
343    else:
344      assert isinstance(children_inputs_mappings, dict)
345    self._children_inputs_mappings = children_inputs_mappings
346    if self._children_inputs_mappings is not None:
347      self._validate_children_inputs_mappings(self._children_inputs_mappings)
348    self._unique_function_id = _uuid.uuid1().hex  # TODO(aselle): Unique enough?
349    self._attrs_to_store_later = kwargs
350    self._stored_attrs = False
351    self._inputs = OpHint.OpHintArgumentTracker(
352        self._function_name, self._unique_function_id, "InputHint",
353        OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings)
354    self._outputs = OpHint.OpHintArgumentTracker(
355        self._function_name, self._unique_function_id, "OutputHint",
356        OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level,
357        self._children_inputs_mappings)
358
359  def _validate_children_inputs_mappings(self, children_inputs_mappings):
360    """Validate children inputs mappings is in the right format.
361
362    Args:
363      children_inputs_mappings: the Children ophint inputs/outputs mapping.
364    """
365    assert isinstance(children_inputs_mappings, dict)
366    assert "parent_first_child_input" in children_inputs_mappings
367    assert "parent_last_child_output" in children_inputs_mappings
368    assert "internal_children_input_output" in children_inputs_mappings
369
370    # validate parent_first_child_input.
371
372    def assert_dictlist_has_keys(dictlist, keys):
373      for dikt in dictlist:
374        assert isinstance(dikt, dict)
375        for key in keys:
376          assert key in dikt
377
378    assert_dictlist_has_keys(
379        children_inputs_mappings["parent_first_child_input"],
380        ["parent_ophint_input_index", "first_child_ophint_input_index"])
381    assert_dictlist_has_keys(
382        children_inputs_mappings["parent_last_child_output"],
383        ["parent_output_index", "child_output_index"])
384    assert_dictlist_has_keys(
385        children_inputs_mappings["internal_children_input_output"],
386        ["child_input_index", "child_output_index"])
387
388  def _setattr(self, dest_op, name, value):
389    tensor_value = _ops.convert_to_tensor(value)
390    # pylint: disable=protected-access
391    dest_op.op._set_attr(name, _attr_value_pb2.AttrValue(
392        tensor=tensor_value.op.node_def.attr["value"].tensor))
393    # pylint: enable=protected-access
394
395  def add_input(self, *args, **kwargs):
396    """Add a wrapped input argument to the hint.
397
398    Args:
399      *args: The input tensor.
400      **kwargs:
401        "name" label
402        "tag" a tag to group multiple arguments that will be aggregated. I.e.
403          a string like 'cool_input'. Basically multiple inputs can be added
404          to the same hint for parallel operations that will eventually be
405          combined. An example would be static_rnn which creates multiple copies
406          of state or inputs.
407        "aggregate" aggregation strategy that is valid only for tag non None.
408          Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
409          and OpHint.AGGREGATE_STACK.
410        "index_override" The global index to use. This corresponds to the
411          argument order in the final stub that will be generated.
412    Returns:
413      The wrapped input tensor.
414    """
415    return self._inputs.add(*args, **kwargs)
416
417  def add_output(self, *args, **kwargs):
418    """Add a wrapped output argument to the hint.
419
420    Args:
421      *args: The output tensor.
422      **kwargs:
423        "name" label
424        "tag" a tag to group multiple arguments that will be aggregated. I.e.
425          a string like 'cool_input'. Basically multiple inputs can be added
426          to the same hint for parallel operations that will eventually be
427          combined. An example would be static_rnn which creates multiple copies
428          of state or inputs.
429        "aggregate" aggregation strategy that is valid only for tag non None.
430          Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
431          and OpHint.AGGREGATE_STACK.
432        "index_override" The global index to use. This corresponds to the
433          argument order in the final stub that will be generated.
434    Returns:
435      The wrapped output tensor.
436    """
437    return self._outputs.add(*args, **kwargs)
438
439  def add_inputs(self, *args, **kwargs):
440    """Add a sequence of inputs to the function invocation.
441
442    Args:
443      *args: List of inputs to be converted (should be Tf.Tensor).
444      **kwargs: This allows 'names' which should be a list of names.
445
446    Returns:
447      Wrapped inputs (identity standins that have additional metadata). These
448      are also are also tf.Tensor's.
449    """
450    if "names" in kwargs:
451      return [
452          self._inputs.add(arg, name=name)
453          for arg, name in zip(args, kwargs["names"])
454      ]
455    else:
456      return [self._inputs.add(arg) for arg in args]
457
458  def add_outputs(self, *args, **kwargs):
459    """Add a sequence of outputs to the function invocation.
460
461    Args:
462      *args: List of outputs to be converted (should be tf.Tensor).
463      **kwargs: See
464
465    Returns:
466      Wrapped outputs (identity standins that have additional metadata). These
467      are also tf.Tensor's.
468    """
469    if "names" in kwargs:
470      return [
471          self._outputs.add(arg, name=name)
472          for arg, name in zip(args, kwargs["names"])
473      ]
474    else:
475      return [self._outputs.add(arg) for arg in args]
476
477
478class _LiteOperand(object):
479  """Abstract operand for a tflite hint function._dynamic_rnn_loop.
480
481  This is a base class that handles representing arguments to an OpHint.
482  It also is able to serialize operands to the stubbed graph_def.
483  Child classes are responsible for being able to
484  store information about the hint identity operators. They are also responsible
485  for knowing how to serialize to output graphdefs.
486
487  Typically this will be implemented by holding one or more identity nodes
488  that were previously discovered as hints.
489  """
490
491  def aggregate_and_return_name_for_input(self, out_graphdef):
492    """This adds the node(s) to out_graphdef and returns the input node name.
493
494    Args:
495      out_graphdef: A graphdef that is ready to have this input added.
496
497    Returns:
498      The output that the stub should use as an input for this operand.
499
500    Raises:
501      RuntimeError: if the method is not implemented.
502    """
503    del out_graphdef
504    raise RuntimeError("Unimplemented abstract method.")
505
506  def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
507                                           out_graphdef):
508    """Add node(s) to graph representing output operands and returns type.
509
510    Args:
511      fused_op_name: name of the fused op stub name.
512      output_index: Output index that we are currently processing from stub.
513      out_graphdef: The destination graphdef we are currently building up.
514
515    Returns:
516      The datatype of this identity.
517
518    Raises:
519      RuntimeError: if the method is not implemented.
520    """
521    del fused_op_name, output_index, out_graphdef
522    raise RuntimeError("Unimplemented abstract method.")
523
524
525class _LiteSingleOperand(_LiteOperand):
526  """A simple operand that is non-aggregated (i.e. most hints)."""
527
528  def __init__(self, node):
529    _LiteOperand.__init__(self)
530    self.node = node
531    self.name = _tensor_name_base(node.name)
532
533  def flatten(self):
534    return [self.name]
535
536  def aggregate_and_return_name_for_input(self, out_graphdef):
537    return self.name
538
539  def aggregate_and_return_name_for_output(self, fused_op_name, index,
540                                           out_graphdef):
541    output_node = _copy.deepcopy(self.node)
542    del output_node.input[:]
543    output_node.input.append(_tensorflow_output_name(fused_op_name, index))
544    out_graphdef.node.extend([output_node])
545    return self.node.attr["type"].i
546
547  def __str__(self):
548    return str(self.name)
549
550
551class _LiteAggregateOperand(_LiteOperand):
552  """An operand for a tflite hint function that is aggregated from many.
553
554  For example, an LSTM is a grid of operators that are all related. Inputs
555  going into them may need to be fused, so they should all be tracked as
556  related arguments.
557  """
558
559  def __init__(self, aggregation):
560    _LiteOperand.__init__(self)
561    self.aggregation = aggregation
562    self.names = {}
563    self.nodes = {}
564    self.flattened = None
565
566  def add(self, sort, node):
567    self.names[sort] = _tensor_name_base(node.name)
568    self.nodes[sort] = node
569
570  def flatten_nodes(self):
571    """Return a list of all the node protos in aggregation sorted order."""
572    if not self.flattened:
573      self.flattened = [None] * len(self.nodes)
574      for idx, node in _six.iteritems(self.nodes):
575        self.flattened[idx] = node
576      for n in self.nodes:
577        if n is None:
578          raise RuntimeError("Aggregate was missing argument.")
579      if self.aggregation == OpHint.AGGREGATE_FIRST:
580        self.flattened = self.flattened[:1]
581      elif self.aggregation == OpHint.AGGREGATE_LAST:
582        self.flattened = self.flattened[-1:]
583      elif self.aggregation == OpHint.AGGREGATE_STACK:
584        pass
585      else:
586        raise ValueError("Invalid aggregation type %r specified" %
587                         self.aggregation)
588    return self.flattened
589
590  def flatten(self):
591    """Return a list of all node names in aggregation sorted sorter."""
592    return [_tensor_name_base(x.name) for x in self.flatten_nodes()]
593
594  def aggregate_and_return_name_for_input(self, out_graphdef):
595    """This adds the nodes to out_graphdef and returns an aggregated output.
596
597    In particular, if you have 4 inputs to a hint stub, this will be the
598    node that you can use as an output. I.e. you have 4 timesteps from a
599    static rnn, then a fused UnidirectionalLSTM will expect 1 input with
600    all 4 time steps. So here we make a pack and return the output name of
601    that pack.
602
603    Args:
604      out_graphdef: A graphdef that is ready to have this input added.
605
606    Returns:
607      The name of a pack that aggregates this node.
608    """
609    flattened = self.flatten_nodes()
610    if (self.aggregation == OpHint.AGGREGATE_FIRST) or (
611        self.aggregation == OpHint.AGGREGATE_LAST):
612      assert len(flattened) == 1
613    if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK:
614      return _tensor_name_base(flattened[0].name)
615    else:
616      new_node = _node_def_pb2.NodeDef()
617      new_node.op = "Pack"
618      new_node.name = "OpHintStack-%s" % flattened[0].name
619      new_node.attr["N"].i = len(flattened)
620      new_node.attr["T"].type = flattened[0].attr["T"].type
621      for discrete in flattened:
622        new_node.input.append(_tensor_name_base(discrete.name))
623      out_graphdef.node.extend([new_node])
624      return new_node.name
625
626  def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
627                                           out_graphdef):
628    """This adds to `out_graphdef` all the unaggregated outputs.
629
630    I.e. we are outputting from a fused stub, but we need to make it compatible
631    with the unfused original graph so we insert an unpack. Ideally in a later
632    stage the unpack -> pack sequences will be removed.
633
634    Args:
635      fused_op_name: The name of the stub we are in the process of fusing.
636      output_index: The output output_index this object represents.
637      out_graphdef: The graphdef we are in the process of buildings
638
639    Returns:
640      The type of the aggregated output (so we can finish building the stub
641      op).
642    """
643    flattened = self.flatten_nodes()
644    if (self.aggregation == OpHint.AGGREGATE_FIRST) or (
645        self.aggregation == OpHint.AGGREGATE_LAST):
646      assert len(flattened) == 1
647    if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK:
648      temp_op = _LiteSingleOperand(flattened[0])
649      return temp_op.aggregate_and_return_name_for_output(
650          fused_op_name, output_index, out_graphdef)
651    else:
652      stack_node = _node_def_pb2.NodeDef()
653      stack_node.op = "Unpack"
654      stack_node.name = "OpHintUnstack-%s" % flattened[0].name
655      stack_node.attr["num"].i = len(flattened)
656      output_type = flattened[0].attr["T"].type
657      stack_node.attr["T"].type = output_type
658      stack_node.input.append(
659          _tensorflow_output_name(fused_op_name, output_index))
660      out_graphdef.node.extend([stack_node])
661
662      for idx, discrete in enumerate(flattened):
663        output_node = _copy.deepcopy(discrete)
664        del output_node.input[:]
665        output_node.input.append(_tensorflow_output_name(stack_node.name, idx))
666        out_graphdef.node.extend([output_node])
667
668      return output_type
669
670  def __str__(self):
671    s = "\t\t\tAGGREGATE %s\n" % self.aggregation
672    for sort, val in self.names.iteritems():
673      s += "\t\t\t%d: %s\n" % (sort, val)
674    return s
675
676
677class _LiteFuncCall(object):
678  """Represent a TensorFlow Lite custom function.
679
680  This is uses to accumulate found hints in the graphdef into a single
681  conceptual unit.
682
683  Attributes:
684    inputs: inputs to the op (hash from index # to argument)
685    outputs: outputs to the op (hash from index # to argument)
686    function_name: the tflite custom op name to use
687    uuid: a unique call id for this particular call  (i.e. multiple function
688      calls would have the same function_name but different uuids.
689    params: A param name to key value for op constant data. I.e. for axis on a
690      reduction, strides on a convolution, etc.
691    level: Level of the OpHint.
692    children_inputs_mappings: If the Ophint has children, children inputs
693      mappings indicate how their inputs & outputs are mapped.
694  """
695
696  def __init__(self):
697    self.inputs = {}
698    self.outputs = {}
699    self.function_name = None
700    self.uuid = None
701    self.params = {}
702    self.level = -1
703    self.children_inputs_mappings = {}
704
705  def flattened_inputs_and_outputs(self):
706    """Return a list of inputs and outputs in a flattened format.
707
708    Returns:
709      Tuple of (inputs, outputs). where input and output i a list of names.
710    """
711
712    def _flatten(input_or_output_dict):
713      flattened_items = []
714      for item in input_or_output_dict.values():
715        flattened_items.extend(item.flatten())
716      return flattened_items
717
718    return _flatten(self.inputs), _flatten(self.outputs)
719
720  def __str__(self):
721
722    def format_args(items):
723      s = ""
724      for idx, item in items.iteritems():
725        s += ("\t\t%d:\n" % idx) + str(item)
726      return s
727
728    inputs_str = "\tInputs\n" + format_args(self.inputs)
729    outputs_str = "\tOutputs\n" + format_args(self.outputs)
730
731    return (
732        "tflite function %s call %s level %d "
733        "\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" %
734        (self.function_name, self.uuid, self.level, inputs_str, outputs_str))
735
736
737def _find_all_hints_in_nodes(nodes):
738  """Look at the all the input nodes and return a list of LiteFuncCall objs.
739
740  Args:
741    nodes: A TensorFlow graph_def to look for LiteFuncCalls.
742
743  Returns:
744    a list of `LifeFuncCall` objects in the form
745
746  """
747  func_calls = _collections.defaultdict(_LiteFuncCall)
748
749  for node in nodes:
750    attr = node.attr
751    # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
752    if (OpHint.FUNCTION_UUID_ATTR not in attr or
753        not attr[OpHint.FUNCTION_UUID_ATTR].s):
754      continue
755    uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
756
757    # Start building function
758    call_def = func_calls[uuid]
759    call_def.uuid = uuid
760    call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
761    call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
762    # Get sorting and aggregation information
763
764    sort = (
765        attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
766        if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
767    if sort == -1:
768      sort = None
769    aggregation = None
770    if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
771      aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
772
773    if OpHint.CHILDREN_INPUTS_MAPPINGS in attr:
774      call_def.children_inputs_mappings = _json.loads(
775          _compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s))
776
777    # Add the input or output
778    def put_operand(stuff, index, sort, operand, aggregation):
779      """Add a given index into the function structure."""
780      if sort is None:
781        stuff[index] = _LiteSingleOperand(operand)
782      else:
783        if index not in stuff:
784          stuff[index] = _LiteAggregateOperand(aggregation)
785        stuff[index].add(sort, operand)
786
787    if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
788      put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i,
789                  sort, node, aggregation)
790    if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
791      put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i,
792                  sort, node, aggregation)
793
794    # Remember attributes
795    for a in attr:
796      if a.startswith("_tflite_attr_"):
797        call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor
798
799  return func_calls
800
801
802def _extract_topology_sequence_mapping(nodes):
803  return dict(
804      (_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes))
805
806
807def _find_children_hints_in_while_loop(function_def, nodes_mapping):
808  """Find children hints and all nodes inside the while loop.
809
810  Args:
811    function_def: Function def of the while loop.
812    nodes_mapping: While loop input_arg : real node name.
813
814  Returns:
815    Ordered children hints and all re-mapped nodes inside the while loop.
816  """
817  new_nodes = []
818
819  # Make nodes inside function def inputs point to the real nodes.
820  for node in function_def.node_def:
821    for i, _ in enumerate(node.input):
822      if node.input[i] in nodes_mapping:
823        node.input[i] = nodes_mapping[node.input[i]]
824    new_nodes.append(_copy.deepcopy(node))
825  name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def)
826  children_hints = _find_all_hints_in_nodes(new_nodes)
827  children_hints_q = []
828  # Ordered by the outputs.
829  for hint in _six.itervalues(children_hints):
830    _, output_names = hint.flattened_inputs_and_outputs()
831    seq = name_to_seq_num[output_names[0]]
832    for output_name in output_names:
833      seq = min(seq, name_to_seq_num[output_name])
834    children_hints_q.append((seq, hint))
835  children_hints_q.sort(key=lambda tup: tup[0])
836  ordered_children_hints = [x[1] for x in children_hints_q]
837  return ordered_children_hints, new_nodes
838
839
840def _find_children_hints(call, graph_def):
841  """Find all children hints.
842
843  For a given OpHint, we find all children hints inside it, we also copy all the
844  nodes inside function defs (if applicable) to the original graph_def, they are
845  returned in a list as well.
846
847  Args:
848    call: Parent OpHint that contains children ophints.
849    graph_def: Original graph def.
850
851  Returns:
852    Ordered children hints inside the parent ophint; new graph def that contains
853    nodes inside function defs (if applicable); nodes inside function defs.
854  """
855  name_to_input_name, _, _ = _extract_graph_summary(graph_def)
856  input_names, output_names = call.flattened_inputs_and_outputs()
857
858  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
859  reachable_by_output = _bfs_for_reachable_nodes(output_names,
860                                                 name_to_input_name)
861  output_nodes_set = set(output_names)
862  children_hints = []
863  out = _graph_pb2.GraphDef()
864  out.library.CopyFrom(graph_def.library)
865  out.versions.CopyFrom(graph_def.versions)
866  function_def_nodes = set()
867  for node in graph_def.node:
868    out.node.extend([_copy.deepcopy(node)])
869    n = _tensor_name_base(node.name)
870    if n in reachable_by_output:
871      if n not in reachable_by_input and n not in output_nodes_set:
872        # special handle for while loop function def.
873        if node.op == "While" or node.op == "StatelessWhile":
874          body_name = node.attr["body"].func.name
875          inputs_outside_loop = node.input
876          for function_def in graph_def.library.function:
877            if function_def.signature.name == body_name:
878              function_inputs = function_def.signature.input_arg
879              assert len(inputs_outside_loop) == len(function_inputs)
880              nodes_mapping = {}
881              for i, function_input in enumerate(function_inputs):
882                nodes_mapping[function_input.name] = inputs_outside_loop[i]
883              # TODO(b/123050804): Consider use grappler.
884              (children_hints_in_loop,
885               new_nodes) = _find_children_hints_in_while_loop(
886                   function_def, nodes_mapping)
887              function_def_nodes.update([x.name for x in new_nodes])
888              children_hints.extend(children_hints_in_loop)
889              out.node.extend(new_nodes)
890
891  return children_hints, out, function_def_nodes
892
893
894def _tensor_name_base(full_tensor_name):
895  """Removes the device assignment code from a tensor.
896
897  e.g. _tensor_name_base("foo:3") => "foo"
898
899  Args:
900    full_tensor_name: A tensor name that is annotated with a device placement
901      (this is what tensor flow introspection gives).
902
903  Returns:
904    A name without any device assignment.
905  """
906  if full_tensor_name.startswith("^"):
907    return full_tensor_name[1:]
908  return full_tensor_name.split(":")[0]
909
910
911def _tensorflow_output_name(tensor_name, output_index):
912  return tensor_name if output_index == 0 else "%s:%d" % (tensor_name,
913                                                          output_index)
914
915
916# TODO(aselle): This should be converted to grappler in the future.
917def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
918                           name_to_input_name):
919  """Checks to make sure node only connects to predecessor graph through inputs.
920
921  Args:
922    n: Node to check
923    reachable_by_input: Nodes that are reachable by all inputs of subgraph
924    input_nodes_set: The set of nodes that are "inputs".
925    name_to_input_name: Maps from name to the list of inputs.
926
927  Raises:
928    TypeError: If the given node uses items past inputs directly.
929  """
930  next_to_visit = [n]
931  visited = set()
932  while next_to_visit:
933    current_node = next_to_visit.pop()
934    visited.add(current_node)
935    if (current_node in reachable_by_input and
936        current_node not in input_nodes_set):
937      raise TypeError("Node %s uses input %s not in input_nodes." %
938                      (n, current_node))
939    if current_node not in input_nodes_set:
940      next_to_visit += [
941          input_node for input_node in name_to_input_name[current_node]
942          if input_node not in visited
943      ]
944
945
946# TODO(aselle): This should be converted to grappler in the future.
947def _convert_single_op_hint_to_stub(call,
948                                    graph_def,
949                                    function_def_nodes=None,
950                                    is_last_run=True):
951  """Given a graph_def, converts `call` into a stub and returns a new graph_def.
952
953  Args:
954    call: A single function call to be converted.
955    graph_def: A graph_def to use as input (that has call obviously).
956    function_def_nodes: Nodes inside the function def those are not connected to
957      the graph.
958    is_last_run: Whether it is the last run for a given pass (for OpHint has
959      children).
960
961  Returns:
962    A new transformed graph-def that has call as a stub (single op).
963
964  Note: after this process, the graph_def can no longer be loaded into
965      the tensorflow runtime, so all future manipulations are done in graph_def
966      level.
967  """
968  if function_def_nodes is None:
969    function_def_nodes = set()
970  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
971      graph_def)
972  input_names, output_names = call.flattened_inputs_and_outputs()
973
974  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
975  reachable_by_output = _bfs_for_reachable_nodes(output_names,
976                                                 name_to_input_name)
977  output_nodes_set = set(output_names)
978  nodes_after_fuse = []
979  nodes_deleted_by_fuse = set()
980  # Classify each node. We want to keep everything reachable by input, but
981  # we don't know if things that are not reachable by output or input (things
982  # after fusing).
983  for node in graph_def.node:
984    n = _tensor_name_base(node.name)
985    if n in reachable_by_output:
986      if n not in reachable_by_input and n not in output_nodes_set:
987        nodes_deleted_by_fuse.add(n)
988    elif n not in reachable_by_input and n not in function_def_nodes:
989      # n is a node that after all the fusings, so keep it.
990      nodes_after_fuse.append(n)
991    else:
992      # In the last run, n is a node that is randomly in the graph but not
993      # connected to the chain of dependencies, we will delete n, otherwise
994      # we keep them.
995      if not is_last_run:
996        nodes_after_fuse.append(n)
997
998  # Make a new graphdef with all the pre-input and input nodes
999  out = _graph_pb2.GraphDef()
1000  reachable_by_input_sorted = sorted(
1001      list(reachable_by_input), key=lambda n: name_to_seq_num[n])
1002  for node in reachable_by_input_sorted:
1003    out.node.extend([_copy.deepcopy(name_to_node[node])])
1004
1005  # Create any stacks to aggregate arguments into to a single input
1006  # i.e. for static_rnn's.
1007  # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1
1008  sorted_input_indices = list(call.inputs.keys())
1009  sorted_input_indices.sort()
1010  sorted_output_indices = list(call.outputs.keys())
1011  sorted_output_indices.sort()
1012  new_node = _node_def_pb2.NodeDef()
1013  # Delegate to each operand to produce the proper new input for this stub node.
1014  # In particular, an aggregate input will now be a Pack of some previously
1015  # non-fused things.
1016
1017  optional_input_node = _node_def_pb2.NodeDef()
1018  optional_input_node.name = "Const" + str(_uuid.uuid1().hex)
1019  optional_input_node.op = "Const"
1020  optional_input_node.attr["dtype"].CopyFrom(
1021      _attr_value_pb2.AttrValue(type=_dtypes.float32.as_datatype_enum))
1022  optional_input_node.attr["value"].CopyFrom(
1023      _attr_value_pb2.AttrValue(
1024          tensor=_tensor_util.make_tensor_proto([-1], _dtypes.float32, [1])))
1025  out.node.extend([optional_input_node])
1026
1027  max_index = max(sorted_input_indices) + 1
1028  for cur_index in range(max_index):
1029    if cur_index in sorted_input_indices:
1030      inputs = call.inputs[cur_index]
1031      input_name = inputs.aggregate_and_return_name_for_input(out)
1032      new_node.input.append(input_name)
1033    else:
1034      new_node.input.append(optional_input_node.name)
1035
1036  new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)
1037
1038  # Create the function
1039  new_node.op = call.function_name
1040  new_node.name = call.uuid
1041  out.node.extend([new_node])
1042
1043  # Now call each output argument to give them a chance to make the proper
1044  # output type and add it to our new_node.
1045  output_dtypes = []
1046  max_output_index = max(sorted_output_indices) + 1
1047  for cur_index in range(max_output_index):
1048    if cur_index in sorted_output_indices:
1049      output = call.outputs[cur_index]
1050      output_dtype = (
1051          output.aggregate_and_return_name_for_output(new_node.name, cur_index,
1052                                                      out))
1053    else:
1054      output_dtype = optional_input_node.attr["type"].i
1055    output_dtypes.append(output_dtype)
1056  new_node.attr["_output_types"].list.type[:] = output_dtypes
1057  # TODO(aselle): what is right here?
1058  new_node.attr["_output_quantized"].b = False
1059
1060  # Add post output nodes that do not depend on the outputs
1061  for n in nodes_after_fuse:
1062    should_keep = True
1063    for input_name in name_to_input_name[n]:
1064      if input_name in nodes_deleted_by_fuse:
1065        should_keep = False
1066    if should_keep:
1067      out.node.extend([_copy.deepcopy(name_to_node[n])])
1068
1069  # Misc. graph_def data that needs copying.
1070  out.library.CopyFrom(graph_def.library)
1071  out.versions.CopyFrom(graph_def.versions)
1072
1073  return out
1074
1075
1076# TODO(aselle): This should be converted to grappler in the future.
1077def _remove_one_redundant_stack_unstack(in_graph_def):
1078  """Removes a stack->unstack pattern from in_graph_def in a returned graph.
1079
1080  Args:
1081    in_graph_def: Graph def to use as input.
1082
1083  Returns:
1084    Simplified tuple (graph_def, changed_something) where changed_something
1085    is true if anything was done.
1086  """
1087  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
1088      in_graph_def)
1089  del name_to_seq_num
1090
1091  # TODO(aselle): Make this not hardcoded.
1092  do_generic_pack_unpack = True
1093
1094  out = _graph_pb2.GraphDef()
1095  out.library.CopyFrom(in_graph_def.library)
1096  out.versions.CopyFrom(in_graph_def.versions)
1097  for n in in_graph_def.node:
1098    node_name = _tensor_name_base(n.name)
1099    if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"):
1100      continue
1101    next_to_visit = [node_name]
1102    visited = set()
1103
1104    unpack_nodes = set()
1105    pack_node = node_name
1106
1107    # Find a pattern of unstack connected to a stack (with identities
1108    # in between.
1109    matches_pattern = True
1110    is_hint_created_stack = False
1111    while next_to_visit:
1112      current_node_name = next_to_visit[0]
1113      visited.add(current_node_name)
1114      del next_to_visit[0]
1115      node = name_to_node[current_node_name]
1116      is_op_hint_stack = node.name.startswith("OpHintStack")
1117      is_op_hint_unstack = node.name.startswith("OpHintUnstack")
1118      if (node.op == "Identity" or is_op_hint_stack or
1119          (do_generic_pack_unpack and node.op == "Pack")):
1120        is_hint_created_stack |= is_op_hint_stack
1121        next_to_visit += [
1122            input_node for input_node in name_to_input_name[current_node_name]
1123            if input_node not in visited
1124        ]
1125      elif (is_op_hint_unstack or
1126            (do_generic_pack_unpack and node.op == "Unpack")):
1127        unpack_nodes.add(node.name)
1128        is_hint_created_stack &= is_op_hint_unstack
1129      else:
1130        matches_pattern = False
1131        break
1132      visited.add(node.name)
1133
1134    if matches_pattern and len(unpack_nodes) == 1:
1135      pack_node = node_name
1136
1137      # Check to see if anyone depends on the intermediate identity or the
1138      # Unstacked form
1139      no_external_dependency = True
1140      for other_n in in_graph_def.node:
1141        if other_n.name in visited:
1142          continue
1143        for input_tensor in name_to_input_name[other_n.name]:
1144          input_op = _tensor_name_base(input_tensor)
1145          if input_op in visited and input_op != pack_node:
1146            no_external_dependency = False
1147      # Proceed with the substitution if the stack/unstack pair was created
1148      # through hints, or that it was not, but nobody is consuming things
1149      # between the stack and unstack.
1150      if is_hint_created_stack or no_external_dependency:
1151        end = unpack_nodes.pop()
1152        end_input = name_to_node[end].input[0]
1153        # All nodes that depend on the final stack need to be redone to use
1154        for other_n in in_graph_def.node:
1155          node_name = _tensor_name_base(other_n.name)
1156          if node_name not in visited:
1157            new_node = _copy.deepcopy(other_n)
1158            new_node.input[:] = [
1159                (end_input if stripped == pack_node else non_stripped)
1160                for stripped, non_stripped in zip(name_to_input_name[node_name],
1161                                                  new_node.input[:])
1162            ]
1163            out.node.extend([new_node])
1164        return out, True
1165  return in_graph_def, False
1166
1167
1168def _remove_redundant_stack_unstack(graph_def):
1169  curr = graph_def
1170  del graph_def
1171  changed_stuff = True
1172  while changed_stuff:
1173    curr, changed_stuff = _remove_one_redundant_stack_unstack(curr)
1174  return curr
1175
1176
1177def _get_correct_mapping(original_index, nodes):
1178  # Special handle for the index is -1 case.
1179  # If it is -1, return the last index.
1180  if original_index == -1:
1181    node_indices = nodes.keys()
1182    node_indices = sorted(node_indices)
1183    return node_indices[-1]
1184  return original_index
1185
1186
1187def _convert_op_hints_to_stubs_helper(
1188    graph_def, write_callback=lambda sess, graph_def: None):
1189  """Converts a graph_def to a new graph_def where all op hints are stubbed.
1190
1191  Args:
1192    graph_def: A graph def that we should convert.
1193    write_callback: A function pointer that can be used to write intermediate
1194      steps of graph transformation (optional).
1195
1196  Returns:
1197    A new stubbed graph_def.
1198  """
1199  hints = _find_all_hints_in_nodes(graph_def.node)
1200
1201  hints_q = []
1202  for hint in _six.itervalues(hints):
1203    hints_q.append((hint.level, hint.uuid))
1204
1205  hints_q.sort(key=lambda tup: tup[0])
1206  for i in range(len(hints_q) - 1, -1, -1):
1207    level, hint_uuid = hints_q[i]
1208
1209  curr_graph_def = graph_def
1210  del graph_def  # prevent using graph_def again (common source of error)
1211  for i in range(len(hints_q) - 1, -1, -1):
1212    level, hint_uuid = hints_q[i]
1213    if level >= 2:
1214      children_hints, curr_graph_def, function_def_nodes = _find_children_hints(
1215          hints[hint_uuid], curr_graph_def)
1216      # pylint: disable=superfluous-parens
1217      assert (len(children_hints) > 0)  #  pylint: disable=g-explicit-length-test
1218      # pylint: enable=superfluous-parens
1219
1220      # Re-wire the children hints inputs/outputs, so latter child's inputs
1221      # connect to previous child node's outputs.
1222      children_inputs_mappings = hints[hint_uuid].children_inputs_mappings
1223      for j, child_hint in enumerate(children_hints):
1224        if j == 0:
1225          for mapping in children_inputs_mappings["parent_first_child_input"]:
1226            parent_input_index = _get_correct_mapping(
1227                mapping["parent_ophint_input_index"], hints[hint_uuid].inputs)
1228            child_input_index = _get_correct_mapping(
1229                mapping["first_child_ophint_input_index"], child_hint.inputs)
1230            child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[
1231                parent_input_index]
1232        else:
1233          for mapping in children_inputs_mappings[
1234              "internal_children_input_output"]:
1235            input_index = _get_correct_mapping(mapping["child_input_index"],
1236                                               child_hint.inputs)
1237            output_index = _get_correct_mapping(mapping["child_output_index"],
1238                                                children_hints[j - 1].outputs)
1239            child_hint.inputs[input_index] = children_hints[
1240                j - 1].outputs[output_index]
1241        if j == len(children_hints) - 1:
1242          for mapping in children_inputs_mappings["parent_last_child_output"]:
1243            parent_output_index = _get_correct_mapping(
1244                mapping["parent_output_index"], hints[hint_uuid].outputs)
1245            child_output_index = _get_correct_mapping(
1246                mapping["child_output_index"], child_hint.outputs)
1247            child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[
1248                parent_output_index]
1249
1250      for j, child_hint in enumerate(children_hints):
1251        curr_graph_def = _convert_single_op_hint_to_stub(
1252            child_hint, curr_graph_def, function_def_nodes,
1253            j == len(children_hints) - 1)
1254    else:
1255      curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid],
1256                                                       curr_graph_def)
1257      write_callback(curr_graph_def, "initial")
1258  # The stubbing process can create stacks/unstacks in the case of LSTMs
1259  # remove them.
1260  curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def)
1261  return curr_graph_def
1262
1263
1264def find_all_hinted_output_nodes(session=None, graph_def=None):
1265  """Find all Ophints output nodes in the graph.
1266
1267  This is used to get all the output nodes those are ophinted, it is important
1268  for operation like convert_variables_to_constants keep all ophints structure.
1269  Note: only one of session or graph_def should be used, not both.
1270  Why this can be useful? Some TensorFlow ops (e.g. bidirectional rnn), can
1271  generate multiple outputs for unfused subgraph. If not all output nodes are
1272  consumed, graph optimization can potentially drop the unused nodes and cause
1273  ophints in an invalid states (due to missing ophinted output nodes). So it's
1274  important for us to find all those hinted output nodes and make sure they're
1275  not discarded away.
1276
1277  Args:
1278    session: A TensorFlow session that contains the graph to convert.
1279    graph_def: A graph def that we should convert.
1280
1281  Returns:
1282    A list of OpHints output nodes.
1283  Raises:
1284    ValueError: If both session and graph_def are provided.
1285  """
1286  if session is not None and graph_def is not None:
1287    raise ValueError("Provide only one of session and graph_def.")
1288  hinted_outputs_nodes = []
1289  if session is not None:
1290    hints = _find_all_hints_in_nodes(session.graph_def.node)
1291  elif graph_def is not None:
1292    hints = _find_all_hints_in_nodes(graph_def.node)
1293  for hint in _six.itervalues(hints):
1294    _, output_nodes = hint.flattened_inputs_and_outputs()
1295    hinted_outputs_nodes.extend(output_nodes)
1296  return hinted_outputs_nodes
1297
1298
1299def is_ophint_converted(graph_def):
1300  if graph_def is None:
1301    raise ValueError("Must provide the graph_def.")
1302  ophint_converted = False
1303  for node in graph_def.node:
1304    attr = node.attr
1305    if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
1306      ophint_converted = True
1307      break
1308  return ophint_converted
1309
1310
1311@_tf_export(v1=["lite.experimental.convert_op_hints_to_stubs"])
1312@_deprecation.deprecated(
1313    None,
1314    "Please follow instructions under "
1315    "https://www.tensorflow.org/lite/convert/operation_fusion for operation"
1316    "fusion in tflite."
1317)
1318def convert_op_hints_to_stubs(session=None,
1319                              graph_def=None,
1320                              write_callback=lambda graph_def, comments: None):
1321  """Converts a graphdef with LiteOp hints into stub operations.
1322
1323  This is used to prepare for toco conversion of complex intrinsic usages.
1324  Note: only one of session or graph_def should be used, not both.
1325
1326  Args:
1327    session: A TensorFlow session that contains the graph to convert.
1328    graph_def: A graph def that we should convert.
1329    write_callback: A function pointer that can be used to write intermediate
1330      steps of graph transformation (optional).
1331
1332  Returns:
1333    A new graphdef with all ops contained in OpHints being replaced by
1334    a single op call with the right parameters.
1335  Raises:
1336    ValueError: If both session and graph_def are provided.
1337  """
1338
1339  if session is not None and graph_def is not None:
1340    raise ValueError("Provide only one of session and graph_def.")
1341
1342  if session is not None:
1343    return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback)
1344  elif graph_def is not None:
1345    return _convert_op_hints_to_stubs_helper(graph_def, write_callback)
1346  else:
1347    raise ValueError("Must specify session or graph_def as input.")
1348
1349
1350_allowed_symbols = [
1351    "OpHint",
1352    "convert_op_hints_to_stubs",
1353    "convert_op_hints_to_stubs_new",
1354    "find_all_hinted_output_nodes",
1355    "is_ophint_converted",
1356]
1357remove_undocumented(__name__, _allowed_symbols)
1358