1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Define tflite op hints (intrinsic operations). 16 17This essentially allows defining a TensorFlow API for tflite operations in 18Python with hints on how they are represented in TensorFlow Lite. This basically 19is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution 20graph and is useful for LSTMs and other complicated TensorFlow constructions 21that are difficult to pattern match in TOCO, but are represented by a single 22accelerated tflite op. 23 24Example: 25 def tflite_cool_activation(input): 26 # A cool activation function. 27 custom = tf.lite.OpHint("cool_activation") 28 input, = custom.add_inputs(input) 29 output = tf.sigmoid(input) * input 30 output, = custom.add_outputs(output) 31 return output 32 33 image = tf.compat.v1.placeholder(tf.float32, (1, 16, 16, 1)) 34 output = tf.identity(tflite_cool_activation(image)) 35 36 session = tf.compat.v1.Session() 37 38 graphdef_to_convert = tf.lite.experimental.convert_op_hints_to_stubs(session) 39 tflite_graph = tf.compat.v1.lite.toco_convert( 40 graphdef_to_convert, [image], [output], allow_custom_ops=True) 41 with open("/tmp/graph.fb", "wb") as fp: 42 fp.write(tflite_graph) 43 44How does it work?: 45 46OpHint is a helper that you use when defining a vanilla python function. 47It allows you to wrap arguments with tf.identities with some custom attributes. 48These attributes allow you to find the original block of ops that was created. 49For example, if you use cool_activation above you essentially get: 50 51a_input = tf.identity() 52result = tf.multiply(tf.sigmoid(a_input), a_input) 53output = tf.identity() 54 55a_input, output are identities that have parameters representing 56what argument they are, what the name of the function they should turn into 57in tf lite as well as a guid that uniquely identifies a particular invocation. 58 59Once you have built your whole tensorflow graph, you can run it and train it 60as usual, but after you have done that, you need to convert the graph into 61a form that replaces these subgraphs wrapped in identities to stub ops. These 62ops don't actually exist in the normal TensorFlow runtime, but will be 63understood by toco later. The generated TensorFlow Lite flatbuffer file will 64contain a custom operator called "cool_activation". Developer needs to implement 65and register this operator in TensorFlow Lite in order to do inference. 66""" 67 68# TODO(aselle): Make this use generic graph transformations. 69# TODO(aselle): _tensor_name_base should be called _tensor_name_to_op_name. 70 71from __future__ import absolute_import 72from __future__ import division 73from __future__ import print_function 74 75import collections as _collections 76import copy as _copy 77import json as _json 78import uuid as _uuid 79import six as _six 80 81from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2 82from tensorflow.core.framework import graph_pb2 as _graph_pb2 83from tensorflow.core.framework import node_def_pb2 as _node_def_pb2 84from tensorflow.python.framework import dtypes as _dtypes 85from tensorflow.python.framework import ops as _ops 86from tensorflow.python.framework import tensor_util as _tensor_util 87# TODO(aselle): publicize these apis if we continue to use these. 88from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes 89from tensorflow.python.framework.graph_util_impl import _extract_graph_summary 90from tensorflow.python.ops import array_ops as _array_ops 91from tensorflow.python.util import compat as _compat 92from tensorflow.python.util import deprecation as _deprecation 93from tensorflow.python.util.all_util import remove_undocumented 94from tensorflow.python.util.tf_export import tf_export as _tf_export 95 96 97@_tf_export(v1=["lite.OpHint"]) 98@_deprecation.deprecated( 99 None, 100 "Please follow instructions under " 101 "https://www.tensorflow.org/lite/convert/operation_fusion for operation" 102 "fusion in tflite." 103) 104class OpHint(object): 105 """A class that helps build tflite function invocations. 106 107 It allows you to take a bunch of TensorFlow ops and annotate the construction 108 such that toco knows how to convert it to tflite. This embeds a pseudo 109 function in a TensorFlow graph. This allows embedding high-level API usage 110 information in a lower level TensorFlow implementation so that an alternative 111 implementation can be substituted later. 112 113 Essentially, any "input" into this pseudo op is fed into an identity, and 114 attributes are added to that input before being used by the constituent ops 115 that make up the pseudo op. A similar process is done to any output that 116 is to be exported from the current op. 117 118 """ 119 # TODO(aselle): When TensorFlow functions functionality works for arbitrary 120 # constructs, this mechanism can be retired and changed to use python defun's. 121 122 # Attr constants that are used for representation in the GraphDef. These 123 # will be used on every Identity op that is involved in a total OpHint. 124 125 # Name of the OpHint function (cosmetic). 126 FUNCTION_NAME_ATTR = "_tflite_function_name" 127 # UUID of the function (each OpHint gets a new uuid). 128 FUNCTION_UUID_ATTR = "_tflite_function_uuid" 129 # The input index of the input (or nothing if it is an output). 130 FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index" 131 # The output index of the output (or nothing if it is an input). 132 FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index" 133 # An index that orders aggregate arguments. Aggregate arguments are ones 134 # that are separate but will be fused horizontally. For example a static LSTM 135 # has a lstm cell for each time step. Each one has a separate opHint, but a 136 # fused SequentialLSTM will treat this as a single tensor. 137 FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index" 138 # The way in which multiple parts of the aggregate argument will be joined 139 # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST, 140 # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK. 141 FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate" 142 # On fused OpHint stub, the order of inputs that the final LSTM call will 143 # have. What this means is that the TensorFlow order might be 144 # "foo", "bar", "stuff" and you might want the TF lite op order to be 145 # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this 146 # attribute to [2, 0, 1, -1]. 147 TFLITE_INPUT_INDICES = "_tflite_input_indices" 148 # OpHint level. 149 FUNCTION_LEVEL_ATTR = "_tflite_ophint_level" 150 # Ophint internal mapping, this is for high level Ophint only. 151 # This basically contains three kinds of mapping: 152 # 1) How parental ophinted inputs map to the first child ophinted inputs; 153 # 2) How internal children nodes are connected; 154 # 3) How parental ophinted outputs map to the last child ophinted outputs. 155 CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping" 156 157 # Types of aggregations 158 # stack: stacks all ophints with matching tags. i.e. for a static rnn. 159 # specifically, this is good for an input or output to a static rnn cell. 160 AGGREGATE_STACK = "stack" 161 # first: only takes the first output (one with lowest sort index) 162 # of matching tags. This is good for the input state to an RNN. 163 AGGREGATE_FIRST = "first" 164 # aggregation last takes only the last tag (one with highest sort index). 165 # This is good for an output value on the last stack item of a 166 # static rnn. 167 AGGREGATE_LAST = "last" 168 169 class OpHintArgumentTracker(object): 170 """Conceptually tracks indices of arguments of "OpHint functions". 171 172 The inputs and arguments of these functions both use an instance 173 of the class so they can have independent numbering. 174 """ 175 176 def __init__(self, 177 function_name, 178 unique_function_id, 179 node_name_prefix, 180 attr_name, 181 level=1, 182 children_inputs_mappings=None): 183 """Initialize ophint argument. 184 185 Args: 186 function_name: Name of the function that this tracks arguments for. 187 unique_function_id: UUID of function that this tracks arguments for. 188 node_name_prefix: How identities that are created are named. 189 attr_name: Name of attribute to use to store the index for this hint. 190 i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX 191 level: Hierarchical level of the Ophint node, a number. 192 children_inputs_mappings: Inputs/Outputs mapping for children hints. 193 """ 194 195 # The global index is the argument index of the op. This is in contrast 196 # to the sort index which is the sequence number of a particular instance 197 # of a given global index. For example, you may have called add hint 198 # twice with the tag "foo". Then the global index will be 0 for both 199 # and the sort index will be 0 for the first added and 1 for the second. 200 self._function_name = function_name 201 self._unique_function_id = unique_function_id 202 self._next_global_index = 0 # The absolute global index 203 self._used_global_indices = set() 204 self._tag_to_global_index = {} # The argument index a given tag maps to 205 self._tag_to_next_sort_index = {} # The current index for each tag 206 self._node_name_prefix = node_name_prefix 207 self._attr_name = attr_name 208 self._level = level 209 self._children_inputs_mappings = children_inputs_mappings 210 211 def _get_new_global_index(self, index_override): 212 """Return the next unused argument index in order or use an override. 213 214 Args: 215 index_override: An index to use instead of the next available or None 216 to use the next available. 217 218 Returns: 219 A valid global_index to use for the next hint argument. 220 221 Raises: 222 ValueError: If the index_override is already used by another hint. 223 """ 224 if index_override is None: 225 global_index = self._next_global_index 226 else: 227 if index_override in self._used_global_indices: 228 raise ValueError("Index %d was already used by another call to add") 229 global_index = index_override 230 # Make next_global_index valid 231 self._used_global_indices.add(global_index) 232 while self._next_global_index in self._used_global_indices: 233 self._next_global_index += 1 234 return global_index 235 236 def add(self, arg, tag=None, name=None, aggregate=None, 237 index_override=None): 238 """Return a wrapped tensor of an input tensor as an argument. 239 240 Args: 241 arg: A TensorFlow tensor that should be considered an argument. 242 tag: String tag to identify arguments that should be packed. 243 name: Name of argument. This is included in the Identity hint op names. 244 aggregate: Strategy to aggregate. 245 Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, 246 and OpHint.AGGREGATE_STACK. 247 Note, aggregate is only valid if tag is specified. 248 index_override: Specify what input/output index should this be in the 249 final stub. i.e. add(arg0, index=1); add(arg1, index=0) will make the 250 final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than 251 the default call order based ordering. 252 253 Returns: 254 A tensor representing the wrapped argument. 255 256 Raises: 257 ValueError: When indices are not consistent. 258 """ 259 260 # Find the appropriate index 261 if tag is None: 262 if aggregate is not None: 263 raise ValueError("You must specify `tag` if using aggregate.") 264 global_index = self._get_new_global_index(index_override) 265 sort_index = None 266 else: 267 if aggregate is None: 268 raise ValueError("You must specify `aggregate` if using tag.") 269 if tag not in self._tag_to_global_index: 270 self._tag_to_global_index[tag] = ( 271 self._get_new_global_index(index_override)) 272 self._tag_to_next_sort_index[tag] = 0 273 elif (index_override and 274 index_override != self._tag_to_global_index[tag]): 275 raise ValueError( 276 "Tag %r was called with two indices %r and %r" % 277 (tag, index_override, self._tag_to_global_index[tag])) 278 global_index = self._tag_to_global_index[tag] 279 sort_index = self._tag_to_next_sort_index[tag] 280 self._tag_to_next_sort_index[tag] += 1 281 282 uuid = self._unique_function_id 283 name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name, 284 uuid, global_index, sort_index, name) 285 286 identity_op = _array_ops.identity(arg, name=name) 287 288 # pylint: disable=protected-access 289 identity_op.op._set_attr( 290 OpHint.FUNCTION_NAME_ATTR, 291 _attr_value_pb2.AttrValue( 292 s=_compat.as_bytes(self._function_name))) 293 identity_op.op._set_attr( 294 OpHint.FUNCTION_UUID_ATTR, 295 _attr_value_pb2.AttrValue( 296 s=_compat.as_bytes(self._unique_function_id))) 297 identity_op.op._set_attr( 298 self._attr_name, _attr_value_pb2.AttrValue(i=global_index)) 299 identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR, 300 _attr_value_pb2.AttrValue(i=self._level)) 301 if self._children_inputs_mappings: 302 identity_op.op._set_attr( 303 OpHint.CHILDREN_INPUTS_MAPPINGS, 304 _attr_value_pb2.AttrValue( 305 s=_compat.as_bytes(_json.dumps( 306 self._children_inputs_mappings)))) 307 308 if sort_index is not None: 309 identity_op.op._set_attr( 310 OpHint.FUNCTION_SORT_INDEX_ATTR, 311 _attr_value_pb2.AttrValue(i=sort_index)) 312 if aggregate is not None: 313 identity_op.op._set_attr( 314 OpHint.FUNCTION_AGGREGATE_ATTR, 315 _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate)))) 316 # pylint: enable=protected-access 317 return identity_op 318 319 def __init__(self, 320 function_name, 321 level=1, 322 children_inputs_mappings=None, 323 **kwargs): 324 """Create a OpHint. 325 326 Args: 327 function_name: Name of the function (the custom op name in tflite) 328 level: OpHint level. 329 children_inputs_mappings: Children OpHint inputs/outputs mapping. 330 children_inputs_mappings should like below: 331 "parent_first_child_input": 332 [{"parent_input_index": num, "child_input_index": num}, ...] 333 "parent_last_child_output": 334 [{"parent_output_index": num, "child_output_index": num}, ...] 335 "internal_children_input_output": 336 [{"child_input_index": num, "child_output_index": num}, ...] 337 **kwargs: Keyword arguments of any constant attributes for the function. 338 """ 339 self._function_name = function_name 340 self._level = level 341 if self._level == 1: 342 assert children_inputs_mappings is None 343 else: 344 assert isinstance(children_inputs_mappings, dict) 345 self._children_inputs_mappings = children_inputs_mappings 346 if self._children_inputs_mappings is not None: 347 self._validate_children_inputs_mappings(self._children_inputs_mappings) 348 self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough? 349 self._attrs_to_store_later = kwargs 350 self._stored_attrs = False 351 self._inputs = OpHint.OpHintArgumentTracker( 352 self._function_name, self._unique_function_id, "InputHint", 353 OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings) 354 self._outputs = OpHint.OpHintArgumentTracker( 355 self._function_name, self._unique_function_id, "OutputHint", 356 OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level, 357 self._children_inputs_mappings) 358 359 def _validate_children_inputs_mappings(self, children_inputs_mappings): 360 """Validate children inputs mappings is in the right format. 361 362 Args: 363 children_inputs_mappings: the Children ophint inputs/outputs mapping. 364 """ 365 assert isinstance(children_inputs_mappings, dict) 366 assert "parent_first_child_input" in children_inputs_mappings 367 assert "parent_last_child_output" in children_inputs_mappings 368 assert "internal_children_input_output" in children_inputs_mappings 369 370 # validate parent_first_child_input. 371 372 def assert_dictlist_has_keys(dictlist, keys): 373 for dikt in dictlist: 374 assert isinstance(dikt, dict) 375 for key in keys: 376 assert key in dikt 377 378 assert_dictlist_has_keys( 379 children_inputs_mappings["parent_first_child_input"], 380 ["parent_ophint_input_index", "first_child_ophint_input_index"]) 381 assert_dictlist_has_keys( 382 children_inputs_mappings["parent_last_child_output"], 383 ["parent_output_index", "child_output_index"]) 384 assert_dictlist_has_keys( 385 children_inputs_mappings["internal_children_input_output"], 386 ["child_input_index", "child_output_index"]) 387 388 def _setattr(self, dest_op, name, value): 389 tensor_value = _ops.convert_to_tensor(value) 390 # pylint: disable=protected-access 391 dest_op.op._set_attr(name, _attr_value_pb2.AttrValue( 392 tensor=tensor_value.op.node_def.attr["value"].tensor)) 393 # pylint: enable=protected-access 394 395 def add_input(self, *args, **kwargs): 396 """Add a wrapped input argument to the hint. 397 398 Args: 399 *args: The input tensor. 400 **kwargs: 401 "name" label 402 "tag" a tag to group multiple arguments that will be aggregated. I.e. 403 a string like 'cool_input'. Basically multiple inputs can be added 404 to the same hint for parallel operations that will eventually be 405 combined. An example would be static_rnn which creates multiple copies 406 of state or inputs. 407 "aggregate" aggregation strategy that is valid only for tag non None. 408 Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, 409 and OpHint.AGGREGATE_STACK. 410 "index_override" The global index to use. This corresponds to the 411 argument order in the final stub that will be generated. 412 Returns: 413 The wrapped input tensor. 414 """ 415 return self._inputs.add(*args, **kwargs) 416 417 def add_output(self, *args, **kwargs): 418 """Add a wrapped output argument to the hint. 419 420 Args: 421 *args: The output tensor. 422 **kwargs: 423 "name" label 424 "tag" a tag to group multiple arguments that will be aggregated. I.e. 425 a string like 'cool_input'. Basically multiple inputs can be added 426 to the same hint for parallel operations that will eventually be 427 combined. An example would be static_rnn which creates multiple copies 428 of state or inputs. 429 "aggregate" aggregation strategy that is valid only for tag non None. 430 Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, 431 and OpHint.AGGREGATE_STACK. 432 "index_override" The global index to use. This corresponds to the 433 argument order in the final stub that will be generated. 434 Returns: 435 The wrapped output tensor. 436 """ 437 return self._outputs.add(*args, **kwargs) 438 439 def add_inputs(self, *args, **kwargs): 440 """Add a sequence of inputs to the function invocation. 441 442 Args: 443 *args: List of inputs to be converted (should be Tf.Tensor). 444 **kwargs: This allows 'names' which should be a list of names. 445 446 Returns: 447 Wrapped inputs (identity standins that have additional metadata). These 448 are also are also tf.Tensor's. 449 """ 450 if "names" in kwargs: 451 return [ 452 self._inputs.add(arg, name=name) 453 for arg, name in zip(args, kwargs["names"]) 454 ] 455 else: 456 return [self._inputs.add(arg) for arg in args] 457 458 def add_outputs(self, *args, **kwargs): 459 """Add a sequence of outputs to the function invocation. 460 461 Args: 462 *args: List of outputs to be converted (should be tf.Tensor). 463 **kwargs: See 464 465 Returns: 466 Wrapped outputs (identity standins that have additional metadata). These 467 are also tf.Tensor's. 468 """ 469 if "names" in kwargs: 470 return [ 471 self._outputs.add(arg, name=name) 472 for arg, name in zip(args, kwargs["names"]) 473 ] 474 else: 475 return [self._outputs.add(arg) for arg in args] 476 477 478class _LiteOperand(object): 479 """Abstract operand for a tflite hint function._dynamic_rnn_loop. 480 481 This is a base class that handles representing arguments to an OpHint. 482 It also is able to serialize operands to the stubbed graph_def. 483 Child classes are responsible for being able to 484 store information about the hint identity operators. They are also responsible 485 for knowing how to serialize to output graphdefs. 486 487 Typically this will be implemented by holding one or more identity nodes 488 that were previously discovered as hints. 489 """ 490 491 def aggregate_and_return_name_for_input(self, out_graphdef): 492 """This adds the node(s) to out_graphdef and returns the input node name. 493 494 Args: 495 out_graphdef: A graphdef that is ready to have this input added. 496 497 Returns: 498 The output that the stub should use as an input for this operand. 499 500 Raises: 501 RuntimeError: if the method is not implemented. 502 """ 503 del out_graphdef 504 raise RuntimeError("Unimplemented abstract method.") 505 506 def aggregate_and_return_name_for_output(self, fused_op_name, output_index, 507 out_graphdef): 508 """Add node(s) to graph representing output operands and returns type. 509 510 Args: 511 fused_op_name: name of the fused op stub name. 512 output_index: Output index that we are currently processing from stub. 513 out_graphdef: The destination graphdef we are currently building up. 514 515 Returns: 516 The datatype of this identity. 517 518 Raises: 519 RuntimeError: if the method is not implemented. 520 """ 521 del fused_op_name, output_index, out_graphdef 522 raise RuntimeError("Unimplemented abstract method.") 523 524 525class _LiteSingleOperand(_LiteOperand): 526 """A simple operand that is non-aggregated (i.e. most hints).""" 527 528 def __init__(self, node): 529 _LiteOperand.__init__(self) 530 self.node = node 531 self.name = _tensor_name_base(node.name) 532 533 def flatten(self): 534 return [self.name] 535 536 def aggregate_and_return_name_for_input(self, out_graphdef): 537 return self.name 538 539 def aggregate_and_return_name_for_output(self, fused_op_name, index, 540 out_graphdef): 541 output_node = _copy.deepcopy(self.node) 542 del output_node.input[:] 543 output_node.input.append(_tensorflow_output_name(fused_op_name, index)) 544 out_graphdef.node.extend([output_node]) 545 return self.node.attr["type"].i 546 547 def __str__(self): 548 return str(self.name) 549 550 551class _LiteAggregateOperand(_LiteOperand): 552 """An operand for a tflite hint function that is aggregated from many. 553 554 For example, an LSTM is a grid of operators that are all related. Inputs 555 going into them may need to be fused, so they should all be tracked as 556 related arguments. 557 """ 558 559 def __init__(self, aggregation): 560 _LiteOperand.__init__(self) 561 self.aggregation = aggregation 562 self.names = {} 563 self.nodes = {} 564 self.flattened = None 565 566 def add(self, sort, node): 567 self.names[sort] = _tensor_name_base(node.name) 568 self.nodes[sort] = node 569 570 def flatten_nodes(self): 571 """Return a list of all the node protos in aggregation sorted order.""" 572 if not self.flattened: 573 self.flattened = [None] * len(self.nodes) 574 for idx, node in _six.iteritems(self.nodes): 575 self.flattened[idx] = node 576 for n in self.nodes: 577 if n is None: 578 raise RuntimeError("Aggregate was missing argument.") 579 if self.aggregation == OpHint.AGGREGATE_FIRST: 580 self.flattened = self.flattened[:1] 581 elif self.aggregation == OpHint.AGGREGATE_LAST: 582 self.flattened = self.flattened[-1:] 583 elif self.aggregation == OpHint.AGGREGATE_STACK: 584 pass 585 else: 586 raise ValueError("Invalid aggregation type %r specified" % 587 self.aggregation) 588 return self.flattened 589 590 def flatten(self): 591 """Return a list of all node names in aggregation sorted sorter.""" 592 return [_tensor_name_base(x.name) for x in self.flatten_nodes()] 593 594 def aggregate_and_return_name_for_input(self, out_graphdef): 595 """This adds the nodes to out_graphdef and returns an aggregated output. 596 597 In particular, if you have 4 inputs to a hint stub, this will be the 598 node that you can use as an output. I.e. you have 4 timesteps from a 599 static rnn, then a fused UnidirectionalLSTM will expect 1 input with 600 all 4 time steps. So here we make a pack and return the output name of 601 that pack. 602 603 Args: 604 out_graphdef: A graphdef that is ready to have this input added. 605 606 Returns: 607 The name of a pack that aggregates this node. 608 """ 609 flattened = self.flatten_nodes() 610 if (self.aggregation == OpHint.AGGREGATE_FIRST) or ( 611 self.aggregation == OpHint.AGGREGATE_LAST): 612 assert len(flattened) == 1 613 if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK: 614 return _tensor_name_base(flattened[0].name) 615 else: 616 new_node = _node_def_pb2.NodeDef() 617 new_node.op = "Pack" 618 new_node.name = "OpHintStack-%s" % flattened[0].name 619 new_node.attr["N"].i = len(flattened) 620 new_node.attr["T"].type = flattened[0].attr["T"].type 621 for discrete in flattened: 622 new_node.input.append(_tensor_name_base(discrete.name)) 623 out_graphdef.node.extend([new_node]) 624 return new_node.name 625 626 def aggregate_and_return_name_for_output(self, fused_op_name, output_index, 627 out_graphdef): 628 """This adds to `out_graphdef` all the unaggregated outputs. 629 630 I.e. we are outputting from a fused stub, but we need to make it compatible 631 with the unfused original graph so we insert an unpack. Ideally in a later 632 stage the unpack -> pack sequences will be removed. 633 634 Args: 635 fused_op_name: The name of the stub we are in the process of fusing. 636 output_index: The output output_index this object represents. 637 out_graphdef: The graphdef we are in the process of buildings 638 639 Returns: 640 The type of the aggregated output (so we can finish building the stub 641 op). 642 """ 643 flattened = self.flatten_nodes() 644 if (self.aggregation == OpHint.AGGREGATE_FIRST) or ( 645 self.aggregation == OpHint.AGGREGATE_LAST): 646 assert len(flattened) == 1 647 if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK: 648 temp_op = _LiteSingleOperand(flattened[0]) 649 return temp_op.aggregate_and_return_name_for_output( 650 fused_op_name, output_index, out_graphdef) 651 else: 652 stack_node = _node_def_pb2.NodeDef() 653 stack_node.op = "Unpack" 654 stack_node.name = "OpHintUnstack-%s" % flattened[0].name 655 stack_node.attr["num"].i = len(flattened) 656 output_type = flattened[0].attr["T"].type 657 stack_node.attr["T"].type = output_type 658 stack_node.input.append( 659 _tensorflow_output_name(fused_op_name, output_index)) 660 out_graphdef.node.extend([stack_node]) 661 662 for idx, discrete in enumerate(flattened): 663 output_node = _copy.deepcopy(discrete) 664 del output_node.input[:] 665 output_node.input.append(_tensorflow_output_name(stack_node.name, idx)) 666 out_graphdef.node.extend([output_node]) 667 668 return output_type 669 670 def __str__(self): 671 s = "\t\t\tAGGREGATE %s\n" % self.aggregation 672 for sort, val in self.names.iteritems(): 673 s += "\t\t\t%d: %s\n" % (sort, val) 674 return s 675 676 677class _LiteFuncCall(object): 678 """Represent a TensorFlow Lite custom function. 679 680 This is uses to accumulate found hints in the graphdef into a single 681 conceptual unit. 682 683 Attributes: 684 inputs: inputs to the op (hash from index # to argument) 685 outputs: outputs to the op (hash from index # to argument) 686 function_name: the tflite custom op name to use 687 uuid: a unique call id for this particular call (i.e. multiple function 688 calls would have the same function_name but different uuids. 689 params: A param name to key value for op constant data. I.e. for axis on a 690 reduction, strides on a convolution, etc. 691 level: Level of the OpHint. 692 children_inputs_mappings: If the Ophint has children, children inputs 693 mappings indicate how their inputs & outputs are mapped. 694 """ 695 696 def __init__(self): 697 self.inputs = {} 698 self.outputs = {} 699 self.function_name = None 700 self.uuid = None 701 self.params = {} 702 self.level = -1 703 self.children_inputs_mappings = {} 704 705 def flattened_inputs_and_outputs(self): 706 """Return a list of inputs and outputs in a flattened format. 707 708 Returns: 709 Tuple of (inputs, outputs). where input and output i a list of names. 710 """ 711 712 def _flatten(input_or_output_dict): 713 flattened_items = [] 714 for item in input_or_output_dict.values(): 715 flattened_items.extend(item.flatten()) 716 return flattened_items 717 718 return _flatten(self.inputs), _flatten(self.outputs) 719 720 def __str__(self): 721 722 def format_args(items): 723 s = "" 724 for idx, item in items.iteritems(): 725 s += ("\t\t%d:\n" % idx) + str(item) 726 return s 727 728 inputs_str = "\tInputs\n" + format_args(self.inputs) 729 outputs_str = "\tOutputs\n" + format_args(self.outputs) 730 731 return ( 732 "tflite function %s call %s level %d " 733 "\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" % 734 (self.function_name, self.uuid, self.level, inputs_str, outputs_str)) 735 736 737def _find_all_hints_in_nodes(nodes): 738 """Look at the all the input nodes and return a list of LiteFuncCall objs. 739 740 Args: 741 nodes: A TensorFlow graph_def to look for LiteFuncCalls. 742 743 Returns: 744 a list of `LifeFuncCall` objects in the form 745 746 """ 747 func_calls = _collections.defaultdict(_LiteFuncCall) 748 749 for node in nodes: 750 attr = node.attr 751 # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip 752 if (OpHint.FUNCTION_UUID_ATTR not in attr or 753 not attr[OpHint.FUNCTION_UUID_ATTR].s): 754 continue 755 uuid = attr[OpHint.FUNCTION_UUID_ATTR].s 756 757 # Start building function 758 call_def = func_calls[uuid] 759 call_def.uuid = uuid 760 call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s 761 call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i 762 # Get sorting and aggregation information 763 764 sort = ( 765 attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i 766 if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None) 767 if sort == -1: 768 sort = None 769 aggregation = None 770 if OpHint.FUNCTION_AGGREGATE_ATTR in attr: 771 aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s) 772 773 if OpHint.CHILDREN_INPUTS_MAPPINGS in attr: 774 call_def.children_inputs_mappings = _json.loads( 775 _compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s)) 776 777 # Add the input or output 778 def put_operand(stuff, index, sort, operand, aggregation): 779 """Add a given index into the function structure.""" 780 if sort is None: 781 stuff[index] = _LiteSingleOperand(operand) 782 else: 783 if index not in stuff: 784 stuff[index] = _LiteAggregateOperand(aggregation) 785 stuff[index].add(sort, operand) 786 787 if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: 788 put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i, 789 sort, node, aggregation) 790 if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr: 791 put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i, 792 sort, node, aggregation) 793 794 # Remember attributes 795 for a in attr: 796 if a.startswith("_tflite_attr_"): 797 call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor 798 799 return func_calls 800 801 802def _extract_topology_sequence_mapping(nodes): 803 return dict( 804 (_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes)) 805 806 807def _find_children_hints_in_while_loop(function_def, nodes_mapping): 808 """Find children hints and all nodes inside the while loop. 809 810 Args: 811 function_def: Function def of the while loop. 812 nodes_mapping: While loop input_arg : real node name. 813 814 Returns: 815 Ordered children hints and all re-mapped nodes inside the while loop. 816 """ 817 new_nodes = [] 818 819 # Make nodes inside function def inputs point to the real nodes. 820 for node in function_def.node_def: 821 for i, _ in enumerate(node.input): 822 if node.input[i] in nodes_mapping: 823 node.input[i] = nodes_mapping[node.input[i]] 824 new_nodes.append(_copy.deepcopy(node)) 825 name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def) 826 children_hints = _find_all_hints_in_nodes(new_nodes) 827 children_hints_q = [] 828 # Ordered by the outputs. 829 for hint in _six.itervalues(children_hints): 830 _, output_names = hint.flattened_inputs_and_outputs() 831 seq = name_to_seq_num[output_names[0]] 832 for output_name in output_names: 833 seq = min(seq, name_to_seq_num[output_name]) 834 children_hints_q.append((seq, hint)) 835 children_hints_q.sort(key=lambda tup: tup[0]) 836 ordered_children_hints = [x[1] for x in children_hints_q] 837 return ordered_children_hints, new_nodes 838 839 840def _find_children_hints(call, graph_def): 841 """Find all children hints. 842 843 For a given OpHint, we find all children hints inside it, we also copy all the 844 nodes inside function defs (if applicable) to the original graph_def, they are 845 returned in a list as well. 846 847 Args: 848 call: Parent OpHint that contains children ophints. 849 graph_def: Original graph def. 850 851 Returns: 852 Ordered children hints inside the parent ophint; new graph def that contains 853 nodes inside function defs (if applicable); nodes inside function defs. 854 """ 855 name_to_input_name, _, _ = _extract_graph_summary(graph_def) 856 input_names, output_names = call.flattened_inputs_and_outputs() 857 858 reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) 859 reachable_by_output = _bfs_for_reachable_nodes(output_names, 860 name_to_input_name) 861 output_nodes_set = set(output_names) 862 children_hints = [] 863 out = _graph_pb2.GraphDef() 864 out.library.CopyFrom(graph_def.library) 865 out.versions.CopyFrom(graph_def.versions) 866 function_def_nodes = set() 867 for node in graph_def.node: 868 out.node.extend([_copy.deepcopy(node)]) 869 n = _tensor_name_base(node.name) 870 if n in reachable_by_output: 871 if n not in reachable_by_input and n not in output_nodes_set: 872 # special handle for while loop function def. 873 if node.op == "While" or node.op == "StatelessWhile": 874 body_name = node.attr["body"].func.name 875 inputs_outside_loop = node.input 876 for function_def in graph_def.library.function: 877 if function_def.signature.name == body_name: 878 function_inputs = function_def.signature.input_arg 879 assert len(inputs_outside_loop) == len(function_inputs) 880 nodes_mapping = {} 881 for i, function_input in enumerate(function_inputs): 882 nodes_mapping[function_input.name] = inputs_outside_loop[i] 883 # TODO(b/123050804): Consider use grappler. 884 (children_hints_in_loop, 885 new_nodes) = _find_children_hints_in_while_loop( 886 function_def, nodes_mapping) 887 function_def_nodes.update([x.name for x in new_nodes]) 888 children_hints.extend(children_hints_in_loop) 889 out.node.extend(new_nodes) 890 891 return children_hints, out, function_def_nodes 892 893 894def _tensor_name_base(full_tensor_name): 895 """Removes the device assignment code from a tensor. 896 897 e.g. _tensor_name_base("foo:3") => "foo" 898 899 Args: 900 full_tensor_name: A tensor name that is annotated with a device placement 901 (this is what tensor flow introspection gives). 902 903 Returns: 904 A name without any device assignment. 905 """ 906 if full_tensor_name.startswith("^"): 907 return full_tensor_name[1:] 908 return full_tensor_name.split(":")[0] 909 910 911def _tensorflow_output_name(tensor_name, output_index): 912 return tensor_name if output_index == 0 else "%s:%d" % (tensor_name, 913 output_index) 914 915 916# TODO(aselle): This should be converted to grappler in the future. 917def _check_subgraph_closed(n, reachable_by_input, input_nodes_set, 918 name_to_input_name): 919 """Checks to make sure node only connects to predecessor graph through inputs. 920 921 Args: 922 n: Node to check 923 reachable_by_input: Nodes that are reachable by all inputs of subgraph 924 input_nodes_set: The set of nodes that are "inputs". 925 name_to_input_name: Maps from name to the list of inputs. 926 927 Raises: 928 TypeError: If the given node uses items past inputs directly. 929 """ 930 next_to_visit = [n] 931 visited = set() 932 while next_to_visit: 933 current_node = next_to_visit.pop() 934 visited.add(current_node) 935 if (current_node in reachable_by_input and 936 current_node not in input_nodes_set): 937 raise TypeError("Node %s uses input %s not in input_nodes." % 938 (n, current_node)) 939 if current_node not in input_nodes_set: 940 next_to_visit += [ 941 input_node for input_node in name_to_input_name[current_node] 942 if input_node not in visited 943 ] 944 945 946# TODO(aselle): This should be converted to grappler in the future. 947def _convert_single_op_hint_to_stub(call, 948 graph_def, 949 function_def_nodes=None, 950 is_last_run=True): 951 """Given a graph_def, converts `call` into a stub and returns a new graph_def. 952 953 Args: 954 call: A single function call to be converted. 955 graph_def: A graph_def to use as input (that has call obviously). 956 function_def_nodes: Nodes inside the function def those are not connected to 957 the graph. 958 is_last_run: Whether it is the last run for a given pass (for OpHint has 959 children). 960 961 Returns: 962 A new transformed graph-def that has call as a stub (single op). 963 964 Note: after this process, the graph_def can no longer be loaded into 965 the tensorflow runtime, so all future manipulations are done in graph_def 966 level. 967 """ 968 if function_def_nodes is None: 969 function_def_nodes = set() 970 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 971 graph_def) 972 input_names, output_names = call.flattened_inputs_and_outputs() 973 974 reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) 975 reachable_by_output = _bfs_for_reachable_nodes(output_names, 976 name_to_input_name) 977 output_nodes_set = set(output_names) 978 nodes_after_fuse = [] 979 nodes_deleted_by_fuse = set() 980 # Classify each node. We want to keep everything reachable by input, but 981 # we don't know if things that are not reachable by output or input (things 982 # after fusing). 983 for node in graph_def.node: 984 n = _tensor_name_base(node.name) 985 if n in reachable_by_output: 986 if n not in reachable_by_input and n not in output_nodes_set: 987 nodes_deleted_by_fuse.add(n) 988 elif n not in reachable_by_input and n not in function_def_nodes: 989 # n is a node that after all the fusings, so keep it. 990 nodes_after_fuse.append(n) 991 else: 992 # In the last run, n is a node that is randomly in the graph but not 993 # connected to the chain of dependencies, we will delete n, otherwise 994 # we keep them. 995 if not is_last_run: 996 nodes_after_fuse.append(n) 997 998 # Make a new graphdef with all the pre-input and input nodes 999 out = _graph_pb2.GraphDef() 1000 reachable_by_input_sorted = sorted( 1001 list(reachable_by_input), key=lambda n: name_to_seq_num[n]) 1002 for node in reachable_by_input_sorted: 1003 out.node.extend([_copy.deepcopy(name_to_node[node])]) 1004 1005 # Create any stacks to aggregate arguments into to a single input 1006 # i.e. for static_rnn's. 1007 # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1 1008 sorted_input_indices = list(call.inputs.keys()) 1009 sorted_input_indices.sort() 1010 sorted_output_indices = list(call.outputs.keys()) 1011 sorted_output_indices.sort() 1012 new_node = _node_def_pb2.NodeDef() 1013 # Delegate to each operand to produce the proper new input for this stub node. 1014 # In particular, an aggregate input will now be a Pack of some previously 1015 # non-fused things. 1016 1017 optional_input_node = _node_def_pb2.NodeDef() 1018 optional_input_node.name = "Const" + str(_uuid.uuid1().hex) 1019 optional_input_node.op = "Const" 1020 optional_input_node.attr["dtype"].CopyFrom( 1021 _attr_value_pb2.AttrValue(type=_dtypes.float32.as_datatype_enum)) 1022 optional_input_node.attr["value"].CopyFrom( 1023 _attr_value_pb2.AttrValue( 1024 tensor=_tensor_util.make_tensor_proto([-1], _dtypes.float32, [1]))) 1025 out.node.extend([optional_input_node]) 1026 1027 max_index = max(sorted_input_indices) + 1 1028 for cur_index in range(max_index): 1029 if cur_index in sorted_input_indices: 1030 inputs = call.inputs[cur_index] 1031 input_name = inputs.aggregate_and_return_name_for_input(out) 1032 new_node.input.append(input_name) 1033 else: 1034 new_node.input.append(optional_input_node.name) 1035 1036 new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices) 1037 1038 # Create the function 1039 new_node.op = call.function_name 1040 new_node.name = call.uuid 1041 out.node.extend([new_node]) 1042 1043 # Now call each output argument to give them a chance to make the proper 1044 # output type and add it to our new_node. 1045 output_dtypes = [] 1046 max_output_index = max(sorted_output_indices) + 1 1047 for cur_index in range(max_output_index): 1048 if cur_index in sorted_output_indices: 1049 output = call.outputs[cur_index] 1050 output_dtype = ( 1051 output.aggregate_and_return_name_for_output(new_node.name, cur_index, 1052 out)) 1053 else: 1054 output_dtype = optional_input_node.attr["type"].i 1055 output_dtypes.append(output_dtype) 1056 new_node.attr["_output_types"].list.type[:] = output_dtypes 1057 # TODO(aselle): what is right here? 1058 new_node.attr["_output_quantized"].b = False 1059 1060 # Add post output nodes that do not depend on the outputs 1061 for n in nodes_after_fuse: 1062 should_keep = True 1063 for input_name in name_to_input_name[n]: 1064 if input_name in nodes_deleted_by_fuse: 1065 should_keep = False 1066 if should_keep: 1067 out.node.extend([_copy.deepcopy(name_to_node[n])]) 1068 1069 # Misc. graph_def data that needs copying. 1070 out.library.CopyFrom(graph_def.library) 1071 out.versions.CopyFrom(graph_def.versions) 1072 1073 return out 1074 1075 1076# TODO(aselle): This should be converted to grappler in the future. 1077def _remove_one_redundant_stack_unstack(in_graph_def): 1078 """Removes a stack->unstack pattern from in_graph_def in a returned graph. 1079 1080 Args: 1081 in_graph_def: Graph def to use as input. 1082 1083 Returns: 1084 Simplified tuple (graph_def, changed_something) where changed_something 1085 is true if anything was done. 1086 """ 1087 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 1088 in_graph_def) 1089 del name_to_seq_num 1090 1091 # TODO(aselle): Make this not hardcoded. 1092 do_generic_pack_unpack = True 1093 1094 out = _graph_pb2.GraphDef() 1095 out.library.CopyFrom(in_graph_def.library) 1096 out.versions.CopyFrom(in_graph_def.versions) 1097 for n in in_graph_def.node: 1098 node_name = _tensor_name_base(n.name) 1099 if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"): 1100 continue 1101 next_to_visit = [node_name] 1102 visited = set() 1103 1104 unpack_nodes = set() 1105 pack_node = node_name 1106 1107 # Find a pattern of unstack connected to a stack (with identities 1108 # in between. 1109 matches_pattern = True 1110 is_hint_created_stack = False 1111 while next_to_visit: 1112 current_node_name = next_to_visit[0] 1113 visited.add(current_node_name) 1114 del next_to_visit[0] 1115 node = name_to_node[current_node_name] 1116 is_op_hint_stack = node.name.startswith("OpHintStack") 1117 is_op_hint_unstack = node.name.startswith("OpHintUnstack") 1118 if (node.op == "Identity" or is_op_hint_stack or 1119 (do_generic_pack_unpack and node.op == "Pack")): 1120 is_hint_created_stack |= is_op_hint_stack 1121 next_to_visit += [ 1122 input_node for input_node in name_to_input_name[current_node_name] 1123 if input_node not in visited 1124 ] 1125 elif (is_op_hint_unstack or 1126 (do_generic_pack_unpack and node.op == "Unpack")): 1127 unpack_nodes.add(node.name) 1128 is_hint_created_stack &= is_op_hint_unstack 1129 else: 1130 matches_pattern = False 1131 break 1132 visited.add(node.name) 1133 1134 if matches_pattern and len(unpack_nodes) == 1: 1135 pack_node = node_name 1136 1137 # Check to see if anyone depends on the intermediate identity or the 1138 # Unstacked form 1139 no_external_dependency = True 1140 for other_n in in_graph_def.node: 1141 if other_n.name in visited: 1142 continue 1143 for input_tensor in name_to_input_name[other_n.name]: 1144 input_op = _tensor_name_base(input_tensor) 1145 if input_op in visited and input_op != pack_node: 1146 no_external_dependency = False 1147 # Proceed with the substitution if the stack/unstack pair was created 1148 # through hints, or that it was not, but nobody is consuming things 1149 # between the stack and unstack. 1150 if is_hint_created_stack or no_external_dependency: 1151 end = unpack_nodes.pop() 1152 end_input = name_to_node[end].input[0] 1153 # All nodes that depend on the final stack need to be redone to use 1154 for other_n in in_graph_def.node: 1155 node_name = _tensor_name_base(other_n.name) 1156 if node_name not in visited: 1157 new_node = _copy.deepcopy(other_n) 1158 new_node.input[:] = [ 1159 (end_input if stripped == pack_node else non_stripped) 1160 for stripped, non_stripped in zip(name_to_input_name[node_name], 1161 new_node.input[:]) 1162 ] 1163 out.node.extend([new_node]) 1164 return out, True 1165 return in_graph_def, False 1166 1167 1168def _remove_redundant_stack_unstack(graph_def): 1169 curr = graph_def 1170 del graph_def 1171 changed_stuff = True 1172 while changed_stuff: 1173 curr, changed_stuff = _remove_one_redundant_stack_unstack(curr) 1174 return curr 1175 1176 1177def _get_correct_mapping(original_index, nodes): 1178 # Special handle for the index is -1 case. 1179 # If it is -1, return the last index. 1180 if original_index == -1: 1181 node_indices = nodes.keys() 1182 node_indices = sorted(node_indices) 1183 return node_indices[-1] 1184 return original_index 1185 1186 1187def _convert_op_hints_to_stubs_helper( 1188 graph_def, write_callback=lambda sess, graph_def: None): 1189 """Converts a graph_def to a new graph_def where all op hints are stubbed. 1190 1191 Args: 1192 graph_def: A graph def that we should convert. 1193 write_callback: A function pointer that can be used to write intermediate 1194 steps of graph transformation (optional). 1195 1196 Returns: 1197 A new stubbed graph_def. 1198 """ 1199 hints = _find_all_hints_in_nodes(graph_def.node) 1200 1201 hints_q = [] 1202 for hint in _six.itervalues(hints): 1203 hints_q.append((hint.level, hint.uuid)) 1204 1205 hints_q.sort(key=lambda tup: tup[0]) 1206 for i in range(len(hints_q) - 1, -1, -1): 1207 level, hint_uuid = hints_q[i] 1208 1209 curr_graph_def = graph_def 1210 del graph_def # prevent using graph_def again (common source of error) 1211 for i in range(len(hints_q) - 1, -1, -1): 1212 level, hint_uuid = hints_q[i] 1213 if level >= 2: 1214 children_hints, curr_graph_def, function_def_nodes = _find_children_hints( 1215 hints[hint_uuid], curr_graph_def) 1216 # pylint: disable=superfluous-parens 1217 assert (len(children_hints) > 0) # pylint: disable=g-explicit-length-test 1218 # pylint: enable=superfluous-parens 1219 1220 # Re-wire the children hints inputs/outputs, so latter child's inputs 1221 # connect to previous child node's outputs. 1222 children_inputs_mappings = hints[hint_uuid].children_inputs_mappings 1223 for j, child_hint in enumerate(children_hints): 1224 if j == 0: 1225 for mapping in children_inputs_mappings["parent_first_child_input"]: 1226 parent_input_index = _get_correct_mapping( 1227 mapping["parent_ophint_input_index"], hints[hint_uuid].inputs) 1228 child_input_index = _get_correct_mapping( 1229 mapping["first_child_ophint_input_index"], child_hint.inputs) 1230 child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[ 1231 parent_input_index] 1232 else: 1233 for mapping in children_inputs_mappings[ 1234 "internal_children_input_output"]: 1235 input_index = _get_correct_mapping(mapping["child_input_index"], 1236 child_hint.inputs) 1237 output_index = _get_correct_mapping(mapping["child_output_index"], 1238 children_hints[j - 1].outputs) 1239 child_hint.inputs[input_index] = children_hints[ 1240 j - 1].outputs[output_index] 1241 if j == len(children_hints) - 1: 1242 for mapping in children_inputs_mappings["parent_last_child_output"]: 1243 parent_output_index = _get_correct_mapping( 1244 mapping["parent_output_index"], hints[hint_uuid].outputs) 1245 child_output_index = _get_correct_mapping( 1246 mapping["child_output_index"], child_hint.outputs) 1247 child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[ 1248 parent_output_index] 1249 1250 for j, child_hint in enumerate(children_hints): 1251 curr_graph_def = _convert_single_op_hint_to_stub( 1252 child_hint, curr_graph_def, function_def_nodes, 1253 j == len(children_hints) - 1) 1254 else: 1255 curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid], 1256 curr_graph_def) 1257 write_callback(curr_graph_def, "initial") 1258 # The stubbing process can create stacks/unstacks in the case of LSTMs 1259 # remove them. 1260 curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def) 1261 return curr_graph_def 1262 1263 1264def find_all_hinted_output_nodes(session=None, graph_def=None): 1265 """Find all Ophints output nodes in the graph. 1266 1267 This is used to get all the output nodes those are ophinted, it is important 1268 for operation like convert_variables_to_constants keep all ophints structure. 1269 Note: only one of session or graph_def should be used, not both. 1270 Why this can be useful? Some TensorFlow ops (e.g. bidirectional rnn), can 1271 generate multiple outputs for unfused subgraph. If not all output nodes are 1272 consumed, graph optimization can potentially drop the unused nodes and cause 1273 ophints in an invalid states (due to missing ophinted output nodes). So it's 1274 important for us to find all those hinted output nodes and make sure they're 1275 not discarded away. 1276 1277 Args: 1278 session: A TensorFlow session that contains the graph to convert. 1279 graph_def: A graph def that we should convert. 1280 1281 Returns: 1282 A list of OpHints output nodes. 1283 Raises: 1284 ValueError: If both session and graph_def are provided. 1285 """ 1286 if session is not None and graph_def is not None: 1287 raise ValueError("Provide only one of session and graph_def.") 1288 hinted_outputs_nodes = [] 1289 if session is not None: 1290 hints = _find_all_hints_in_nodes(session.graph_def.node) 1291 elif graph_def is not None: 1292 hints = _find_all_hints_in_nodes(graph_def.node) 1293 for hint in _six.itervalues(hints): 1294 _, output_nodes = hint.flattened_inputs_and_outputs() 1295 hinted_outputs_nodes.extend(output_nodes) 1296 return hinted_outputs_nodes 1297 1298 1299def is_ophint_converted(graph_def): 1300 if graph_def is None: 1301 raise ValueError("Must provide the graph_def.") 1302 ophint_converted = False 1303 for node in graph_def.node: 1304 attr = node.attr 1305 if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: 1306 ophint_converted = True 1307 break 1308 return ophint_converted 1309 1310 1311@_tf_export(v1=["lite.experimental.convert_op_hints_to_stubs"]) 1312@_deprecation.deprecated( 1313 None, 1314 "Please follow instructions under " 1315 "https://www.tensorflow.org/lite/convert/operation_fusion for operation" 1316 "fusion in tflite." 1317) 1318def convert_op_hints_to_stubs(session=None, 1319 graph_def=None, 1320 write_callback=lambda graph_def, comments: None): 1321 """Converts a graphdef with LiteOp hints into stub operations. 1322 1323 This is used to prepare for toco conversion of complex intrinsic usages. 1324 Note: only one of session or graph_def should be used, not both. 1325 1326 Args: 1327 session: A TensorFlow session that contains the graph to convert. 1328 graph_def: A graph def that we should convert. 1329 write_callback: A function pointer that can be used to write intermediate 1330 steps of graph transformation (optional). 1331 1332 Returns: 1333 A new graphdef with all ops contained in OpHints being replaced by 1334 a single op call with the right parameters. 1335 Raises: 1336 ValueError: If both session and graph_def are provided. 1337 """ 1338 1339 if session is not None and graph_def is not None: 1340 raise ValueError("Provide only one of session and graph_def.") 1341 1342 if session is not None: 1343 return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback) 1344 elif graph_def is not None: 1345 return _convert_op_hints_to_stubs_helper(graph_def, write_callback) 1346 else: 1347 raise ValueError("Must specify session or graph_def as input.") 1348 1349 1350_allowed_symbols = [ 1351 "OpHint", 1352 "convert_op_hints_to_stubs", 1353 "convert_op_hints_to_stubs_new", 1354 "find_all_hinted_output_nodes", 1355 "is_ophint_converted", 1356] 1357remove_undocumented(__name__, _allowed_symbols) 1358