1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Compiled parallel-for loop.""" 16# pylint: disable=missing-docstring,g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import string 24import sys 25import traceback 26 27import numpy as np 28import six 29 30from tensorflow.compiler.tf2xla.python import xla 31from tensorflow.core.framework import types_pb2 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.eager import execute 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import func_graph 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import sparse_tensor 40from tensorflow.python.framework import tensor_shape 41from tensorflow.python.framework import tensor_spec 42from tensorflow.python.framework import tensor_util 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import bitwise_ops 45from tensorflow.python.ops import control_flow_ops 46from tensorflow.python.ops import custom_gradient 47from tensorflow.python.ops import data_flow_ops 48from tensorflow.python.ops import gen_array_ops 49from tensorflow.python.ops import gen_dataset_ops 50from tensorflow.python.ops import gen_image_ops 51from tensorflow.python.ops import gen_linalg_ops 52from tensorflow.python.ops import gen_list_ops 53from tensorflow.python.ops import gen_math_ops 54from tensorflow.python.ops import gen_nn_ops 55from tensorflow.python.ops import gen_parsing_ops 56from tensorflow.python.ops import gen_random_ops 57from tensorflow.python.ops import gen_sparse_ops 58from tensorflow.python.ops import gen_spectral_ops 59from tensorflow.python.ops import linalg_ops 60from tensorflow.python.ops import list_ops 61from tensorflow.python.ops import map_fn 62from tensorflow.python.ops import math_ops 63from tensorflow.python.ops import nn_ops 64from tensorflow.python.ops import parsing_ops 65from tensorflow.python.ops import resource_variable_ops 66from tensorflow.python.ops import sparse_ops 67from tensorflow.python.ops import special_math_ops 68from tensorflow.python.ops import tensor_array_ops 69from tensorflow.python.platform import flags 70from tensorflow.python.platform import tf_logging as logging 71from tensorflow.python.util import compat 72from tensorflow.python.util import nest 73from tensorflow.python.util import object_identity 74 75 76# TODO(agarwal): remove flag. 77flags.DEFINE_bool( 78 "op_conversion_fallback_to_while_loop", True, 79 "DEPRECATED: Flag is ignored.") 80 81 82def _variant_handle_data(t): 83 """Fetches handle data for a variant tensor `t`, or None if unavailable.""" 84 handle_data = resource_variable_ops.get_eager_safe_handle_data(t) 85 if not handle_data.is_set: 86 return None 87 return handle_data.shape_and_type 88 89 90def _is_variant_with_internal_stacking(t): 91 """Identifies variant tensors which pfor always maintains as scalars. 92 93 For these, the pfor tensor is recorded as "stacked" if the content of the 94 variant tensor (e.g. the elements of a TensorList) are all stacked. 95 96 Args: 97 t: A tensor to identify. 98 Returns: 99 True if `t` is a TensorList/Optional, False not, None if unknown. 100 """ 101 if t.dtype != dtypes.variant: 102 return False 103 shapes_and_types = _variant_handle_data(t) 104 if shapes_and_types is None or not shapes_and_types: 105 # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can 106 # make this an error instead of assuming TensorLists have handle data. 107 return None # Presumed not a TensorList/Optional 108 return (shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST or 109 shapes_and_types[0].specialized_type == types_pb2.ST_OPTIONAL) 110 111 112def _parse_variant_shapes_and_types(t): 113 """Extracts shape and dtype information from a variant tensor `t`.""" 114 shapes_and_types = _variant_handle_data(t) 115 if shapes_and_types is None or not shapes_and_types: 116 raise ValueError("Required handle data not set for {!r}".format(t)) 117 if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST: 118 return shapes_and_types 119 else: 120 if shapes_and_types[0].specialized_type != types_pb2.ST_INVALID: 121 return shapes_and_types 122 else: 123 raise ValueError( 124 "Attempted to stack a variant-dtype tensor with no type set ({!r})" 125 .format(t)) 126 127 128def _stack(t, length): 129 """stacks `t` `length` times.""" 130 # Note that this stacking may currently be triggered, for example, when a 131 # loop invariant tensor with dtype variant is input to a while_loop which then 132 # produces a loop dependent output. Simply stacking the variants may not be 133 # suitable since operations on stacked handles may expect a vectorized version 134 # of the variant. 135 if t.dtype == dtypes.variant: 136 shapes_and_types = _parse_variant_shapes_and_types(t) 137 if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST: 138 if len(shapes_and_types) != 1: 139 raise ValueError( 140 "Expected handle data of length 1, got {!r} of length {}" 141 .format(shapes_and_types, len(shapes_and_types))) 142 return wrap( 143 _stack_tensor_list(t, shapes_and_types[0].dtype, length), 144 True) 145 else: 146 raise ValueError( 147 ("Attempted to stack an unhandled variant-dtype tensor of " 148 "type {!r} ({!r})").format(shapes_and_types[0].specialized_type, t)) 149 ones = array_ops.ones_like(array_ops.shape(t)) 150 ones = array_ops.reshape(ones, [-1]) 151 length = array_ops.reshape(length, [-1]) 152 multiples = array_ops.concat([length, ones], 0) 153 t = array_ops.tile(array_ops.expand_dims(t, 0), multiples) 154 return wrap(t, True) 155 156 157# The following stateful ops can be safely called once, and with the same 158# signature as the unconverted version, if their inputs are loop invariant. 159# TODO(agarwal): implement a strategy for converting Variable reads/writes. The 160# plan is to map each read/write in the loop_fn to a corresponding merged 161# read/write in the converted graph. Writes need to be mergeable (e.g. 162# AssignAdd) to be used in `pfor`. Given a certain read/write order in the 163# loop_fn, doing a one-to-one conversion will simulate executing such 164# instructions in lock-step across all iterations. 165passthrough_stateful_ops = set([ 166 "VariableV2", 167 "VarHandleOp", 168 "VariableShape", 169 "ReadVariableOp", 170 "StackV2", 171 "TensorArrayWriteV3", 172 "TensorArrayReadV3", 173 "TensorArraySizeV3", 174]) 175 176 177# Ops which we will treat like stateful for the purpose of vectorization. 178# Typically this is used to force pfor converters to run for these ops. 179force_stateful_ops = set([ 180 # We vectorize this since we need to change the element shape set on the 181 # list. 182 "TensorListReserve", 183]) 184 185 186def _is_stateful_pfor_op(op): 187 if isinstance(op, WhileOp): 188 return op.is_stateful 189 if op.type == "Const": 190 # Const didn't have an op_def. 191 return False 192 if op.type in passthrough_stateful_ops: 193 return False 194 if op.type in force_stateful_ops: 195 return True 196 assert hasattr(op, "op_def") and op.op_def is not None, op 197 return op.op_def.is_stateful 198 199 200# pylint: disable=protected-access 201class WhileOp(object): 202 """Object for storing state for converting the outputs of a while_loop.""" 203 204 def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config): 205 """Initializer. 206 207 Args: 208 exit_node: A tensor output from the while_loop. 209 pfor_ops: list of ops inside the current pfor loop. 210 fallback_to_while_loop: If True, fallback to while loop when conversion of 211 an op is not supported 212 pfor_config: PForConfig object used while constructing loop body. 213 """ 214 self._fallback_to_while_loop = fallback_to_while_loop 215 self._pfor_config = pfor_config 216 self._pfor_ops = set(pfor_ops) 217 self._pfor_op_ids = set(x._id for x in pfor_ops) 218 assert isinstance(exit_node, ops.Tensor) 219 self._while_context = exit_node.op._get_control_flow_context() 220 assert isinstance(self._while_context, control_flow_ops.WhileContext) 221 self._context_name = self._while_context.name 222 self._condition = self._while_context.pivot.op.inputs[0] 223 # Parts of an external while_loop could be created inside a pfor loop. 224 # However for the purpose here, we declare such loops to be external. Also 225 # note that we check if the condition was created inside or outside to 226 # determine if the while_loop was first created inside or outside. 227 # TODO(agarwal): check that the Enter and Exit of this loop are unstacked. 228 self._is_inside_loop = self.op_is_inside_loop(self._condition.op) 229 if self._is_inside_loop: 230 for e in self._while_context.loop_exits: 231 assert self.op_is_inside_loop(e.op) 232 233 # Note the code below tries to reverse engineer an existing while_loop graph 234 # by assuming the following pattern of nodes. 235 # 236 # NextIteration <---- Body <--- Enter 237 # | ^ 238 # V ___| Y 239 # Enter -> Merge -> Switch___ 240 # ^ | N 241 # | V 242 # LoopCond Exit 243 244 # Node that elements in the list below correspond one-to-one with each 245 # other. i.e. these lists are the same size, and the i_th entry corresponds 246 # to different Operations/Tensors of a single cycle as illustrated above. 247 # List of Switch ops (ops.Operation) that feed into an Exit Node. 248 self._exit_switches = [] 249 # List of inputs (ops.Tensor) to NextIteration. 250 self._body_outputs = [] 251 # List of list of control inputs of the NextIteration nodes. 252 self._next_iter_control_inputs = [] 253 # List of Merge ops (ops.Operation). 254 self._enter_merges = [] 255 # List of output (ops.Tensor) of Exit nodes. 256 self._outputs = [] 257 258 # List of Enter Tensors. 259 # There are two types of Enter nodes: 260 # - The Enter nodes that are used in the `loop_vars` argument to 261 # `while_loop` (see 262 # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect 263 # these Enter nodes immediately below by tracing backwards from the Exit 264 # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the 265 # diagram above. This allows us to have a 1:1 correspondence between the 266 # self._outputs and the first elements in self._enters. 267 # - The Enter nodes that are used only by the body. They don't appear in the 268 # `loop_vars` and are not returned from the `while_loop`. In Python code, 269 # they are usually captured by the body lambda. We collect them below by 270 # iterating over all the ops in the graph. They are appended to the end of 271 # self._enters or self._direct_enters, and don't correspond to any outputs 272 # in self._outputs. Note that we keep the resource/variant Enter nodes in 273 # self._direct_enters and the constructed while_loop's body uses them 274 # directly as opposed to passing them as loop variables. This is done 275 # because the while_body cannot partition the resource/variant Tensors, so 276 # it has to leave them unchanged. 277 self._enters = [] 278 self._direct_enters = [] 279 280 for e in self._while_context.loop_exits: 281 self._outputs.append(e.op.outputs[0]) 282 switch = e.op.inputs[0].op 283 assert switch.type == "Switch", switch 284 self._exit_switches.append(switch) 285 merge = switch.inputs[0].op 286 assert merge.type == "Merge", merge 287 self._enter_merges.append(merge) 288 enter = merge.inputs[0].op 289 assert enter.type == "Enter", enter 290 self._enters.append(enter.outputs[0]) 291 next_iter = merge.inputs[1].op 292 assert next_iter.type == "NextIteration", next_iter 293 self._body_outputs.append(next_iter.inputs[0]) 294 self._next_iter_control_inputs.append(next_iter.control_inputs) 295 296 # Collect all the Enter nodes that are not part of `loop_vars`, the second 297 # category described above. 298 # Also track whether the loop body has any stateful ops. 299 self._is_stateful = False 300 for op in ops.get_default_graph().get_operations(): 301 # TODO(agarwal): make sure this works with nested case. 302 control_flow_context = op._get_control_flow_context() 303 if control_flow_context is None: 304 continue 305 if control_flow_context.name == self._context_name: 306 self._is_stateful |= _is_stateful_pfor_op(op) 307 if op.type == "Enter": 308 output = op.outputs[0] 309 if output not in self._enters: 310 if output.dtype in (dtypes.resource, dtypes.variant): 311 if output not in self._direct_enters: 312 self._direct_enters.append(output) 313 else: 314 self._enters.append(output) 315 316 def __str__(self): 317 """String representation.""" 318 return "while_loop(%s)" % self.name 319 320 @property 321 def inputs(self): 322 """Input to all the Enter nodes.""" 323 return [x.op.inputs[0] for x in self._enters + self._direct_enters] 324 325 @property 326 def control_inputs(self): 327 """Control input to all the Enter nodes.""" 328 control_inputs = [] 329 for x in self._enters + self._direct_enters: 330 control_inputs.extend(x.op.control_inputs) 331 return control_inputs 332 333 @property 334 def outputs(self): 335 """Outputs of all the Exit nodes.""" 336 return self._outputs 337 338 @property 339 def name(self): 340 """Context name for the while loop.""" 341 return self._context_name 342 343 @property 344 def is_inside_loop(self): 345 """Returns true if the while_loop was created inside the pfor.""" 346 return self._is_inside_loop 347 348 def op_is_inside_loop(self, op): 349 """True if op was created inside the pfor loop body.""" 350 assert isinstance(op, ops.Operation) 351 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops 352 # since it appears there tensorflow API could return different python 353 # objects representing the same Operation node. 354 return op._id in self._pfor_op_ids 355 356 @property 357 def is_stateful(self): 358 return self._is_stateful 359 360 @property 361 def pfor_converter(self): 362 """Return a converter for the while loop.""" 363 return self 364 365 def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs, 366 inputs_stacked): 367 """Create a PFor object for converting parts of the while_loop. 368 369 Args: 370 parent_pfor: PFor object being used for converting the while_loop. 371 indices: int32 Tensor of ids for the iterations that are still active 372 (i.e. did not exit the while_loop). 373 cond_stacked: True if the while_loop condition is stacked. 374 inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note 375 that these Tensors are a subset of the loop variables for the generated 376 while_loop. 377 inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`, 378 indicating if the value is stacked or not. 379 380 Returns: 381 A PFor instance. The instance is initialized by adding conversion mappings 382 of nodes that will be external to the conversion that the returned 383 instance will be used for. e.g. Enter nodes as well as Merge and Switch 384 outputs are mapped to converted values. 385 """ 386 num_outputs = len(self._outputs) 387 assert len(inputs) == len(self._enters) 388 assert len(inputs_stacked) == len(self._enters) 389 loop_var = parent_pfor.loop_var 390 loop_len = array_ops.size(indices) 391 pfor = PFor( 392 loop_var, 393 loop_len, 394 pfor_ops=self._pfor_ops, 395 all_indices=indices, 396 all_indices_partitioned=cond_stacked, 397 fallback_to_while_loop=self._fallback_to_while_loop, 398 pfor_config=self._pfor_config) 399 # Map all inputs of Enter nodes in self._direct_enters to their converted 400 # values. 401 for enter in self._direct_enters: 402 enter_input = enter.op.inputs[0] 403 converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper( 404 enter_input) 405 # Since these are resources / variants, they should be unstacked. 406 assert not stacked and not is_sparse_stacked, (enter, converted_enter) 407 pfor._add_conversion(enter, wrap(converted_enter, False)) 408 409 # Map all Enter nodes to the inputs. 410 for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked): 411 pfor._add_conversion(enter, wrap(inp, stacked)) 412 # Map outputs of Switch and Merge. 413 for i in range(num_outputs): 414 wrapped_inp = wrap(inputs[i], inputs_stacked[i]) 415 merge = self._enter_merges[i] 416 pfor._add_conversion(merge.outputs[0], wrapped_inp) 417 # Note that second output of Merge is typically not used, except possibly 418 # as a control dependency. To avoid trying to output the correct value, we 419 # employ a hack here. We output a dummy invalid value with an incorrect 420 # dtype. This will allow control dependency to work but if using it as an 421 # input, it should typically lead to errors during graph construction due 422 # to dtype mismatch. 423 # TODO(agarwal): Check in the original graph to see if there are any 424 # consumers of this Tensor that use it as an input. 425 pfor._add_conversion(merge.outputs[1], 426 wrap(constant_op.constant(-1.0), False)) 427 switch = self._exit_switches[i] 428 # Don't need to worry about switch.output[0] which will feed to Exit node. 429 pfor._add_conversion(switch.outputs[1], wrapped_inp) 430 return pfor 431 432 def _convert_enter(self, parent_pfor, enter): 433 """Converts an Enter node.""" 434 inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0]) 435 control_inputs = [] 436 for x in enter.op.control_inputs: 437 converted = parent_pfor._convert_helper(x) 438 if not isinstance(converted, ops.Operation): 439 converted = converted.t 440 control_inputs.append(converted) 441 if control_inputs: 442 with ops.control_dependencies(control_inputs): 443 inp = array_ops.identity(inp) 444 return inp, stacked 445 446 def _maybe_stacked(self, cache, inp): 447 """Heuristic to figure out if the converting inp leads to a stacked value. 448 449 450 Args: 451 cache: map from Tensor to boolean indicating stacked/unstacked. 452 inp: input Tensor. 453 454 Returns: 455 True if `inp` could get stacked. If the function returns False, the 456 converted value should be guaranteed to be unstacked. If returning True, 457 it may or may not be stacked. 458 """ 459 if inp in cache: 460 return cache[inp] 461 if not self.op_is_inside_loop(inp.op): 462 return False 463 op = inp.op 464 output = False 465 if op.type in [ 466 "Shape", 467 "Rank", 468 "ShapeN", 469 "ZerosLike", 470 "TensorArrayV3", 471 "TensorArraySizeV3", 472 ]: 473 output = False 474 elif _is_stateful_pfor_op(op): 475 # This may be fairly aggressive. 476 output = True 477 elif op.type == "Exit": 478 # This may be fairly aggressive. 479 output = True 480 else: 481 for t in op.inputs: 482 if self._maybe_stacked(cache, t): 483 output = True 484 break 485 cache[inp] = output 486 return output 487 488 def _create_init_values(self, pfor_input): 489 """Create arguments passed to converted while_loop.""" 490 with ops.name_scope("while_init"): 491 loop_len_vector = pfor_input.pfor.loop_len_vector 492 loop_len = loop_len_vector[0] 493 num_outputs = len(self._outputs) 494 495 inputs = [] 496 maybe_stacked_cache = {} 497 # Convert all the Enters. Need to do this before checking for stacking 498 # below. 499 for i, enter in enumerate(self._enters): 500 inp, stacked = self._convert_enter(pfor_input.pfor, enter) 501 inputs.append(inp) 502 maybe_stacked_cache[enter] = stacked 503 # Since this enter node is part of the `loop_vars`, it corresponds to an 504 # output and its preceding switch. We mark this switch's output the same 505 # stackness, to act at the base case for the logic below. Below, we will 506 # be going through the body figuring out which inputs might need to be 507 # stacked and which inputs can safely remain unstacked. 508 if i < num_outputs: 509 maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked 510 511 # Shape invariants for init_values corresponding to self._enters. 512 input_shape_invariants = [] 513 # TensorArrays for outputs of converted while loop 514 output_tas = [] 515 # Shape invariants for output TensorArrays. 516 ta_shape_invariants = [] 517 # List of booleans indicating stackness of inputs, i.e. tensors 518 # corresponding to self._enters. 519 inputs_stacked = [] 520 for i, inp in enumerate(inputs): 521 enter = self._enters[i] 522 inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter) 523 # Note that even when an input is unstacked, the body could make it 524 # stacked. we use a heuristic below to figure out if body may be making 525 # it stacked. 526 if i < num_outputs: 527 body_output = self._body_outputs[i] 528 if enter.op in self._pfor_ops: 529 body_output_stacked = self._maybe_stacked(maybe_stacked_cache, 530 body_output) 531 else: 532 # If constructed outside of pfor loop, then the output would not be 533 # stacked. 534 body_output_stacked = False 535 if body_output_stacked and not inp_stacked: 536 inp = _stack(inp, loop_len_vector).t 537 inputs[i] = inp 538 inp_stacked = True 539 # TODO(agarwal): other attributes for the TensorArray ? 540 output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len)) 541 ta_shape_invariants.append(tensor_shape.TensorShape(None)) 542 543 inputs_stacked.append(inp_stacked) 544 input_shape_invariants.append(tensor_shape.TensorShape(None)) 545 546 # See documentation for __call__ for the structure of init_values. 547 init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas 548 # TODO(agarwal): try stricter shape invariants 549 shape_invariants = ( 550 [tensor_shape.TensorShape(None), 551 tensor_shape.TensorShape(None)] + input_shape_invariants + 552 ta_shape_invariants) 553 554 return init_values, inputs_stacked, shape_invariants 555 556 def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): 557 """Handles case when condition is unstacked. 558 559 Note that all iterations end together. So we don't need to partition the 560 inputs. When all iterations are done, we write the inputs to the 561 TensorArrays. Note that we only write to index 0 of output_tas. Since all 562 iterations end together, they can all be output together. 563 """ 564 not_all_done = array_ops.reshape(conditions, []) 565 new_output_tas = [] 566 # pylint: disable=cell-var-from-loop 567 for i, out_ta in enumerate(output_tas): 568 inp = inputs[i] 569 new_output_tas.append( 570 control_flow_ops.cond(not_all_done, lambda: out_ta, 571 lambda: out_ta.write(0, inp))) 572 # pylint: enable=cell-var-from-loop 573 return not_all_done, indices, inputs, new_output_tas 574 575 def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, 576 output_tas): 577 num_outputs = len(self._outputs) 578 # Compute if all iterations are done. 579 not_all_done = math_ops.reduce_any(conditions) 580 conditions_int = math_ops.cast(conditions, dtypes.int32) 581 # Partition the indices. 582 done_indices, new_indices = data_flow_ops.dynamic_partition( 583 indices, conditions_int, 2) 584 585 new_inputs = [] 586 new_output_tas = [] 587 for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): 588 # Partition the inputs. 589 if stacked: 590 done_inp, new_inp = data_flow_ops.dynamic_partition( 591 inp, conditions_int, 2) 592 else: 593 # TODO(agarwal): avoid this stacking. See TODO earlier in 594 # _process_cond_unstacked. 595 done_inp = _stack(inp, [array_ops.size(done_indices)]).t 596 new_inp = inp 597 new_inputs.append(new_inp) 598 # For iterations that are done, write them to TensorArrays. 599 if i < num_outputs: 600 out_ta = output_tas[i] 601 # Note that done_indices can be empty. done_inp should also be empty in 602 # that case. 603 new_output_tas.append(out_ta.scatter(done_indices, done_inp)) 604 return not_all_done, new_indices, new_inputs, new_output_tas 605 606 def _process_body(self, pfor_input, inputs_stacked, new_indices, cond_stacked, 607 new_inputs, not_all_done): 608 """Convert the body function.""" 609 610 def true_fn(control_inputs, body_pfor, body_output, stacked): 611 """Converts the body function for all but last iteration. 612 613 This essentially converts body_output. Additionally, it needs to handle 614 any control dependencies on the NextIteration node. So it creates another 615 Identity node with the converted dependencies. 616 """ 617 converted_control_inp = [] 618 for x in control_inputs: 619 for t in x.outputs: 620 converted_control_inp.append(body_pfor._convert_helper(t).t) 621 if stacked: 622 # Note convert always does the stacking. 623 output = body_pfor.convert(body_output) 624 else: 625 output, convert_stacked, _ = body_pfor._convert_helper(body_output) 626 assert convert_stacked == stacked, body_output 627 with ops.control_dependencies(converted_control_inp): 628 return array_ops.identity(output) 629 630 body_pfor = self._init_pfor(pfor_input.pfor, new_indices, cond_stacked, 631 new_inputs, inputs_stacked) 632 new_outputs = [] 633 634 for i, (body_output, 635 stacked) in enumerate(zip(self._body_outputs, inputs_stacked)): 636 control_inp = self._next_iter_control_inputs[i] 637 out_dtype = body_output.dtype 638 # Note that we want to run the body only if not all pfor iterations are 639 # done. If all are done, we return empty tensors since these values will 640 # not be used. Notice that the value returned by the loop is based on 641 # TensorArrays and not directly on these returned values. 642 # pylint: disable=cell-var-from-loop 643 new_output = control_flow_ops.cond( 644 not_all_done, 645 lambda: true_fn(control_inp, body_pfor, body_output, stacked), 646 lambda: constant_op.constant([], dtype=out_dtype)) 647 # pylint: enable=cell-var-from-loop 648 new_outputs.append(new_output) 649 return new_outputs 650 651 def __call__(self, pfor_input): 652 """Converter for the while_loop. 653 654 The conversion of a while_loop is another while_loop. 655 656 The arguments to this converted while_loop are as follows: 657 not_all_done: Boolean scalar Tensor indicating if all the pfor iterations 658 are done. 659 indices: int32 1-D Tensor storing the id of the iterations that are not 660 done. 661 args: Remaining arguments. These can be divided into 3 categories: 662 - First set of arguments are the tensors that correspond to the initial 663 elements of self._enters. The elements that appear in original while 664 loop's `loop_vars`. 665 - The second set of arguments are the tensors that correspond to the 666 remaining elements of self._enters. These are the tensors that directly 667 enter the original while loop body. 668 - Finally, the last set of arguments are TensorArrays. These TensorArrays 669 correspond to the outputs of the original while_loop, i.e. to the 670 elements in self._outputs. Each TensorArray has `PFor.loop_len` 671 elements, i.e. the number of pfor iterations. At the end, the i'th 672 element of each TensorArray will contain the output computed by the 673 i'th iteration of pfor. Note that elements can be written into these 674 tensors arrays in any order, depending on when the corresponding pfor 675 iteration is done. 676 If the original while_loop had `k` tensors in its `loop_vars` and its body 677 directly captured `m` tensors, the `args` will contain `2 * k + m` values. 678 679 In each iteration, the while_loop body recomputes the condition for all 680 active pfor iterations to see which of them are now done. It then partitions 681 all the inputs and passes them along to the converted body. Values for all 682 the iterations that are done are written to TensorArrays indexed by the pfor 683 iteration number. When all iterations are done, the TensorArrays are stacked 684 to get the final value. 685 686 Args: 687 pfor_input: A PForInput object corresponding to the output of any Exit 688 node from this while loop. 689 690 Returns: 691 List of converted outputs. 692 """ 693 # Create init_values that will be passed to the while_loop. 694 init_values, inputs_stacked, shape_invariants = self._create_init_values( 695 pfor_input) 696 # Note that we use a list as a hack since we need the nested function body 697 # to set the value of cond_is_stacked. python2.x doesn't support nonlocal 698 # variables. 699 cond_is_stacked = [None] 700 701 def cond(not_all_done, *_): 702 return not_all_done 703 704 def body(not_all_done, indices, *args): 705 # See documentation for __call__ for the structure of *args. 706 num_enters = len(self._enters) 707 inputs = args[:num_enters] 708 output_tas = args[num_enters:] 709 # TODO(agarwal): see which outputs have consumers and only populate the 710 # TensorArrays corresponding to those. Or do those paths get trimmed out 711 # from inside the while_loop body? 712 assert len(inputs) >= len(output_tas) 713 assert len(inputs) == len(inputs_stacked) 714 715 # Convert condition 716 with ops.name_scope("while_cond"): 717 # Note that we set cond_stacked to True here. At this point we don't 718 # know if it could be loop invariant, hence the conservative value is 719 # to assume stacked. 720 cond_pfor = self._init_pfor( 721 pfor_input.pfor, 722 indices, 723 cond_stacked=True, 724 inputs=inputs, 725 inputs_stacked=inputs_stacked) 726 conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition) 727 cond_is_stacked[0] = cond_stacked 728 729 # Recompute the new condition, write outputs of done iterations, and 730 # partition the inputs if needed. 731 if not cond_stacked: 732 (not_all_done, new_indices, new_inputs, 733 new_output_tas) = self._process_cond_unstacked(conditions, indices, 734 inputs, output_tas) 735 else: 736 (not_all_done, new_indices, new_inputs, 737 new_output_tas) = self._process_cond_stacked(conditions, indices, 738 inputs, inputs_stacked, 739 output_tas) 740 741 # Convert body 742 with ops.name_scope("while_body"): 743 # Compute the outputs from the body. 744 new_outputs = self._process_body(pfor_input, inputs_stacked, 745 new_indices, cond_stacked, new_inputs, 746 not_all_done) 747 748 # Note that the first num_outputs new values of inputs are computed using 749 # the body. Rest of them were direct Enters into the condition/body and 750 # the partitioning done earlier is sufficient to give the new value. 751 num_outputs = len(self._outputs) 752 new_args = ([not_all_done, new_indices] + new_outputs + 753 list(new_inputs[num_outputs:]) + new_output_tas) 754 return tuple(new_args) 755 756 while_outputs = control_flow_ops.while_loop( 757 cond, body, init_values, shape_invariants=shape_invariants) 758 output_tas = while_outputs[-len(self._outputs):] 759 outputs = [] 760 assert cond_is_stacked[0] is not None 761 for inp_stacked, ta in zip(inputs_stacked, output_tas): 762 if cond_is_stacked[0]: 763 outputs.append(wrap(ta.stack(), True)) 764 else: 765 # Note that if while_loop condition is unstacked, all iterations exit at 766 # the same time and we wrote those outputs in index 0 of the tensor 767 # array. 768 outputs.append(wrap(ta.read(0), inp_stacked)) 769 return outputs 770 771 772class ConversionNotImplementedError(Exception): 773 pass 774 775 776class _PforInput(object): 777 """Input object passed to registered pfor converters.""" 778 779 __slots__ = ["pfor", "_op", "_inputs"] 780 781 def __init__(self, pfor, op, inputs): 782 """Creates a _PforInput object. 783 784 Args: 785 pfor: PFor converter object. 786 op: the Operation object that is being converted. 787 inputs: list of WrappedTensor objects representing converted values of the 788 inputs of `op`. 789 """ 790 self.pfor = pfor 791 self._op = op 792 self._inputs = inputs 793 794 def stack_inputs(self, stack_indices=None, tile_variants=False): 795 """Stacks unstacked inputs at `stack_indices`. 796 797 Args: 798 stack_indices: indices of inputs at which stacking is done. If None, 799 stacking is done at all indices. 800 tile_variants: If True, affected indices which have a variant dtype will 801 be tiled after this operation to match the expected shape of a 802 vectorized tensor. Variants generally need to be un-tiled when they are 803 inputs to operations and tiled when returned. 804 """ 805 if stack_indices is None: 806 stack_indices = range(len(self._inputs)) 807 length = self.pfor.loop_len_vector 808 for i in stack_indices: 809 inp = self._inputs[i] 810 is_variant = inp.t.dtype == dtypes.variant 811 if not inp.is_stacked: 812 self._inputs[i] = _stack(inp.t, length) 813 if tile_variants and is_variant: 814 self._inputs[i] = wrap( 815 _tile_variant_with_length(self._inputs[i].t, length), True) 816 elif not tile_variants and is_variant: 817 self._inputs[i] = wrap(_untile_variant(self._inputs[i].t), True) 818 819 def expanddim_inputs_for_broadcast(self): 820 """Reshapes stacked inputs to prepare them for broadcast. 821 822 Since stacked inputs have an extra leading dimension, automatic broadcasting 823 rules could incorrectly try to expand dimensions before that leading 824 dimension. To avoid that, we reshape these stacked inputs to the maximum 825 rank they will need to be broadcasted to. 826 """ 827 if not self._inputs: 828 return 829 830 # Find max rank 831 def _get_rank(x): 832 rank = array_ops.rank(x.t) 833 if not x.is_stacked: 834 rank += 1 835 return rank 836 837 ranks = [_get_rank(x) for x in self._inputs] 838 max_rank = ranks[0] 839 for rank in ranks[1:]: 840 max_rank = math_ops.maximum(rank, max_rank) 841 842 for i, inp in enumerate(self._inputs): 843 if inp.is_stacked: 844 shape = array_ops.shape(inp.t) 845 rank_diff = array_ops.reshape(max_rank - ranks[i], [1]) 846 ones = array_ops.tile([1], rank_diff) 847 new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0) 848 self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True) 849 850 @property 851 def inputs(self): 852 return self._inputs 853 854 @property 855 def num_inputs(self): 856 return len(self._inputs) 857 858 def input(self, index): 859 assert len(self._inputs) > index, (index, self._inputs) 860 return self._inputs[index] 861 862 def stacked_input(self, index): 863 t, is_stacked, _ = self.input(index) 864 if not is_stacked: 865 op_type = self.op_type 866 op_def = getattr(self._op, "op_def", None) 867 if op_def is None: 868 input_name = "at index %d" % index 869 else: 870 input_name = "\"%s\"" % op_def.input_arg[index].name 871 raise ConversionNotImplementedError( 872 "Input %s of op \"%s\" expected to be not loop invariant" % 873 (input_name, op_type)) 874 return t 875 876 def unstacked_input(self, index): 877 t, is_stacked, _ = self.input(index) 878 if is_stacked: 879 op_type = self.op_type 880 op_def = getattr(self._op, "op_def", None) 881 if op_def is None: 882 input_name = "at index %d" % index 883 else: 884 input_name = "\"%s\"" % op_def.input_arg[index].name 885 raise ConversionNotImplementedError( 886 "Input %s of op \"%s\" expected to be loop invariant" % 887 (input_name, op_type)) 888 return t 889 890 @property 891 def op(self): 892 return self._op 893 894 @property 895 def op_type(self): 896 return self._op.type 897 898 def get_attr(self, attr): 899 return self._op.get_attr(attr) 900 901 @property 902 def outputs(self): 903 return self._op.outputs 904 905 def output(self, index): 906 assert index < len(self._op.outputs) 907 return self._op.outputs[index] 908 909 910_pfor_converter_registry = {} 911 912 913class RegisterPFor(object): 914 """Utility to register converters for pfor. 915 916 Usage: 917 @RegisterPFor(foo_op_type) 918 def _foo_converter(pfor_input): 919 ... 920 921 The above will register conversion function `_foo_converter` for handling 922 conversion of `foo_op_type`. These converters are called during vectorization 923 of a `pfor` loop body. For each operation node in this loop body, 924 the vectorization process will call the converter corresponding to the 925 operation type of the node. 926 927 During conversion, the registered function will be called with a single 928 argument `pfor_input`, of type `PForInput`, which will contain state needed 929 for the conversion. When the converter is called for a node, all its inputs 930 should already have been converted and these converted values are stored in 931 `pfor_input.inputs`. This registered function should output a list of 932 WrappedTensor objects with the same length as the number of outputs of the 933 node being converted. If the node had zero outputs, then it should return an 934 ops.Operation object. These new sets of nodes should implement the 935 functionality of running that operation for the number of iterations specified 936 by `pfor_input.pfor.loop_len_vector[0]` where the inputs of the node for each 937 iteration are picked from `pfor_inputs.inputs()`. 938 939 One tricky aspect of the conversion process is keeping track of, and 940 leveraging loop invariance of computation. Each converted input is a 941 WrappedTensor which indicates whether the input was loop invariant or not. If 942 the converted value is loop invariant, its rank should match the rank of the 943 corresponding tensor in the loop body, else its rank is larger by 1. The 944 converter should look at the loop invariance of the inputs and generate new 945 nodes based on that. Note that the converter will not be called if all inputs 946 are loop invariant and the operation is not stateful. The converter should 947 determine if its own output is loop invariant and `wrap` its output 948 accordingly. 949 950 Example: 951 952 Here, the converter is trying to convert a Reshape node in the loop body. This 953 node will have two inputs: the tensor to reshape, and the new shape. The 954 example here only handles the case where the shape is loop invariant. 955 956 @RegisterPFor("Reshape") 957 def _convert_reshape(pfor_input): 958 # We assume that input is not loop invariant. Call to `stacked_input` 959 # asserts that and returns the converted value. This value will have a rank 960 # larger by 1 compared to the rank of the input in the loop body. 961 t = pfor_input.stacked_input(0) 962 963 # We assume that shape input is loop invariant. Call to `unstacked_input` 964 # asserts that and returns the converted value. 965 shape = pfor_input.unstacked_input(1) 966 967 # We compute `new_shape` by prepending the number of iterations to the 968 # original shape. 969 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], 970 axis=0) 971 972 # The vectorized output involves reshaping the converted input `t` using 973 # `new_shape`. 974 new_output = array_ops.reshape(t, new_shape) 975 976 # The converted output is marked as not loop invariant using the call to 977 # wrap. 978 return wrap(new_output, True) 979 """ 980 981 def __init__(self, op_type): 982 """Creates an object to register a converter for op with type `op_type`.""" 983 self.op_type = op_type 984 985 def __call__(self, converter): 986 name = self.op_type 987 assert name not in _pfor_converter_registry, "Re-registering %s " % name 988 _pfor_converter_registry[name] = converter 989 return converter 990 991 992class RegisterPForWithArgs(RegisterPFor): 993 """Utility to register converters for pfor. 994 995 Usage: 996 @RegisteRPFor(foo_op_type, foo=value, ....) 997 def _foo_converter(pfor_input, foo=None, ....): 998 ... 999 1000 See RegisterPFor for details on the conversion function. 1001 `RegisterPForWithArgs` allows binding extra arguments to the 1002 conversion function at registration time. 1003 """ 1004 1005 def __init__(self, op_type, *args, **kw_args): 1006 super(RegisterPForWithArgs, self).__init__(op_type) 1007 self._args = args 1008 self._kw_args = kw_args 1009 1010 def __call__(self, converter): 1011 1012 def _f(pfor_input): 1013 return converter(pfor_input, self.op_type, *self._args, **self._kw_args) 1014 1015 super(RegisterPForWithArgs, self).__call__(_f) 1016 return converter 1017 1018 1019# TODO(agarwal): call raw_ops instead of calling these low level routines. 1020def _create_op(op_type, inputs, op_dtypes, attrs=None): 1021 """Utility to create an op.""" 1022 op = ops.get_default_graph().create_op( 1023 op_type, inputs, op_dtypes, attrs=attrs, compute_device=True) 1024 flat_attrs = [] 1025 # The tape expects an alternating flat list of names and attribute values. 1026 for a in attrs: 1027 flat_attrs.append(str(a)) 1028 flat_attrs.append(op.get_attr(str(a))) 1029 execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:]) 1030 return op 1031 1032 1033WrappedTensor = collections.namedtuple("WrappedTensor", 1034 ["t", "is_stacked", "is_sparse_stacked"]) 1035"""Wrapper around the result of a Tensor conversion. 1036 1037The additional fields are useful for keeping track of the conversion state as 1038data flows through the ops in the loop body. For every op whose output is a 1039Tensor, its converter should return either a WrappedTensor or a list of 1040WrappedTensors. 1041 1042Args: 1043 t: The converted tensor 1044 is_stacked: True if the tensor is stacked, i.e. represents the results of all 1045 the iterations of the loop, where each row i of the tensor corresponds to 1046 that op's output on iteration i of the loop. False if the tensor is not 1047 stacked, i.e. represents the result of the op on of a single iteration of 1048 the loop, where the result does not vary between iterations. 1049 is_sparse_stacked: True if the tensor corresponds to a component tensor 1050 (indices, values, or dense_shape) of a sparse tensor, and has been logically 1051 stacked via a sparse conversion. 1052""" 1053 1054 1055def wrap(tensor, is_stacked=True, is_sparse_stacked=False): 1056 """Helper to create a WrappedTensor object.""" 1057 assert isinstance(is_stacked, bool) 1058 assert isinstance(is_sparse_stacked, bool) 1059 assert isinstance(tensor, ops.Tensor) 1060 assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is " 1061 "stacked via a sparse " 1062 "conversion, it must also be " 1063 "stacked.") 1064 return WrappedTensor(tensor, is_stacked, is_sparse_stacked) 1065 1066 1067def _wrap_and_tile_variants(tensor, length): 1068 if tensor.dtype == dtypes.variant: 1069 tensor = _tile_variant_with_length(tensor, length) 1070 return wrap(tensor) 1071 1072 1073def _fallback_converter(pfor_input, warn=True): 1074 if warn: 1075 logging.warn("Using a while_loop for converting %s", pfor_input.op_type) 1076 output_dtypes = [x.dtype for x in pfor_input.outputs] 1077 iters = pfor_input.pfor.loop_len_vector[0] 1078 1079 def while_body(i, *ta_list): 1080 """Body of while loop.""" 1081 inputs = [ 1082 x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs 1083 ] 1084 op_outputs = _create_op( 1085 pfor_input.op_type, 1086 inputs, 1087 output_dtypes, 1088 attrs=pfor_input.op.node_def.attr).outputs 1089 1090 outputs = [] 1091 # TODO(agarwal): Add tf.debugging asserts to check that the shapes across 1092 # the different iterations are the same. 1093 for out, ta in zip(op_outputs, ta_list): 1094 assert isinstance(out, ops.Tensor) 1095 outputs.append(ta.write(i, array_ops.expand_dims(out, 0))) 1096 return tuple([i + 1] + outputs) 1097 1098 ta_list = control_flow_ops.while_loop( 1099 lambda i, *ta: i < iters, while_body, [0] + 1100 [tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes 1101 ])[1:] 1102 return tuple([wrap(ta.concat(), True) for ta in ta_list]) 1103 1104 1105class PForConfig(object): 1106 """A configuration object used to communicate with loop body function.""" 1107 1108 def __init__(self): 1109 # This may be set to the number of iterations. 1110 self._maybe_iters = None 1111 # Map from reduction node, created by `reduce`, to the bundle of reduction 1112 # function and arguments. 1113 self._reduce_map = {} 1114 1115 def _has_reductions(self): 1116 """True if some reductions where performed by loop body.""" 1117 return len(self._reduce_map) 1118 1119 def _set_iters(self, iters): 1120 """Set number of pfor iterations.""" 1121 if isinstance(iters, ops.Tensor): 1122 iters = tensor_util.constant_value(iters) 1123 self._maybe_iters = iters 1124 1125 def reduce(self, fn, *args): 1126 """Performs reduction `fn` on `args` vectorized across pfor iterations. 1127 1128 Note that `fn` is traced once inside the loop function context. Hence any 1129 captures or side-effects will happen in that context. Call to the traced 1130 version of `fn` happens during the construction of the vectorized code. 1131 1132 Note that this currently may not work inside a control flow construct. 1133 Args: 1134 fn: a reduction function. It will be called with arguments that have the 1135 same structure as *args but with individual values whose rank may be 1136 higher by 1 since they represent loop invariant vectorized versions of 1137 the corresponding Tensors in *args. 1138 *args: unvectorized Tensors. 1139 1140 Returns: 1141 The result of running `fn` on the vectorized versions of `*args`. These 1142 outputs will be available as loop invariant values to all the iterations. 1143 """ 1144 assert not context.executing_eagerly() 1145 # Creates a concrete function that will be used for reduction. 1146 tensor_specs = [] 1147 for arg in args: 1148 if not isinstance(arg, ops.Tensor): 1149 raise ValueError("Got a non-Tensor argument %s in reduce" % arg) 1150 batched_shape = tensor_shape.TensorShape([self._maybe_iters 1151 ]).concatenate(arg.shape) 1152 tensor_specs.append( 1153 tensor_spec.TensorSpec(shape=batched_shape, dtype=arg.dtype)) 1154 concrete_function = def_function.function(fn).get_concrete_function( 1155 *tensor_specs) 1156 1157 # Creates PlaceholderWithDefault and IdentityN nodes corresponding the 1158 # reduction. 1159 pl_outputs = [] 1160 with ops.control_dependencies(args): 1161 for output in concrete_function.outputs: 1162 if not isinstance(output, ops.Tensor): 1163 raise ValueError("Got a non-Tensor output %s while running reduce" % 1164 output) 1165 # Note that we use placeholder_with_default just to make XLA happy since 1166 # it does not like placeholder ops. 1167 if output.shape.is_fully_defined(): 1168 dummy = array_ops.zeros(output.shape.as_list(), dtype=output.dtype) 1169 pl_outputs.append( 1170 array_ops.placeholder_with_default(dummy, shape=output.shape)) 1171 else: 1172 # TODO(agarwal): support case when under XLA and output.shape is not 1173 # fully defined. 1174 pl_outputs.append( 1175 array_ops.placeholder(output.dtype, shape=output.shape)) 1176 1177 reduction_op = array_ops.identity_n(pl_outputs)[0].op 1178 self._reduce_map[reduction_op] = (concrete_function, args) 1179 if len(reduction_op.outputs) == 1: 1180 return reduction_op.outputs[0] 1181 else: 1182 return tuple(reduction_op.outputs) 1183 1184 # TODO(agarwal): handle reductions inside control flow constructs. 1185 def reduce_concat(self, x): 1186 """Performs a concat reduction on `x` across pfor iterations. 1187 1188 Note that this currently may not work inside a control flow construct. 1189 Args: 1190 x: an unvectorized Tensor. 1191 1192 Returns: 1193 A Tensor that has rank one higher than `x`. The value is the vectorized 1194 version of `x`, i.e. stacking the value of `x` across different pfor 1195 iterations. 1196 """ 1197 return self.reduce(lambda y: y, x) 1198 1199 def reduce_mean(self, x): 1200 """Performs a mean reduction on `x` across pfor iterations. 1201 1202 Note that this currently may not work inside a control flow construct. 1203 Args: 1204 x: an unvectorized Tensor. 1205 1206 Returns: 1207 A Tensor that has same rank as `x`. The value is the mean of the values 1208 of `x` across the pfor iterations. 1209 """ 1210 return self.reduce(lambda y: math_ops.reduce_mean(y, axis=0), x) 1211 1212 def reduce_sum(self, x): 1213 """Performs a sum reduction on `x` across pfor iterations. 1214 1215 Note that this currently may not work inside a control flow construct. 1216 Args: 1217 x: an unvectorized Tensor. 1218 1219 Returns: 1220 A Tensor that has same rank as `x`. The value is the sum of the values 1221 of `x` across the pfor iterations. 1222 """ 1223 return self.reduce(lambda y: math_ops.reduce_sum(y, axis=0), x) 1224 1225 def _lookup_reduction(self, t): 1226 """Lookups Tensor `t` in the reduction maps.""" 1227 assert isinstance(t, ops.Tensor), t 1228 return self._reduce_map.get(t.op) 1229 1230 1231class PFor(object): 1232 """Implementation of rewrite of parallel-for loops. 1233 1234 This class takes a DAG or a set of DAGs representing the body of a 1235 parallel-for loop, and adds new operations to the graph that implements 1236 functionality equivalent to running that loop body for a specified number of 1237 iterations. This new set of nodes may or may not use a tensorflow loop 1238 construct. 1239 1240 The process of conversion does not delete or change any existing operations. 1241 It only adds operations that efficiently implement the equivalent 1242 functionality. We refer to the added ops as "converted ops". 1243 1244 The conversion process uses a simple greedy heuristic. It walks the loop body 1245 and tries to express the functionality of running each node in a loop with a 1246 new set of nodes. When converting an op several cases are possible: 1247 - The op is not inside the loop body. Hence it can be used as is. 1248 - The op does not depend on the iteration number and is stateless. In this 1249 case, it can be used as is. 1250 - The op is not stateful, and depends on iteration number only through control 1251 dependencies. In this case, we can create a single op with same inputs and 1252 attributes, but with "converted" control dependencies. 1253 - The op is not stateful, and all its inputs are loop invariant. In this 1254 case, similar to above, we can create a single op with same inputs and 1255 attributes, but with "converted" control dependencies. 1256 - The op is stateful or at least one of the inputs is not loop invariant. In 1257 this case, we run the registered converter for that op to create a set of 1258 converted ops. All nodes in the set will have converted control dependencies 1259 corresponding to control dependencies of the original op. If the op returned 1260 multiple outputs, "converted outputs" could be produced by different ops in 1261 this set. 1262 """ 1263 1264 def __init__(self, 1265 loop_var, 1266 loop_len, 1267 pfor_ops, 1268 fallback_to_while_loop, 1269 all_indices=None, 1270 all_indices_partitioned=False, 1271 pfor_config=None): 1272 """Creates an object to rewrite a parallel-for loop. 1273 1274 Args: 1275 loop_var: ops.Tensor output of a Placeholder operation. The value should 1276 be an int32 scalar representing the loop iteration number. 1277 loop_len: A scalar or scalar Tensor representing the number of iterations 1278 the loop is run for. 1279 pfor_ops: List of all ops inside the loop body. 1280 fallback_to_while_loop: If True, on failure to vectorize an op, a while 1281 loop is used to sequentially execute that op. 1282 all_indices: If not None, an int32 vector with size `loop_len` 1283 representing the iteration ids that are still active. These values 1284 should be unique and sorted. However they may not be contiguous. This is 1285 typically the case when inside a control flow construct which has 1286 partitioned the indices of the iterations that are being converted. 1287 all_indices_partitioned: If True, this object is being constructed from a 1288 control flow construct where not all the pfor iterations are guaranteed 1289 to be active. 1290 pfor_config: PForConfig object used while constructing the loop body. 1291 """ 1292 assert isinstance(loop_var, ops.Tensor) 1293 assert loop_var.op.type == "PlaceholderWithDefault" 1294 self._loop_var = loop_var 1295 loop_len_value = tensor_util.constant_value(loop_len) 1296 if loop_len_value is not None: 1297 loop_len = loop_len_value 1298 self._loop_len_vector = array_ops.reshape(loop_len, [1]) 1299 self._all_indices_partitioned = all_indices_partitioned 1300 if all_indices_partitioned: 1301 assert all_indices is not None 1302 self.all_indices = ( 1303 math_ops.range(loop_len) if all_indices is None else all_indices) 1304 1305 self._conversion_map = object_identity.ObjectIdentityDictionary() 1306 self._conversion_map[loop_var] = wrap(self.all_indices, True) 1307 self._pfor_ops = set(pfor_ops) 1308 self._pfor_op_ids = set(x._id for x in pfor_ops) 1309 self._fallback_to_while_loop = fallback_to_while_loop 1310 self._pfor_config = pfor_config 1311 1312 def op_is_inside_loop(self, op): 1313 """True if op was created inside the pfor loop body.""" 1314 assert isinstance(op, ops.Operation) 1315 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops 1316 # since it appears there tensorflow API could return different python 1317 # objects representing the same Operation node. 1318 return op._id in self._pfor_op_ids 1319 1320 def _convert_sparse(self, y): 1321 """Returns the converted value corresponding to SparseTensor y. 1322 1323 For SparseTensors, instead of stacking the component tensors separately, 1324 resulting in component tensors with shapes (N, m, rank), (N, m), and (N, 1325 rank) respectively for indices, values, and dense_shape (where N is the loop 1326 length and m is the number of sparse tensor values per loop iter), we want 1327 to logically stack the SparseTensors, to create a SparseTensor whose 1328 components are size (N * m, rank + 1), (N * m, ), and (rank + 1,) 1329 respectively. 1330 1331 Here, we try to get the conversion of each component tensor. 1332 If the tensors are stacked via a sparse conversion, return the resulting 1333 SparseTensor composed of the converted components. Otherwise, the component 1334 tensors are either unstacked or stacked naively. In the latter case, we 1335 unstack the component tensors to reform loop_len SparseTensor elements, 1336 then correctly batch them. 1337 1338 The unstacked tensors must have the same rank. Each dimension of each 1339 SparseTensor will expand to be the largest among all SparseTensor elements 1340 for that dimension. For example, if there are N SparseTensors of rank 3 1341 being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i), 1342 the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)). 1343 1344 Args: 1345 y: A tf.sparse.SparseTensor. 1346 1347 Returns: 1348 A tf.sparse.SparseTensor that is the converted value corresponding to y. 1349 """ 1350 outputs = [ 1351 self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape) 1352 ] 1353 assert all(isinstance(o, WrappedTensor) for o in outputs) 1354 1355 if all(w.is_sparse_stacked for w in outputs): 1356 return sparse_tensor.SparseTensor(*[w.t for w in outputs]) 1357 1358 assert not any(w.is_sparse_stacked for w in outputs), ( 1359 "Error converting SparseTensor. All components should be logically " 1360 "stacked, or none.") 1361 1362 # If component tensors were not sparsely stacked, they are either unstacked 1363 # or stacked without knowledge that they are components of sparse tensors. 1364 # In this case, we have to restack them. 1365 return self._restack_sparse_tensor_logically( 1366 *[self._unwrap_or_tile(w) for w in outputs]) 1367 1368 def _restack_sparse_tensor_logically(self, indices, values, shape): 1369 sparse_tensor_rank = indices.get_shape().dims[-1].value 1370 if sparse_tensor_rank is not None: 1371 sparse_tensor_rank += 1 1372 1373 def fn(args): 1374 res = gen_sparse_ops.serialize_sparse( 1375 args[0], args[1], args[2], out_type=dtypes.variant) 1376 return res 1377 1378 # Applies a map function to the component tensors to serialize each 1379 # sparse tensor element and batch them all, then deserializes the batch. 1380 # TODO(rachelim): Try to do this without map_fn -- add the right offsets 1381 # to shape and indices tensors instead. 1382 result = map_fn.map_fn(fn, [indices, values, shape], dtype=dtypes.variant) 1383 return sparse_ops.deserialize_sparse( 1384 result, dtype=values.dtype, rank=sparse_tensor_rank) 1385 1386 def _unwrap_or_tile(self, wrapped_tensor): 1387 """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it.""" 1388 output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked 1389 if is_stacked: 1390 return output 1391 else: 1392 return _stack(output, self._loop_len_vector).t 1393 1394 def convert(self, y): 1395 """Returns the converted value corresponding to y. 1396 1397 Args: 1398 y: A ops.Tensor or a ops.Operation object. If latter, y should not have 1399 any outputs. 1400 1401 Returns: 1402 If y does not need to be converted, it returns y as is. Else it returns 1403 the "converted value" corresponding to y. 1404 """ 1405 if y is None: 1406 return None 1407 if isinstance(y, sparse_tensor.SparseTensor): 1408 return self._convert_sparse(y) 1409 assert isinstance(y, (ops.Tensor, ops.Operation)), y 1410 output = self._convert_helper(y) 1411 if isinstance(output, WrappedTensor): 1412 assert isinstance(y, ops.Tensor) 1413 return self._unwrap_or_tile(output) 1414 else: 1415 assert isinstance(y, ops.Operation) 1416 assert not y.outputs 1417 assert isinstance(output, ops.Operation) 1418 return output 1419 1420 def _was_converted(self, t): 1421 """True if t is not a conversion of itself.""" 1422 converted_t = self._conversion_map[t] 1423 return converted_t.t is not t 1424 1425 def _add_conversion(self, old_output, new_output): 1426 assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output 1427 assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output 1428 self._conversion_map[old_output] = new_output 1429 1430 def _convert_reduction(self, y): 1431 # Handle reductions. 1432 if self._pfor_config is None or isinstance(y, ops.Operation): 1433 return None 1434 reduction = self._pfor_config._lookup_reduction(y) 1435 if reduction is None: 1436 return None 1437 (reduction_fn, reduction_args) = reduction 1438 batched_args = [] 1439 for reduction_arg in reduction_args: 1440 assert isinstance(reduction_arg, ops.Tensor), reduction_arg 1441 # Tensor being reduced should already be converted due to a control 1442 # dependency on the created placeholder. 1443 # Note that in cases where reduction_arg is in an outer context, one 1444 # needs to locate the corresponding Enter node and use that to lookup 1445 # the conversion. 1446 # TODO(agarwal): handle reductions inside control flow constructs. 1447 assert reduction_arg in self._conversion_map, ( 1448 "Unable to handle reduction of %s, possibly as it was used " 1449 "inside a control flow construct. Note that reductions across " 1450 "pfor iterations are currently not supported inside control flow " 1451 "constructs." % reduction_arg) 1452 batched_arg = self._conversion_map[reduction_arg] 1453 batched_args.append(self._unwrap_or_tile(batched_arg)) 1454 outputs = reduction_fn(*batched_args) 1455 return [wrap(output, False) for output in nest.flatten(outputs)] 1456 1457 def _convert_helper(self, op_or_tensor): 1458 stack = collections.deque([op_or_tensor]) 1459 while stack: 1460 y = stack[0] 1461 if y in self._conversion_map: 1462 assert isinstance(self._conversion_map[y], 1463 (WrappedTensor, ops.Operation)) 1464 stack.popleft() 1465 continue 1466 if isinstance(y, ops.Operation): 1467 assert not y.outputs, ( 1468 "We only support converting Operation objects with no outputs. " 1469 "Got %s", y) 1470 y_op = y 1471 else: 1472 assert isinstance(y, ops.Tensor), y 1473 y_op = y.op 1474 1475 is_while_loop = y_op.type == "Exit" 1476 if is_while_loop: 1477 while_op = WhileOp( 1478 y, pfor_ops=self._pfor_ops, 1479 fallback_to_while_loop=self.fallback_to_while_loop, 1480 pfor_config=self._pfor_config) 1481 is_inside_loop = while_op.is_inside_loop 1482 # If all nodes in the while_loop graph were created inside the pfor, we 1483 # treat the whole loop subgraph as a single op (y_op) and try to convert 1484 # it. For while_loops that are created completely or partially outside, 1485 # we treat them as external and should be able to simply return the Exit 1486 # node output as is without needing any conversion. Note that for 1487 # while_loops that are partially constructed inside, we assume they will 1488 # be loop invariant. If that is not the case, it will create runtime 1489 # errors since the converted graph would depend on the self._loop_var 1490 # placeholder. 1491 if is_inside_loop: 1492 y_op = while_op 1493 else: 1494 is_inside_loop = self.op_is_inside_loop(y_op) 1495 1496 # If this op was not created inside the loop body, we will return as is. 1497 # 1. Convert inputs and control inputs. 1498 1499 def _add_to_stack(x): 1500 if x not in self._conversion_map: 1501 stack.appendleft(x) 1502 return True 1503 else: 1504 return False 1505 1506 if is_inside_loop: 1507 added_to_stack = False 1508 for inp in y_op.inputs: 1509 added_to_stack |= _add_to_stack(inp) 1510 for cinp in y_op.control_inputs: 1511 if cinp.outputs: 1512 for t in cinp.outputs: 1513 added_to_stack |= _add_to_stack(t) 1514 else: 1515 added_to_stack |= _add_to_stack(cinp) 1516 if added_to_stack: 1517 continue 1518 1519 converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs] 1520 some_input_converted = any(self._was_converted(x) for x in y_op.inputs) 1521 some_input_stacked = any(x.is_stacked for x in converted_inputs) 1522 1523 converted_control_ops = set() 1524 some_control_input_converted = False 1525 for cinp in y_op.control_inputs: 1526 if cinp.outputs: 1527 for t in cinp.outputs: 1528 converted_t = self._conversion_map[t] 1529 if self._was_converted(t): 1530 some_control_input_converted = True 1531 converted_control_ops.add(converted_t.t.op) 1532 else: 1533 converted_cinp = self._conversion_map[cinp] 1534 assert isinstance(converted_cinp, ops.Operation) 1535 if converted_cinp != cinp: 1536 some_control_input_converted = True 1537 converted_control_ops.add(converted_cinp) 1538 converted_control_ops = list(converted_control_ops) 1539 is_stateful = _is_stateful_pfor_op(y_op) 1540 else: 1541 converted_inputs = [] 1542 converted_control_ops = [] 1543 logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op, 1544 converted_inputs, converted_control_ops) 1545 1546 # 2. Convert y_op 1547 # If converting a while_loop, we let the while_loop convertor deal with 1548 # putting the control dependencies appropriately. 1549 control_dependencies = [] if is_while_loop else converted_control_ops 1550 with ops.control_dependencies(control_dependencies), ops.name_scope( 1551 y_op.name + "/pfor/"), ops.get_default_graph()._original_op(y_op): 1552 # Op is a placeholder for a reduction. 1553 reduce_output = self._convert_reduction(y) 1554 if reduce_output is not None: 1555 new_outputs = reduce_output 1556 # None of the inputs and control inputs were converted. 1557 elif ((not is_inside_loop or 1558 (not is_stateful and not some_input_converted and 1559 not some_control_input_converted)) and 1560 y.graph == ops.get_default_graph()): 1561 if y is y_op: 1562 assert not isinstance(y_op, WhileOp) 1563 new_outputs = y_op 1564 else: 1565 new_outputs = [wrap(x, False) for x in y_op.outputs] 1566 elif not (is_stateful or is_while_loop or some_input_stacked): 1567 # All inputs are unstacked or unconverted but some control inputs are 1568 # converted. 1569 # TODO(rachelim): Handle the case where some inputs are sparsely 1570 # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs)) 1571 new_op = _create_op(y_op.type, [x.t for x in converted_inputs], 1572 [x.dtype for x in y_op.outputs], 1573 y_op.node_def.attr) 1574 if y is y_op: 1575 new_outputs = new_op 1576 else: 1577 new_outputs = [] 1578 for old_output, new_output in zip(y_op.outputs, new_op.outputs): 1579 custom_gradient.copy_handle_data(old_output, new_output) 1580 new_outputs.append(wrap(new_output, False)) 1581 else: 1582 # Either some inputs are not loop invariant or op is stateful. 1583 if hasattr(y_op, "pfor_converter"): 1584 converter = y_op.pfor_converter 1585 else: 1586 converter = _pfor_converter_registry.get(y_op.type, None) 1587 if converter is None: 1588 has_variant_outputs = any(x.dtype == dtypes.variant for x in 1589 y_op.outputs) 1590 if self._fallback_to_while_loop and not has_variant_outputs: 1591 converter = _fallback_converter 1592 else: 1593 message = ("No pfor vectorization defined for %s\n" 1594 "%s\n" 1595 "inputs: %s. " % 1596 (y_op.type, y_op, converted_inputs)) 1597 if not self._fallback_to_while_loop: 1598 message += ("Consider enabling the fallback_to_while_loop " 1599 "option to pfor, which may run slower.") 1600 raise ValueError(message) 1601 # TODO(rachelim): Handle the case where some inputs are sparsely 1602 # stacked. We should only call the converter if it supports handling 1603 # those inputs. 1604 pfor_inputs = _PforInput(self, y_op, converted_inputs) 1605 try: 1606 try: 1607 new_outputs = converter(pfor_inputs) 1608 except ConversionNotImplementedError as e: 1609 if self._fallback_to_while_loop: 1610 new_outputs = _fallback_converter(pfor_inputs) 1611 else: 1612 six.reraise(ValueError, ValueError(str(e)), sys.exc_info()[2]) 1613 except Exception as e: # pylint: disable=broad-except 1614 logging.error( 1615 "Got error while pfor was converting op %s" 1616 "with inputs %s\n, converted inputs %s\n" 1617 "%s\n" 1618 "Here are the pfor conversion stack traces:", y_op, 1619 y_op.inputs[:], pfor_inputs.inputs, str(e)) 1620 original_op = y_op 1621 while isinstance(original_op, ops.Operation): 1622 logging.error( 1623 "%s\ncreated at:\n %s", original_op, 1624 " ".join(traceback.format_list(original_op.traceback))) 1625 original_op = original_op._original_op 1626 six.reraise(e.__class__, e, sys.exc_info()[2]) 1627 1628 if isinstance(new_outputs, WrappedTensor): 1629 new_outputs = [new_outputs] 1630 assert isinstance(new_outputs, 1631 (list, tuple, ops.Operation)), new_outputs 1632 logging.vlog(2, "converted %s %s", y_op, new_outputs) 1633 1634 # Insert into self._conversion_map 1635 if y is y_op: 1636 assert isinstance(new_outputs, ops.Operation) 1637 self._add_conversion(y_op, new_outputs) 1638 else: 1639 assert len(y_op.outputs) == len(new_outputs), (y_op, y_op.outputs, 1640 new_outputs) 1641 for old_output, new_output in zip(y_op.outputs, new_outputs): 1642 assert isinstance(new_output, WrappedTensor), (new_output, y, y_op) 1643 assert old_output.dtype == new_output.t.dtype, (new_output, y, y_op) 1644 # Set shape for converted output. 1645 output_shape = old_output.shape 1646 if not new_output.is_sparse_stacked: 1647 if new_output.is_stacked: 1648 loop_len = tensor_util.constant_value(self.loop_len_vector) 1649 if loop_len is None: 1650 batch_dim = tensor_shape.TensorShape([None]) 1651 else: 1652 batch_dim = tensor_shape.TensorShape(loop_len) 1653 output_shape = batch_dim.concatenate(output_shape) 1654 if _is_variant_with_internal_stacking(new_output.t): 1655 new_output.t.set_shape([]) 1656 else: 1657 new_output.t.set_shape(output_shape) 1658 self._add_conversion(old_output, new_output) 1659 stack.popleft() 1660 1661 return self._conversion_map[op_or_tensor] 1662 1663 @property 1664 def loop_len_vector(self): 1665 """Returns a single element vector whose value is number of iterations.""" 1666 return self._loop_len_vector 1667 1668 @property 1669 def loop_var(self): 1670 """Returns placeholder loop variable.""" 1671 return self._loop_var 1672 1673 @property 1674 def pfor_ops(self): 1675 return self._pfor_ops 1676 1677 @property 1678 def pfor_config(self): 1679 return self._pfor_config 1680 1681 @property 1682 def all_indices_partitioned(self): 1683 """all_indices_partitioned property. 1684 1685 Returns: 1686 True if we are inside a control flow construct and not all pfor iterations 1687 may be active. 1688 """ 1689 return self._all_indices_partitioned 1690 1691 @property 1692 def fallback_to_while_loop(self): 1693 return self._fallback_to_while_loop 1694 1695 1696# The code below defines converters for different operations. Please see comment 1697# for RegisterPFor to see how converters should be defined. 1698 1699 1700# image_ops 1701 1702 1703@RegisterPFor("AdjustContrastv2") 1704def _convert_adjust_contrastv2(pfor_input): 1705 images = pfor_input.stacked_input(0) 1706 contrast_factor = pfor_input.unstacked_input(1) 1707 return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True) 1708 1709 1710@RegisterPFor("AdjustHue") 1711def _convert_adjust_hue(pfor_input): 1712 images = pfor_input.stacked_input(0) 1713 delta = pfor_input.unstacked_input(1) 1714 return wrap(gen_image_ops.adjust_hue(images, delta), True) 1715 1716 1717@RegisterPFor("AdjustSaturation") 1718def _convert_adjust_saturation(pfor_input): 1719 images = pfor_input.stacked_input(0) 1720 scale = pfor_input.unstacked_input(1) 1721 return wrap(gen_image_ops.adjust_saturation(images, scale), True) 1722 1723 1724# nn_ops 1725 1726 1727def _flatten_first_two_dims(x): 1728 """Merges first two dimensions.""" 1729 old_shape = array_ops.shape(x) 1730 new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0) 1731 return array_ops.reshape(x, new_shape) 1732 1733 1734def _unflatten_first_dim(x, first_dim): 1735 """Splits first dimension into [first_dim, -1].""" 1736 old_shape = array_ops.shape(x) 1737 new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0) 1738 return array_ops.reshape(x, new_shape) 1739 1740 1741def _inputs_with_flattening(pfor_input, input_indices): 1742 """Stacks and flattens first dim of inputs at indices `input_indices`.""" 1743 if input_indices is None: 1744 input_indices = [] 1745 pfor_input.stack_inputs(stack_indices=input_indices) 1746 inputs = [] 1747 for i in range(pfor_input.num_inputs): 1748 if i in input_indices: 1749 inp = pfor_input.stacked_input(i) 1750 inp = _flatten_first_two_dims(inp) 1751 else: 1752 inp = pfor_input.unstacked_input(i) 1753 inputs.append(inp) 1754 return inputs 1755 1756 1757@RegisterPForWithArgs("Conv2D", dims=[0]) 1758@RegisterPForWithArgs("DepthToSpace", dims=[0]) 1759@RegisterPForWithArgs("AvgPool", dims=[0]) 1760@RegisterPForWithArgs("AvgPool3D", dims=[0]) 1761@RegisterPForWithArgs("MaxPool", dims=[0]) 1762@RegisterPForWithArgs("MaxPoolV2", dims=[0]) 1763@RegisterPForWithArgs("MaxPool3D", dims=[0]) 1764@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2]) 1765@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2]) 1766@RegisterPForWithArgs("MaxPoolGradV2", dims=[0, 1, 2]) 1767@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2]) 1768@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2]) 1769@RegisterPForWithArgs("MaxPoolGradGradV2", dims=[0, 1, 2]) 1770@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1]) 1771@RegisterPForWithArgs("SparseSoftmaxCrossEntropyWithLogits", dims=[0, 1]) 1772@RegisterPForWithArgs("SpaceToDepth", dims=[0]) 1773def _convert_flatten_batch(pfor_input, op_type, dims): 1774 del op_type 1775 inputs = _inputs_with_flattening(pfor_input, dims) 1776 outputs = _create_op( 1777 pfor_input.op_type, 1778 inputs, [x.dtype for x in pfor_input.outputs], 1779 attrs=pfor_input.op.node_def.attr).outputs 1780 n = pfor_input.pfor.loop_len_vector 1781 outputs = [_unflatten_first_dim(x, n) for x in outputs] 1782 return [wrap(x, True) for x in outputs] 1783 1784 1785_channel_flatten_input_cache = {} 1786 1787 1788@RegisterPFor("BatchToSpaceND") 1789def _convert_batch_to_space_nd(pfor_input): 1790 inp = pfor_input.stacked_input(0) 1791 block_shape = pfor_input.unstacked_input(1) 1792 crops = pfor_input.unstacked_input(2) 1793 1794 inp_shape = array_ops.shape(inp) 1795 n = pfor_input.pfor.loop_len_vector 1796 1797 # Reshape and transpose to move the vectorization axis inside the axes that 1798 # will move to space. 1799 # Reshape to 4D and transpose 1800 block_size = math_ops.reduce_prod(block_shape) 1801 new_shape = [n[0], block_size, inp_shape[1] // block_size, -1] 1802 inp = array_ops.reshape(inp, new_shape) 1803 inp = array_ops.transpose(inp, [1, 0, 2, 3]) 1804 # Reshape back to merge the block, vectorization and batch dimension, and 1805 # restore the other dimensions. 1806 new_shape = array_ops.concat([n * inp_shape[1], inp_shape[2:]], axis=0) 1807 inp = array_ops.reshape(inp, new_shape) 1808 # Call batch_to_space and then split the new batch axis. 1809 output = gen_array_ops.batch_to_space_nd(inp, block_shape, crops) 1810 output = _unflatten_first_dim(output, n) 1811 return wrap(output, True) 1812 1813 1814@RegisterPFor("SpaceToBatchND") 1815def _convert_space_to_batch_nd(pfor_input): 1816 inp = pfor_input.stacked_input(0) 1817 block_shape = pfor_input.unstacked_input(1) 1818 paddings = pfor_input.unstacked_input(2) 1819 1820 n = pfor_input.pfor.loop_len_vector 1821 inp_shape = array_ops.shape(inp) 1822 inp = _flatten_first_two_dims(inp) 1823 output = gen_array_ops.space_to_batch_nd(inp, block_shape, paddings) 1824 output_shape = array_ops.shape(output) 1825 block_size = math_ops.reduce_prod(block_shape) 1826 new_shape = [block_size, n[0], -1] 1827 output = array_ops.reshape(output, new_shape) 1828 output = array_ops.transpose(output, [1, 0, 2]) 1829 new_shape = array_ops.concat( 1830 [n, block_size * inp_shape[1:2], output_shape[1:]], axis=0) 1831 output = array_ops.reshape(output, new_shape) 1832 return wrap(output, True) 1833 1834 1835def _channel_flatten_input(x, data_format): 1836 """Merge the stack dimension with the channel dimension. 1837 1838 If S is pfor's stacking dimension, then, 1839 - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose 1840 should be cheap. 1841 - for SNHWC, we transpose to NHWCS. 1842 We then merge the S and C dimension. 1843 1844 Args: 1845 x: ops.Tensor to transform. 1846 data_format: "NCHW" or "NHWC". 1847 1848 Returns: 1849 A 3-element tuple with the transformed value, along with the shape for 1850 reshape and order for transpose required to transform back. 1851 """ 1852 1853 graph = ops.get_default_graph() 1854 cache_key = (graph, x.ref(), data_format) 1855 if cache_key not in _channel_flatten_input_cache: 1856 x_shape = array_ops.shape(x) 1857 if data_format == b"NCHW": 1858 order = [1, 0, 2, 3, 4] 1859 shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0) 1860 reverse_order = order 1861 else: 1862 order = [1, 2, 3, 0, 4] 1863 shape = array_ops.concat([x_shape[1:4], [-1]], axis=0) 1864 reverse_order = [3, 0, 1, 2, 4] 1865 # Move S dimension next to C dimension. 1866 x = array_ops.transpose(x, order) 1867 reverse_shape = array_ops.shape(x) 1868 # Reshape to merge the S and C dimension. 1869 x = array_ops.reshape(x, shape) 1870 outputs = x, reverse_order, reverse_shape 1871 _channel_flatten_input_cache[cache_key] = outputs 1872 else: 1873 outputs = _channel_flatten_input_cache[cache_key] 1874 return outputs 1875 1876 1877# Note that with training=True, running FusedBatchNormV3 on individual examples 1878# is very different from running FusedBatchNormV3 on a batch of those examples. 1879# This is because, for the latter case, the operation can be considered as first 1880# computing the mean and variance over all the examples and then using these 1881# to scale all those examples. This creates a data dependency between these 1882# different "iterations" since the inputs to the scaling step depends on the 1883# statistics coming from all these inputs. 1884# As with other kernels, the conversion here effectively runs the kernel 1885# independently for each iteration, and returns outputs by stacking outputs from 1886# each of those iterations. 1887@RegisterPFor("FusedBatchNormV3") 1888def _convert_fused_batch_norm(pfor_input): 1889 is_training = pfor_input.get_attr("is_training") 1890 # When BatchNorm is used with training=False, mean and variance are provided 1891 # externally and used as is by the op. Thus, we can merge the S and N 1892 # dimensions as we do for regular operations. 1893 # When BatchNorm is used with training=True, mean and variance are computed 1894 # for each channel across the batch dimension (first one). If we merge S and N 1895 # dimensions, mean and variances will be computed over a larger set. So, we 1896 # merge the S and C dimensions instead. 1897 if not is_training: 1898 # We return zeros for batch_mean and batch_variance output. Note that CPU 1899 # and GPU seem to have different behavior for those two outputs. CPU outputs 1900 # zero because these values are not used during inference. GPU outputs 1901 # something, probably real means and variances. 1902 inputs = _inputs_with_flattening(pfor_input, [0]) 1903 outputs = _create_op( 1904 pfor_input.op_type, 1905 inputs, [x.dtype for x in pfor_input.outputs], 1906 attrs=pfor_input.op.node_def.attr).outputs 1907 y = outputs[0] 1908 n = pfor_input.pfor.loop_len_vector 1909 y = _unflatten_first_dim(y, n) 1910 mean = pfor_input.unstacked_input(3) 1911 zeros = array_ops.zeros_like(mean) 1912 return [wrap(y, True)] + [wrap(zeros, False)] * 5 1913 1914 pfor_input.stack_inputs() 1915 data_format = pfor_input.get_attr("data_format") 1916 # We merge the first dimension with the "C" dimension, run FusedBatchNormV3, 1917 # and then transpose back. 1918 x = pfor_input.stacked_input(0) 1919 x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format) 1920 # Note that we stack all the other inputs as well so that they are the same 1921 # size as the new size of the channel dimension. 1922 inputs = [x] + [ 1923 array_ops.reshape(pfor_input.stacked_input(i), [-1]) 1924 for i in range(1, pfor_input.num_inputs) 1925 ] 1926 outputs = _create_op( 1927 pfor_input.op_type, 1928 inputs, [x.dtype for x in pfor_input.outputs], 1929 attrs=pfor_input.op.node_def.attr).outputs 1930 y = outputs[0] 1931 y = array_ops.reshape(y, reverse_shape) 1932 y = array_ops.transpose(y, reverse_order) 1933 n = pfor_input.pfor.loop_len_vector 1934 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] 1935 outputs = [y] + outputs 1936 return [wrap(x, True) for x in outputs] 1937 1938 1939@RegisterPFor("FusedBatchNormGradV3") 1940def _convert_fused_batch_norm_grad(pfor_input): 1941 pfor_input.stack_inputs() 1942 data_format = pfor_input.get_attr("data_format") 1943 y_backprop = pfor_input.stacked_input(0) 1944 y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format) 1945 x = pfor_input.stacked_input(1) 1946 x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format) 1947 inputs = [y_backprop, x] + [ 1948 array_ops.reshape(pfor_input.stacked_input(i), [-1]) 1949 for i in range(2, pfor_input.num_inputs) 1950 ] 1951 outputs = _create_op( 1952 pfor_input.op_type, 1953 inputs, [x.dtype for x in pfor_input.outputs], 1954 attrs=pfor_input.op.node_def.attr).outputs 1955 x_backprop = outputs[0] 1956 x_backprop = array_ops.reshape(x_backprop, x_reverse_shape) 1957 x_backprop = array_ops.transpose(x_backprop, x_reverse_order) 1958 n = pfor_input.pfor.loop_len_vector 1959 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] 1960 outputs = [x_backprop] + outputs 1961 return [wrap(output, True) for output in outputs] 1962 1963 1964@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0) 1965@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0) 1966@RegisterPForWithArgs("AvgPool3DGrad", flatten_dims=[1], shape_dim=0) 1967def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims, 1968 shape_dim): 1969 del op_type 1970 inputs = _inputs_with_flattening(pfor_input, flatten_dims) 1971 n = pfor_input.pfor.loop_len_vector 1972 # Adjust the `input_sizes` input. 1973 ones = array_ops.ones([array_ops.shape(inputs[shape_dim])[0] - 1], 1974 dtype=n.dtype) 1975 inputs[shape_dim] *= array_ops.concat([n, ones], axis=0) 1976 outputs = _create_op( 1977 pfor_input.op_type, 1978 inputs, [x.dtype for x in pfor_input.outputs], 1979 attrs=pfor_input.op.node_def.attr).outputs 1980 outputs = [_unflatten_first_dim(x, n) for x in outputs] 1981 return [wrap(x, True) for x in outputs] 1982 1983 1984@RegisterPFor("Conv2DBackpropFilter") 1985def _convert_conv2d_backprop_filter(pfor_input): 1986 pfor_input.stack_inputs(stack_indices=[2]) 1987 inputs, inputs_stacked, _ = pfor_input.input(0) 1988 filter_sizes = pfor_input.unstacked_input(1) 1989 grads = pfor_input.stacked_input(2) 1990 strides = pfor_input.get_attr("strides") 1991 padding = pfor_input.get_attr("padding") 1992 use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu") 1993 data_format = pfor_input.get_attr("data_format") 1994 dilations = pfor_input.get_attr("dilations") 1995 if inputs_stacked: 1996 # TODO(agarwal): Implement this efficiently. 1997 logging.warn("Conv2DBackpropFilter uses a while_loop. Fix that!") 1998 1999 def while_body(i, ta): 2000 inp_i = inputs[i, ...] 2001 grad_i = grads[i, ...] 2002 output = nn_ops.conv2d_backprop_filter( 2003 inp_i, 2004 filter_sizes, 2005 grad_i, 2006 strides=strides, 2007 padding=padding, 2008 use_cudnn_on_gpu=use_cudnn_on_gpu, 2009 data_format=data_format, 2010 dilations=dilations) 2011 return i + 1, ta.write(i, array_ops.expand_dims(output, 0)) 2012 2013 n = array_ops.reshape(pfor_input.pfor.loop_len_vector, []) 2014 _, ta = control_flow_ops.while_loop( 2015 lambda i, ta: i < n, while_body, 2016 (0, tensor_array_ops.TensorArray(inputs.dtype, n))) 2017 output = ta.concat() 2018 return wrap(output, True) 2019 else: 2020 # We merge the stack dimension with the channel dimension of the gradients 2021 # and pretend we had a larger filter (see change to filter_sizes below). 2022 # Once the filter backprop is computed, we reshape and transpose back 2023 # appropriately. 2024 grads, _, _ = _channel_flatten_input(grads, data_format) 2025 n = pfor_input.pfor.loop_len_vector 2026 old_filter_sizes = filter_sizes 2027 filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0) 2028 output = nn_ops.conv2d_backprop_filter( 2029 inputs, 2030 filter_sizes, 2031 grads, 2032 strides=strides, 2033 padding=padding, 2034 use_cudnn_on_gpu=use_cudnn_on_gpu, 2035 data_format=data_format, 2036 dilations=dilations) 2037 new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0) 2038 output = array_ops.reshape(output, new_filter_shape) 2039 output = array_ops.transpose(output, [3, 0, 1, 2, 4]) 2040 return wrap(output, True) 2041 2042 2043@RegisterPForWithArgs("LogSoftmax", gen_nn_ops.log_softmax) 2044@RegisterPForWithArgs("Softmax", gen_nn_ops.softmax) 2045def _convert_softmax(pfor_input, op_type, op_func): 2046 del op_type 2047 return wrap(op_func(pfor_input.stacked_input(0)), True) 2048 2049 2050# array_ops 2051 2052 2053@RegisterPForWithArgs("Identity", array_ops.identity) 2054@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient) 2055@RegisterPForWithArgs("MatrixDiag", array_ops.matrix_diag) 2056@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part) 2057def _convert_identity(pfor_input, op_type, op_func): 2058 del op_type 2059 return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) 2060 2061 2062@RegisterPFor("IdentityN") 2063def _convert_identity_n(pfor_input): 2064 outputs = array_ops.identity_n([x.t for x in pfor_input.inputs]) 2065 return [ 2066 wrap(out, inp.is_stacked) for out, inp in zip(outputs, pfor_input.inputs) 2067 ] 2068 2069 2070@RegisterPFor("Reshape") 2071def _convert_reshape(pfor_input): 2072 t = pfor_input.stacked_input(0) 2073 shape = pfor_input.unstacked_input(1) 2074 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 2075 return wrap(array_ops.reshape(t, new_shape), True) 2076 2077 2078@RegisterPFor("Fill") 2079def _convert_fill(pfor_input): 2080 dims = pfor_input.unstacked_input(0) 2081 value = pfor_input.stacked_input(1) 2082 # Expand the rank of `value` 2083 new_shape = array_ops.concat( 2084 [[-1], array_ops.ones([array_ops.size(dims)], dtype=dtypes.int32)], 2085 axis=0) 2086 value = array_ops.reshape(value, new_shape) 2087 # Compute the new output shape 2088 new_dims = array_ops.concat([pfor_input.pfor.loop_len_vector, dims], axis=0) 2089 # Broadcast 2090 return wrap(array_ops.broadcast_to(value, new_dims), True) 2091 2092 2093@RegisterPFor("BroadcastTo") 2094def _convert_broadcast_to(pfor_input): 2095 t = pfor_input.stacked_input(0) 2096 shape = pfor_input.unstacked_input(1) 2097 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 2098 2099 # Expand dims of stacked t to broadcast against the new shape. 2100 # TODO(davmre): consider factoring out common code with 2101 # `expanddim_inputs_for_broadcast`, which has similar logic but with 2102 # implicit shapes (of input Tensors) rather than explicit shapes. 2103 rank_diff = array_ops.shape(new_shape)[0] - array_ops.rank(t) 2104 ones = array_ops.tile([1], array_ops.reshape(rank_diff, [1])) 2105 t_shape = array_ops.shape(t) 2106 t_expanded_shape = array_ops.concat([t_shape[:1], ones, t_shape[1:]], axis=0) 2107 2108 return wrap( 2109 array_ops.broadcast_to(array_ops.reshape(t, t_expanded_shape), new_shape), 2110 True) 2111 2112 2113@RegisterPFor("ExpandDims") 2114def _convert_expanddims(pfor_input): 2115 t = pfor_input.stacked_input(0) 2116 dim = pfor_input.unstacked_input(1) 2117 dim += math_ops.cast(dim >= 0, dim.dtype) 2118 return wrap(array_ops.expand_dims(t, axis=dim), True) 2119 2120 2121@RegisterPForWithArgs("LowerBound", gen_array_ops.lower_bound) 2122@RegisterPForWithArgs("UpperBound", gen_array_ops.upper_bound) 2123def _convert_searchsorted(pfor_input, _, op_func): 2124 pfor_input.stack_inputs() 2125 sorted_inputs = _flatten_first_two_dims(pfor_input.stacked_input(0)) 2126 values = _flatten_first_two_dims(pfor_input.stacked_input(1)) 2127 out_type = pfor_input.get_attr("out_type") 2128 output = op_func(sorted_inputs, values, out_type) 2129 return wrap( 2130 _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector), True) 2131 2132 2133@RegisterPFor("MatrixBandPart") 2134def _convert_matrix_band_part(pfor_input): 2135 t = pfor_input.stacked_input(0) 2136 num_lower = pfor_input.unstacked_input(1) 2137 num_upper = pfor_input.unstacked_input(2) 2138 return wrap( 2139 array_ops.matrix_band_part(t, num_lower=num_lower, num_upper=num_upper), 2140 True) 2141 2142 2143@RegisterPFor("MatrixSetDiag") 2144def _convert_matrix_set_diag(pfor_input): 2145 pfor_input.stack_inputs() 2146 t = pfor_input.stacked_input(0) 2147 diag = pfor_input.stacked_input(1) 2148 return wrap(array_ops.matrix_set_diag(t, diag), True) 2149 2150 2151# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3. 2152# The input orders defined in the OpKernel and the actual python API are 2153# different (for compatibility with V1), so we cannot use _convert_identity. 2154# v2 is not compatible with v3 and is never exposed on the public API. 2155@RegisterPFor("MatrixDiagV2") 2156@RegisterPFor("MatrixDiagV3") 2157def _convert_matrix_diag_v2(pfor_input): 2158 params = { 2159 "diagonal": pfor_input.stacked_input(0), 2160 "k": pfor_input.unstacked_input(1), 2161 "num_rows": pfor_input.unstacked_input(2), 2162 "num_cols": pfor_input.unstacked_input(3), 2163 "padding_value": pfor_input.unstacked_input(4) 2164 } 2165 if pfor_input.op_type == "MatrixDiagV2": 2166 return wrap(array_ops.matrix_diag_v2(**params), True) 2167 params["align"] = pfor_input.get_attr("align") 2168 return wrap(array_ops.matrix_diag(**params), True) 2169 2170 2171@RegisterPFor("Diag") 2172def _convert_diag(pfor_input): 2173 diag = pfor_input.stacked_input(0) 2174 if diag.shape.ndims == 2: 2175 # We can use matrix_diag. 2176 return wrap(array_ops.matrix_diag(diag), True) 2177 else: 2178 # It is not clear if we can do better than a while loop here with existing 2179 # kernels. 2180 return _fallback_converter(pfor_input, warn=False) 2181 2182 2183# See notes for MatrixDiagV2 2184@RegisterPFor("MatrixDiagPartV2") 2185@RegisterPFor("MatrixDiagPartV3") 2186def _convert_matrix_diag_part_v2(pfor_input): 2187 params = { 2188 "input": pfor_input.stacked_input(0), 2189 "k": pfor_input.unstacked_input(1), 2190 "padding_value": pfor_input.unstacked_input(2) 2191 } 2192 if pfor_input.op_type == "MatrixDiagPartV2": 2193 return wrap(array_ops.matrix_diag_part_v2(**params), True) 2194 params["align"] = pfor_input.get_attr("align") 2195 return wrap(array_ops.matrix_diag_part(**params), True) 2196 2197 2198# See notes for MatrixDiagV2 2199@RegisterPFor("MatrixSetDiagV2") 2200@RegisterPFor("MatrixSetDiagV3") 2201def _convert_matrix_set_diag_v2(pfor_input): 2202 pfor_input.stack_inputs([0, 1]) 2203 params = { 2204 "input": pfor_input.stacked_input(0), 2205 "diagonal": pfor_input.stacked_input(1), 2206 "k": pfor_input.unstacked_input(2) 2207 } 2208 if pfor_input.op_type == "MatrixSetDiagV2": 2209 return wrap(array_ops.matrix_set_diag_v2(**params), True) 2210 params["align"] = pfor_input.get_attr("align") 2211 return wrap(array_ops.matrix_set_diag(**params), True) 2212 2213 2214@RegisterPFor("DiagPart") 2215def _convert_diag_part(pfor_input): 2216 inp = pfor_input.stacked_input(0) 2217 if inp.shape.ndims == 3: 2218 # We can use matrix_diag_part. 2219 return wrap(array_ops.matrix_diag_part(inp), True) 2220 else: 2221 # It is not clear if we can do better than a while loop here with existing 2222 # kernels. 2223 return _fallback_converter(pfor_input, warn=False) 2224 2225 2226@RegisterPFor("OneHot") 2227def _convert_one_hot(pfor_input): 2228 indices = pfor_input.stacked_input(0) 2229 depth = pfor_input.unstacked_input(1) 2230 on_value = pfor_input.unstacked_input(2) 2231 off_value = pfor_input.unstacked_input(3) 2232 axis = pfor_input.get_attr("axis") 2233 if axis >= 0: 2234 axis += 1 2235 return wrap( 2236 array_ops.one_hot(indices, depth, on_value, off_value, axis), True) 2237 2238 2239@RegisterPFor("Slice") 2240def _convert_slice(pfor_input): 2241 t = pfor_input.stacked_input(0) 2242 begin = pfor_input.unstacked_input(1) 2243 size = pfor_input.unstacked_input(2) 2244 begin = array_ops.concat([[0], begin], axis=0) 2245 size = array_ops.concat([[-1], size], axis=0) 2246 return wrap(array_ops.slice(t, begin, size), True) 2247 2248 2249@RegisterPFor("Tile") 2250def _convert_tile(pfor_input): 2251 t = pfor_input.stacked_input(0) 2252 multiples = pfor_input.unstacked_input(1) 2253 multiples = array_ops.concat([[1], multiples], 0) 2254 return wrap(array_ops.tile(t, multiples), True) 2255 2256 2257@RegisterPFor("Pack") 2258def _convert_pack(pfor_input): 2259 pfor_input.stack_inputs() 2260 axis = pfor_input.get_attr("axis") 2261 if axis >= 0: 2262 axis += 1 2263 return wrap( 2264 array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True) 2265 2266 2267@RegisterPFor("Unpack") 2268def _convert_unpack(pfor_input): 2269 value = pfor_input.stacked_input(0) 2270 axis = pfor_input.get_attr("axis") 2271 if axis >= 0: 2272 axis += 1 2273 num = pfor_input.get_attr("num") 2274 return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)] 2275 2276 2277@RegisterPFor("Pad") 2278def _convert_pad(pfor_input): 2279 t = pfor_input.stacked_input(0) 2280 paddings = pfor_input.unstacked_input(1) 2281 paddings = array_ops.concat([[[0, 0]], paddings], 0) 2282 return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True) 2283 2284 2285@RegisterPFor("Split") 2286def _convert_split(pfor_input): 2287 split_dim = pfor_input.unstacked_input(0) 2288 t = pfor_input.stacked_input(1) 2289 num_split = pfor_input.get_attr("num_split") 2290 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) 2291 return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)] 2292 2293 2294@RegisterPFor("SplitV") 2295def _convert_split_v(pfor_input): 2296 t = pfor_input.stacked_input(0) 2297 splits = pfor_input.unstacked_input(1) 2298 split_dim = pfor_input.unstacked_input(2) 2299 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) 2300 return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)] 2301 2302 2303@RegisterPFor("Squeeze") 2304def _convert_squeeze(pfor_input): 2305 t = pfor_input.stacked_input(0) 2306 squeeze_dims = pfor_input.get_attr("squeeze_dims") 2307 squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims] 2308 return wrap(array_ops.squeeze(t, axis=squeeze_dims), True) 2309 2310 2311@RegisterPFor("ReverseV2") 2312def _convert_reverse(pfor_input): 2313 value = pfor_input.stacked_input(0) 2314 axis = pfor_input.unstacked_input(1) 2315 new_axis = array_ops.where_v2(axis >= 0, axis + 1, axis) 2316 return wrap(gen_array_ops.reverse_v2(value, axis=new_axis), True) 2317 2318 2319@RegisterPForWithArgs("Transpose", gen_array_ops.transpose) 2320@RegisterPForWithArgs("ConjugateTranspose", gen_array_ops.conjugate_transpose) 2321def _convert_transpose(pfor_input, _, op_func): 2322 t = pfor_input.stacked_input(0) 2323 perm = pfor_input.unstacked_input(1) 2324 new_perm = array_ops.concat([[0], perm + 1], axis=0) 2325 return wrap(op_func(t, new_perm), True) 2326 2327 2328@RegisterPFor("ZerosLike") 2329def _convert_zeroslike(pfor_input): 2330 t = pfor_input.stacked_input(0) 2331 shape = array_ops.shape(t)[1:] 2332 return wrap(array_ops.zeros(shape, dtype=t.dtype), False) 2333 2334 2335@RegisterPFor("Gather") 2336@RegisterPFor("GatherV2") 2337def _convert_gather(pfor_input): 2338 param, param_stacked, _ = pfor_input.input(0) 2339 indices, indices_stacked, _ = pfor_input.input(1) 2340 batch_dims = pfor_input.get_attr("batch_dims") 2341 2342 op_type = pfor_input.op_type 2343 if op_type == "Gather": 2344 validate_indices = pfor_input.get_attr("validate_indices") 2345 axis = 0 2346 else: 2347 validate_indices = None 2348 # Assume we will never have a Tensor with rank > 2**32. 2349 axis = math_ops.cast(pfor_input.unstacked_input(2), dtypes.int32) 2350 axis_value = tensor_util.constant_value(axis) 2351 if axis_value is not None: 2352 axis = axis_value 2353 if indices_stacked and not param_stacked: 2354 if indices is pfor_input.pfor.all_indices and axis == 0: 2355 param_shape0 = tensor_shape.dimension_value(param.shape[0]) 2356 indices_shape0 = tensor_shape.dimension_value(indices.shape[0]) 2357 if param_shape0 is not None and indices_shape0 == param_shape0: 2358 # Note that with loops and conditionals, indices may not be contiguous. 2359 # However they will be sorted and unique. So if the shape matches, then 2360 # it must be picking up all the rows of param. 2361 return wrap(param, True) 2362 2363 if batch_dims != 0: 2364 # Convert `batch_dims` to its positive equivalent if necessary. 2365 batch_dims_pos = batch_dims 2366 if batch_dims < 0: 2367 batch_dims_pos += array_ops.rank(indices) 2368 # In order to maintain 2369 # indices.shape[:batch_dims] == params.shape[:batch_dims] 2370 # with stacked indices, we move the first dimension of `indices` to the 2371 # `batch_dims + 1`th position. The (non-batch) index dimensions will be 2372 # inserted into the shape of `output` at the `axis` dimension, which is 2373 # then transposed to the front (below). 2374 order = array_ops.concat([ 2375 math_ops.range(1, batch_dims_pos + 1), 2376 [0], 2377 math_ops.range(batch_dims_pos + 1, array_ops.rank(indices))], axis=0) 2378 indices = array_ops.transpose(indices, order) 2379 2380 output = array_ops.gather( 2381 param, indices, validate_indices=validate_indices, axis=axis, 2382 batch_dims=batch_dims) 2383 if axis != 0: 2384 axis = control_flow_ops.cond(axis < 0, 2385 lambda: axis + array_ops.rank(param), 2386 lambda: axis) 2387 order = array_ops.concat( 2388 [[axis], 2389 math_ops.range(axis), 2390 math_ops.range(axis + 1, array_ops.rank(output))], 2391 axis=0) 2392 output = control_flow_ops.cond( 2393 math_ops.equal(axis, 0), lambda: output, 2394 lambda: array_ops.transpose(output, order)) 2395 return wrap(output, True) 2396 if param_stacked: 2397 pfor_input.stack_inputs(stack_indices=[1]) 2398 indices = pfor_input.stacked_input(1) 2399 2400 output = array_ops.gather( 2401 param, indices, 2402 axis=array_ops.where(axis >= 0, axis + 1, axis), 2403 batch_dims=(batch_dims + 1 if batch_dims >= 0 else batch_dims)) 2404 return wrap(output, True) 2405 2406 2407@RegisterPFor("GatherNd") 2408def _convert_gather_nd(pfor_input): 2409 # TODO(jmenick): Add support for unstacked params. 2410 pfor_input.stack_inputs(stack_indices=[1]) 2411 params = pfor_input.stacked_input(0) 2412 indices = pfor_input.stacked_input(1) 2413 stacked_result = array_ops.gather_nd(params, indices, batch_dims=1) 2414 return wrap(stacked_result, True) 2415 2416 2417@RegisterPFor("ConcatV2") 2418def _convert_concatv2(pfor_input): 2419 n = pfor_input.num_inputs 2420 pfor_input.stack_inputs(stack_indices=range(n - 1)) 2421 axis = pfor_input.unstacked_input(n - 1) 2422 axis += math_ops.cast(axis >= 0, axis.dtype) 2423 return wrap( 2424 array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis), 2425 True) 2426 2427 2428@RegisterPFor("StridedSlice") 2429def _convert_strided_slice(pfor_input): 2430 inp = pfor_input.stacked_input(0) 2431 begin = pfor_input.unstacked_input(1) 2432 end = pfor_input.unstacked_input(2) 2433 strides = pfor_input.unstacked_input(3) 2434 begin_mask = pfor_input.get_attr("begin_mask") 2435 end_mask = pfor_input.get_attr("end_mask") 2436 ellipsis_mask = pfor_input.get_attr("ellipsis_mask") 2437 new_axis_mask = pfor_input.get_attr("new_axis_mask") 2438 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") 2439 2440 begin = array_ops.concat([[0], begin], axis=0) 2441 end = array_ops.concat([[0], end], axis=0) 2442 strides = array_ops.concat([[1], strides], axis=0) 2443 begin_mask = begin_mask << 1 | 1 2444 end_mask = end_mask << 1 | 1 2445 ellipsis_mask <<= 1 2446 new_axis_mask <<= 1 2447 shrink_axis_mask <<= 1 2448 return wrap( 2449 array_ops.strided_slice( 2450 inp, 2451 begin, 2452 end, 2453 strides, 2454 begin_mask=begin_mask, 2455 end_mask=end_mask, 2456 ellipsis_mask=ellipsis_mask, 2457 new_axis_mask=new_axis_mask, 2458 shrink_axis_mask=shrink_axis_mask), True) 2459 2460 2461@RegisterPFor("StridedSliceGrad") 2462def _convert_strided_slice_grad(pfor_input): 2463 shape = pfor_input.unstacked_input(0) 2464 begin = pfor_input.unstacked_input(1) 2465 end = pfor_input.unstacked_input(2) 2466 strides = pfor_input.unstacked_input(3) 2467 dy = pfor_input.stacked_input(4) 2468 begin_mask = pfor_input.get_attr("begin_mask") 2469 end_mask = pfor_input.get_attr("end_mask") 2470 ellipsis_mask = pfor_input.get_attr("ellipsis_mask") 2471 new_axis_mask = pfor_input.get_attr("new_axis_mask") 2472 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") 2473 2474 shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 2475 begin = array_ops.concat([[0], begin], axis=0) 2476 end = array_ops.concat([[0], end], axis=0) 2477 strides = array_ops.concat([[1], strides], axis=0) 2478 begin_mask = begin_mask << 1 | 1 2479 end_mask = end_mask << 1 | 1 2480 ellipsis_mask <<= 1 2481 new_axis_mask <<= 1 2482 shrink_axis_mask <<= 1 2483 return wrap( 2484 array_ops.strided_slice_grad( 2485 shape, 2486 begin, 2487 end, 2488 strides, 2489 dy, 2490 begin_mask=begin_mask, 2491 end_mask=end_mask, 2492 ellipsis_mask=ellipsis_mask, 2493 new_axis_mask=new_axis_mask, 2494 shrink_axis_mask=shrink_axis_mask), True) 2495 2496 2497@RegisterPFor("CheckNumerics") 2498def _convert_check_numerics(pfor_input): 2499 t = pfor_input.stacked_input(0) 2500 message = pfor_input.get_attr("message") 2501 return wrap(gen_array_ops.check_numerics(t, message), True) 2502 2503 2504# math_ops 2505 2506 2507@RegisterPFor("MatMul") 2508def _convert_matmul(pfor_input): 2509 # TODO(agarwal): Check if tiling is faster than two transposes. 2510 a, a_stacked, _ = pfor_input.input(0) 2511 b, b_stacked, _ = pfor_input.input(1) 2512 tr_a = pfor_input.get_attr("transpose_a") 2513 tr_b = pfor_input.get_attr("transpose_b") 2514 if a_stacked and b_stacked: 2515 output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True) 2516 return output 2517 elif a_stacked: 2518 if tr_a: 2519 a = array_ops.transpose(a, [0, 2, 1]) 2520 if a.shape.is_fully_defined(): 2521 x, y, z = a.shape 2522 else: 2523 x, y, z = [ 2524 array_ops.reshape(i, []) 2525 for i in array_ops.split(array_ops.shape(a), 3) 2526 ] 2527 a = array_ops.reshape(a, [x * y, z]) 2528 prod = math_ops.matmul(a, b, transpose_b=tr_b) 2529 return wrap(array_ops.reshape(prod, [x, y, -1]), True) 2530 else: 2531 assert b_stacked 2532 if tr_b: 2533 perm = [2, 0, 1] 2534 b = array_ops.transpose(b, perm) 2535 else: 2536 # As an optimization, if one of the first two dimensions is 1, then we can 2537 # reshape instead of transpose. 2538 # TODO(agarwal): This check can be done inside Transpose kernel. 2539 b_shape = array_ops.shape(b) 2540 min_dim = math_ops.minimum(b_shape[0], b_shape[1]) 2541 perm = control_flow_ops.cond( 2542 math_ops.equal(min_dim, 1), lambda: [0, 1, 2], lambda: [1, 0, 2]) 2543 new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]]) 2544 b = array_ops.transpose(b, perm) 2545 b = array_ops.reshape(b, new_shape) 2546 2547 if b.shape.is_fully_defined(): 2548 x, y, z = b.shape 2549 else: 2550 x, y, z = [ 2551 array_ops.reshape(i, []) 2552 for i in array_ops.split(array_ops.shape(b), 3) 2553 ] 2554 b = array_ops.reshape(b, [x, y * z]) 2555 prod = math_ops.matmul(a, b, transpose_a=tr_a) 2556 prod = array_ops.reshape(prod, [-1, y, z]) 2557 prod = array_ops.transpose(prod, [1, 0, 2]) 2558 return wrap(prod, True) 2559 2560 2561# TODO(rmlarsen): Use the converter of BatchMatMulV2 once compatibility window 2562# is met. 2563@RegisterPFor("BatchMatMul") 2564def _convert_batch_mat_mul(pfor_input): 2565 # TODO(agarwal): There may be a more efficient way to do this instead of 2566 # stacking the inputs. 2567 pfor_input.stack_inputs() 2568 x = pfor_input.stacked_input(0) 2569 y = pfor_input.stacked_input(1) 2570 adj_x = pfor_input.get_attr("adj_x") 2571 adj_y = pfor_input.get_attr("adj_y") 2572 2573 x = _flatten_first_two_dims(x) 2574 y = _flatten_first_two_dims(y) 2575 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) 2576 output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector) 2577 return wrap(output, True) 2578 2579 2580@RegisterPFor("BatchMatMulV2") 2581def _convert_batch_mat_mul_v2(pfor_input): 2582 pfor_input.expanddim_inputs_for_broadcast() 2583 x = pfor_input.input(0)[0] 2584 y = pfor_input.input(1)[0] 2585 adj_x = pfor_input.get_attr("adj_x") 2586 adj_y = pfor_input.get_attr("adj_y") 2587 2588 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) 2589 return wrap(output, True) 2590 2591 2592@RegisterPForWithArgs("Sum", math_ops.reduce_sum) 2593@RegisterPForWithArgs("Prod", math_ops.reduce_prod) 2594@RegisterPForWithArgs("Max", math_ops.reduce_max) 2595@RegisterPForWithArgs("Min", math_ops.reduce_min) 2596@RegisterPForWithArgs("Mean", math_ops.reduce_mean) 2597@RegisterPForWithArgs("All", math_ops.reduce_all) 2598@RegisterPForWithArgs("Any", math_ops.reduce_any) 2599def _convert_reduction(pfor_input, _, op_func): 2600 t = pfor_input.stacked_input(0) 2601 indices = pfor_input.unstacked_input(1) 2602 # Shift positive indices by one to account for the extra dimension. 2603 indices += math_ops.cast(indices >= 0, indices.dtype) 2604 keep_dims = pfor_input.get_attr("keep_dims") 2605 return wrap(op_func(t, indices, keepdims=keep_dims), True) 2606 2607 2608@RegisterPForWithArgs("ArgMax", math_ops.argmax) 2609@RegisterPForWithArgs("ArgMin", math_ops.argmin) 2610def _convert_argmax_argmin(pfor_input, _, op_func): 2611 t = pfor_input.stacked_input(0) 2612 dimension = pfor_input.unstacked_input(1) 2613 dimension += math_ops.cast(dimension >= 0, dimension.dtype) 2614 output_type = pfor_input.get_attr("output_type") 2615 return wrap(op_func(t, axis=dimension, output_type=output_type), True) 2616 2617 2618@RegisterPFor("Bucketize") 2619def _convert_bucketize(pfor_input): 2620 t = pfor_input.stacked_input(0) 2621 boundaries = pfor_input.get_attr("boundaries") 2622 return wrap(math_ops.bucketize(t, boundaries), True) 2623 2624 2625@RegisterPFor("ClipByValue") 2626def _convert_clip_by_value(pfor_input): 2627 t = pfor_input.stacked_input(0) 2628 clip_value_min = pfor_input.unstacked_input(1) 2629 clip_value_max = pfor_input.unstacked_input(2) 2630 return wrap(gen_math_ops.clip_by_value(t, clip_value_min, clip_value_max), 2631 True) 2632 2633 2634@RegisterPForWithArgs("Cumsum", math_ops.cumsum) 2635@RegisterPForWithArgs("Cumprod", math_ops.cumprod) 2636def _convert_cumfoo(pfor_input, _, op_func): 2637 t = pfor_input.stacked_input(0) 2638 axis = pfor_input.unstacked_input(1) 2639 # Shift positive indices by one to account for the extra dimension. 2640 axis += math_ops.cast(axis >= 0, axis.dtype) 2641 exclusive = pfor_input.get_attr("exclusive") 2642 reverse = pfor_input.get_attr("reverse") 2643 return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True) 2644 2645 2646@RegisterPFor("BiasAdd") 2647def _convert_biasadd(pfor_input): 2648 t, t_stacked, _ = pfor_input.input(0) 2649 bias, bias_stacked, _ = pfor_input.input(1) 2650 data_format = pfor_input.get_attr("data_format").decode() 2651 if bias_stacked: 2652 # BiasAdd only supports 1-D biases, so cast bias to match value and use Add. 2653 pfor_input.expanddim_inputs_for_broadcast() 2654 t, _, _ = pfor_input.input(0) 2655 bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype) 2656 if compat.as_bytes(data_format) == b"NCHW": 2657 b_shape = array_ops.shape(bias) 2658 new_b_shape = array_ops.concat( 2659 [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0) 2660 bias = array_ops.reshape(bias, new_b_shape) 2661 return wrap(math_ops.add(t, bias), True) 2662 else: 2663 assert t_stacked, "At least one input to BiasAdd should be loop variant." 2664 if compat.as_bytes(data_format) == b"NCHW": 2665 shape = array_ops.shape(t) 2666 flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0) 2667 t = array_ops.reshape(t, flattened_shape) 2668 t = nn_ops.bias_add(t, bias, data_format="NCHW") 2669 t = array_ops.reshape(t, shape) 2670 return wrap(t, True) 2671 return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True) 2672 2673 2674@RegisterPForWithArgs("UnsortedSegmentSum", math_ops.unsorted_segment_sum) 2675@RegisterPForWithArgs("UnsortedSegmentMax", math_ops.unsorted_segment_max) 2676@RegisterPForWithArgs("UnsortedSegmentMin", math_ops.unsorted_segment_min) 2677@RegisterPForWithArgs("UnsortedSegmentProd", math_ops.unsorted_segment_prod) 2678def _convert_unsortedsegmentsum(pfor_input, _, op_func): 2679 pfor_input.stack_inputs([0, 1]) 2680 data = pfor_input.stacked_input(0) 2681 segment_ids = pfor_input.stacked_input(1) 2682 # TODO(agarwal): handle stacked? 2683 num_segments = pfor_input.unstacked_input(2) 2684 if segment_ids.dtype != num_segments.dtype: 2685 segment_ids = math_ops.cast(segment_ids, dtypes.int64) 2686 num_segments = math_ops.cast(num_segments, dtypes.int64) 2687 dtype = segment_ids.dtype 2688 segment_shape = array_ops.shape(segment_ids, out_type=dtype) 2689 n = segment_shape[0] 2690 ones = array_ops.ones_like(segment_shape, dtype=dtype)[1:] 2691 segment_offset = num_segments * math_ops.range(n, dtype=dtype) 2692 segment_offset = array_ops.reshape(segment_offset, 2693 array_ops.concat([[n], ones], axis=0)) 2694 segment_ids += segment_offset 2695 num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast( 2696 n, dtypes.int64) 2697 output = op_func(data, segment_ids, num_segments) 2698 new_output_shape = array_ops.concat( 2699 [[n, -1], array_ops.shape(output)[1:]], axis=0) 2700 output = array_ops.reshape(output, new_output_shape) 2701 return wrap(output, True) 2702 2703 2704def _flatten_array_with_offset(ids, offset_delta, num_rows): 2705 """Flattens a rank 2 tensor, adding an offset to each row.""" 2706 # Note that if `ids` is rank 1, it is broadcast to rank 2. 2707 offset_delta = math_ops.cast(offset_delta, ids.dtype) 2708 n = math_ops.cast(num_rows, dtype=ids.dtype) 2709 offsets = math_ops.range( 2710 start=0, limit=n * offset_delta, delta=offset_delta, dtype=ids.dtype) 2711 offsets = array_ops.expand_dims(offsets, -1) 2712 ids += offsets 2713 return array_ops.reshape(ids, [-1]) 2714 2715 2716@RegisterPForWithArgs("SparseSegmentSum", math_ops.sparse_segment_sum_v2) 2717@RegisterPForWithArgs("SparseSegmentMean", math_ops.sparse_segment_mean_v2) 2718@RegisterPForWithArgs("SparseSegmentSqrtN", math_ops.sparse_segment_sqrt_n_v2) 2719@RegisterPForWithArgs("SparseSegmentSumWithNumSegments", 2720 math_ops.sparse_segment_sum_v2) 2721@RegisterPForWithArgs("SparseSegmentMeanWithNumSegments", 2722 math_ops.sparse_segment_mean_v2) 2723@RegisterPForWithArgs("SparseSegmentSqrtNWithNumSegments", 2724 math_ops.sparse_segment_sqrt_n_v2) 2725def _convert_sparse_segment(pfor_input, _, op_func): 2726 _, segment_ids_stacked, _ = pfor_input.input(2) 2727 if segment_ids_stacked: 2728 pfor_input.stack_inputs([1]) 2729 data, data_stacked, _ = pfor_input.input(0) 2730 indices, _, _ = pfor_input.input(1) 2731 num_inputs = len(pfor_input.inputs) 2732 assert num_inputs in (3, 4) 2733 if num_inputs == 3: 2734 # `segment_ids` needs to be unstacked since otherwise output sizes could 2735 # differ across pfor iterations. 2736 segment_ids = pfor_input.unstacked_input(2) 2737 num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1) 2738 else: 2739 segment_ids, _, _ = pfor_input.input(2) 2740 num_segments = pfor_input.unstacked_input(3) 2741 2742 n = pfor_input.pfor.loop_len_vector[0] 2743 if data_stacked: 2744 indices = _flatten_array_with_offset(indices, array_ops.shape(data)[1], n) 2745 data = _flatten_first_two_dims(data) 2746 else: 2747 indices = array_ops.reshape(indices, [-1]) 2748 segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n) 2749 2750 if num_inputs == 3: 2751 num_segments = None 2752 else: 2753 num_segments *= n 2754 output = op_func(data, indices, segment_ids, num_segments=num_segments) 2755 output = _unflatten_first_dim(output, [n]) 2756 return wrap(output, True) 2757 2758 2759@RegisterPForWithArgs("SparseSegmentMeanGrad", 2760 math_ops.sparse_segment_mean_grad) 2761@RegisterPForWithArgs("SparseSegmentSqrtNGrad", 2762 math_ops.sparse_segment_sqrt_n_grad) 2763def _convert_sparse_segment_grad(pfor_input, _, op_func): 2764 grad = pfor_input.stacked_input(0) 2765 indices = pfor_input.unstacked_input(1) 2766 segment_ids = pfor_input.unstacked_input(2) 2767 dim0 = pfor_input.unstacked_input(3) 2768 2769 n = pfor_input.pfor.loop_len_vector[0] 2770 indices = _flatten_array_with_offset(indices, dim0, n) 2771 num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1) 2772 segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n) 2773 grad = _flatten_first_two_dims(grad) 2774 dim0 *= n 2775 output = op_func(grad, indices, segment_ids, dim0) 2776 output = _unflatten_first_dim(output, [n]) 2777 return wrap(output, True) 2778 2779 2780@RegisterPFor("Cast") 2781def _convert_cast(pfor_input): 2782 inp = pfor_input.stacked_input(0) 2783 dtype = pfor_input.get_attr("DstT") 2784 return wrap(math_ops.cast(inp, dtype), True) 2785 2786 2787@RegisterPForWithArgs("Abs", math_ops.abs) 2788@RegisterPForWithArgs("Acos", math_ops.acos) 2789@RegisterPForWithArgs("Acosh", math_ops.acosh) 2790@RegisterPForWithArgs("Add", math_ops.add) 2791@RegisterPForWithArgs("AddV2", math_ops.add_v2) 2792@RegisterPForWithArgs("Angle", math_ops.angle) 2793@RegisterPForWithArgs("Asin", math_ops.asin) 2794@RegisterPForWithArgs("Asinh", math_ops.asinh) 2795@RegisterPForWithArgs("Atan", math_ops.atan) 2796@RegisterPForWithArgs("Atan2", math_ops.atan2) 2797@RegisterPForWithArgs("Atanh", math_ops.atanh) 2798@RegisterPForWithArgs("BesselI0", special_math_ops.bessel_i0) 2799@RegisterPForWithArgs("BesselI1", special_math_ops.bessel_i1) 2800@RegisterPForWithArgs("BesselI0e", special_math_ops.bessel_i0e) 2801@RegisterPForWithArgs("BesselI1e", special_math_ops.bessel_i1e) 2802@RegisterPForWithArgs("BesselK0", special_math_ops.bessel_k0) 2803@RegisterPForWithArgs("BesselK1", special_math_ops.bessel_k1) 2804@RegisterPForWithArgs("BesselK0e", special_math_ops.bessel_k0e) 2805@RegisterPForWithArgs("BesselK1e", special_math_ops.bessel_k1e) 2806@RegisterPForWithArgs("BesselJ0", special_math_ops.bessel_j0) 2807@RegisterPForWithArgs("BesselJ1", special_math_ops.bessel_j1) 2808@RegisterPForWithArgs("BesselY0", special_math_ops.bessel_y0) 2809@RegisterPForWithArgs("BesselY1", special_math_ops.bessel_y1) 2810@RegisterPForWithArgs("BitwiseAnd", bitwise_ops.bitwise_and) 2811@RegisterPForWithArgs("BitwiseOr", bitwise_ops.bitwise_or) 2812@RegisterPForWithArgs("BitwiseXor", bitwise_ops.bitwise_xor) 2813@RegisterPForWithArgs("Ceil", math_ops.ceil) 2814@RegisterPForWithArgs("Complex", math_ops.complex) 2815@RegisterPForWithArgs("ComplexAbs", math_ops.complex_abs) 2816@RegisterPForWithArgs("Conj", math_ops.conj) 2817@RegisterPForWithArgs("Cos", math_ops.cos) 2818@RegisterPForWithArgs("Cosh", math_ops.cosh) 2819@RegisterPForWithArgs("Dawsn", special_math_ops.dawsn) 2820@RegisterPForWithArgs("Digamma", math_ops.digamma) 2821@RegisterPForWithArgs("Div", math_ops.div) 2822@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan) 2823@RegisterPForWithArgs("Elu", nn_ops.elu) 2824@RegisterPForWithArgs("Erf", math_ops.erf) 2825@RegisterPForWithArgs("Erfc", math_ops.erfc) 2826@RegisterPForWithArgs("Erfinv", math_ops.erfinv) 2827@RegisterPForWithArgs("Exp", math_ops.exp) 2828@RegisterPForWithArgs("Expint", special_math_ops.expint) 2829@RegisterPForWithArgs("Expm1", math_ops.expm1) 2830@RegisterPForWithArgs("Floor", math_ops.floor) 2831@RegisterPForWithArgs("FloorDiv", math_ops.floor_div) 2832@RegisterPForWithArgs("FloorMod", math_ops.floor_mod) 2833@RegisterPForWithArgs("FresnelCos", special_math_ops.fresnel_cos) 2834@RegisterPForWithArgs("FresnelSin", special_math_ops.fresnel_sin) 2835@RegisterPForWithArgs("Greater", math_ops.greater) 2836@RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal) 2837@RegisterPForWithArgs("Igamma", math_ops.igamma) 2838@RegisterPForWithArgs("IgammaGradA", math_ops.igamma_grad_a) 2839@RegisterPForWithArgs("Igammac", math_ops.igammac) 2840@RegisterPForWithArgs("Imag", math_ops.imag) 2841@RegisterPForWithArgs("Inv", math_ops.inv) 2842@RegisterPForWithArgs("Invert", bitwise_ops.invert) 2843@RegisterPForWithArgs("IsFinite", math_ops.is_finite) 2844@RegisterPForWithArgs("IsInf", math_ops.is_inf) 2845@RegisterPForWithArgs("IsNan", math_ops.is_nan) 2846@RegisterPForWithArgs("LeftShift", bitwise_ops.left_shift) 2847@RegisterPForWithArgs("Less", math_ops.less) 2848@RegisterPForWithArgs("LessEqual", math_ops.less_equal) 2849@RegisterPForWithArgs("Lgamma", math_ops.lgamma) 2850@RegisterPForWithArgs("Log", math_ops.log) 2851@RegisterPForWithArgs("Log1p", math_ops.log1p) 2852@RegisterPForWithArgs("LogicalAnd", math_ops.logical_and) 2853@RegisterPForWithArgs("LogicalNot", math_ops.logical_not) 2854@RegisterPForWithArgs("LogicalOr", math_ops.logical_or) 2855@RegisterPForWithArgs("LogicalXor", math_ops.logical_xor) 2856@RegisterPForWithArgs("Maximum", math_ops.maximum) 2857@RegisterPForWithArgs("Minimum", math_ops.minimum) 2858@RegisterPForWithArgs("Mod", math_ops.mod) 2859@RegisterPForWithArgs("Mul", math_ops.multiply) 2860@RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan) 2861@RegisterPForWithArgs("Ndtri", math_ops.ndtri) 2862@RegisterPForWithArgs("Neg", math_ops.negative) 2863@RegisterPForWithArgs("Polygamma", math_ops.polygamma) 2864@RegisterPForWithArgs("Pow", math_ops.pow) 2865@RegisterPForWithArgs("Real", math_ops.real) 2866@RegisterPForWithArgs("RealDiv", math_ops.divide) 2867@RegisterPForWithArgs("Reciprocal", math_ops.reciprocal) 2868@RegisterPForWithArgs("Relu", nn_ops.relu) 2869@RegisterPForWithArgs("Relu6", nn_ops.relu6) 2870@RegisterPForWithArgs("RightShift", bitwise_ops.right_shift) 2871@RegisterPForWithArgs("Rint", math_ops.rint) 2872@RegisterPForWithArgs("Round", math_ops.round) 2873@RegisterPForWithArgs("Rsqrt", math_ops.rsqrt) 2874@RegisterPForWithArgs("Selu", nn_ops.selu) 2875@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid) 2876@RegisterPForWithArgs("Sign", math_ops.sign) 2877@RegisterPForWithArgs("Sin", math_ops.sin) 2878@RegisterPForWithArgs("Sinh", math_ops.sinh) 2879@RegisterPForWithArgs("Softplus", nn_ops.softplus) 2880@RegisterPForWithArgs("Softsign", nn_ops.softsign) 2881@RegisterPForWithArgs("Spence", special_math_ops.spence) 2882@RegisterPForWithArgs("Sqrt", math_ops.sqrt) 2883@RegisterPForWithArgs("Square", math_ops.square) 2884@RegisterPForWithArgs("SquaredDifference", math_ops.squared_difference) 2885@RegisterPForWithArgs("Sub", math_ops.subtract) 2886@RegisterPForWithArgs("Tan", math_ops.tan) 2887@RegisterPForWithArgs("Tanh", math_ops.tanh) 2888@RegisterPForWithArgs("TruncateDiv", math_ops.truncate_div) 2889@RegisterPForWithArgs("TruncateMod", math_ops.truncate_mod) 2890@RegisterPForWithArgs("Xdivy", math_ops.xdivy) 2891@RegisterPForWithArgs("Xlogy", math_ops.xlogy) 2892@RegisterPForWithArgs("Xlog1py", math_ops.xlog1py) 2893@RegisterPForWithArgs("Zeta", math_ops.zeta) 2894def _convert_cwise(pfor_input, op_type, op_func): 2895 # Note that ops handled here do not have attributes except those listed below 2896 # and hence don't need extra arguments passed to the cwise_op call below. 2897 for attr in pfor_input.op.node_def.attr.keys(): 2898 assert attr in [u"T", u"Tout", u"_xla_compile_id"], (op_type, attr) 2899 if pfor_input.num_inputs > 1: 2900 pfor_input.expanddim_inputs_for_broadcast() 2901 return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) 2902 2903 2904@RegisterPFor("LeakyRelu") 2905def _convert_leaky_relu(pfor_input): 2906 t = pfor_input.stacked_input(0) 2907 alpha = pfor_input.get_attr("alpha") 2908 return wrap(gen_nn_ops.leaky_relu(t, alpha=alpha), True) 2909 2910 2911@RegisterPFor("Equal") 2912def _convert_equal(pfor_input): 2913 pfor_input.expanddim_inputs_for_broadcast() 2914 x = pfor_input.input(0)[0] 2915 y = pfor_input.input(1)[0] 2916 incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error") 2917 return wrap(gen_math_ops.equal( 2918 x, y, incompatible_shape_error=incompatible_shape_error), True) 2919 2920 2921@RegisterPFor("NotEqual") 2922def _convert_not_equal(pfor_input): 2923 pfor_input.expanddim_inputs_for_broadcast() 2924 x = pfor_input.input(0)[0] 2925 y = pfor_input.input(1)[0] 2926 incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error") 2927 return wrap(gen_math_ops.not_equal( 2928 x, y, incompatible_shape_error=incompatible_shape_error), True) 2929 2930 2931@RegisterPFor("ApproximateEqual") 2932def _convert_approximate_equal(pfor_input): 2933 pfor_input.expanddim_inputs_for_broadcast() 2934 x = pfor_input.input(0)[0] 2935 y = pfor_input.input(1)[0] 2936 tolerance = pfor_input.get_attr("tolerance") 2937 return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True) 2938 2939 2940@RegisterPFor("Shape") 2941def _convert_shape(pfor_input): 2942 out_type = pfor_input.get_attr("out_type") 2943 return wrap( 2944 array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:], 2945 False) 2946 2947 2948@RegisterPFor("ShapeN") 2949def _convert_shape_n(pfor_input): 2950 out_type = pfor_input.get_attr("out_type") 2951 shapes = [ 2952 array_ops.shape(x, out_type=out_type)[1:] if stacked else array_ops.shape( 2953 x, out_type=out_type) for x, stacked, _ in pfor_input.inputs 2954 ] 2955 return [wrap(x, False) for x in shapes] 2956 2957 2958@RegisterPFor("Size") 2959def _convert_size(pfor_input): 2960 out_type = pfor_input.get_attr("out_type") 2961 n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type) 2962 return wrap( 2963 array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n, 2964 False) 2965 2966 2967@RegisterPFor("Rank") 2968def _convert_rank(pfor_input): 2969 return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False) 2970 2971 2972@RegisterPFor("AddN") 2973def _convert_addn(pfor_input): 2974 # AddN does not support broadcasting. 2975 pfor_input.stack_inputs(tile_variants=False) 2976 return _wrap_and_tile_variants( 2977 math_ops.add_n([x.t for x in pfor_input.inputs]), 2978 pfor_input.pfor.loop_len_vector) 2979 2980 2981@RegisterPFor("Cross") 2982def _convert_cross(pfor_input): 2983 pfor_input.stack_inputs() 2984 a = pfor_input.stacked_input(0) 2985 b = pfor_input.stacked_input(1) 2986 return wrap(math_ops.cross(a, b), True) 2987 2988 2989@RegisterPFor("BiasAddGrad") 2990def _convert_biasaddgrad(pfor_input): 2991 grad = pfor_input.stacked_input(0) 2992 fmt = pfor_input.get_attr("data_format") 2993 if fmt == b"NCHW": 2994 output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False) 2995 else: 2996 grad_shape = array_ops.shape(grad) 2997 last_dim_shape = grad_shape[-1] 2998 first_dim_shape = grad_shape[0] 2999 output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape]) 3000 output = math_ops.reduce_sum(output, axis=[1], keepdims=False) 3001 return wrap(output, True) 3002 3003 3004# Some required ops are not exposed under the tf namespace. Hence relying on 3005# _create_op to create them. 3006@RegisterPForWithArgs("EluGrad") 3007@RegisterPForWithArgs("LeakyReluGrad") 3008@RegisterPForWithArgs("ReciprocalGrad") 3009@RegisterPForWithArgs("Relu6Grad") 3010@RegisterPForWithArgs("ReluGrad") 3011@RegisterPForWithArgs("RsqrtGrad") 3012@RegisterPForWithArgs("SeluGrad") 3013@RegisterPForWithArgs("SigmoidGrad") 3014@RegisterPForWithArgs("SoftplusGrad") 3015@RegisterPForWithArgs("SoftsignGrad") 3016@RegisterPForWithArgs("SqrtGrad") 3017@RegisterPForWithArgs("TanhGrad") 3018def _convert_grads(pfor_input, op_type, *args, **kw_args): 3019 del args 3020 del kw_args 3021 # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we 3022 # have to use tiling here. 3023 pfor_input.stack_inputs() 3024 outputs = _create_op( 3025 op_type, [x.t for x in pfor_input.inputs], 3026 [x.dtype for x in pfor_input.outputs], 3027 attrs=pfor_input.op.node_def.attr).outputs 3028 return [wrap(x, True) for x in outputs] 3029 3030 3031@RegisterPFor("Select") 3032def _convert_select(pfor_input): 3033 pfor_input.stack_inputs() 3034 cond = pfor_input.stacked_input(0) 3035 t = pfor_input.stacked_input(1) 3036 e = pfor_input.stacked_input(2) 3037 cond_rank = array_ops.rank(cond) 3038 cond, t, e = control_flow_ops.cond( 3039 cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]), 3040 lambda: [cond, t, e]) 3041 outputs = _create_op( 3042 pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs], 3043 attrs=pfor_input.op.node_def.attr).outputs 3044 n = pfor_input.pfor.loop_len_vector 3045 out = control_flow_ops.cond(cond_rank > 1, 3046 lambda: _unflatten_first_dim(outputs[0], n), 3047 lambda: outputs[0]) 3048 return [wrap(out, True) for x in outputs] 3049 3050 3051@RegisterPFor("SelectV2") 3052def _convert_selectv2(pfor_input): 3053 pfor_input.expanddim_inputs_for_broadcast() 3054 cond = pfor_input.input(0)[0] 3055 t = pfor_input.input(1)[0] 3056 e = pfor_input.input(2)[0] 3057 out = array_ops.where_v2(cond, t, e) 3058 return wrap(out, True) 3059 3060 3061# random_ops 3062 3063 3064def _transpose_dim_to_front(x, dim): 3065 rank = array_ops.rank(x) 3066 return array_ops.transpose( 3067 x, 3068 perm=array_ops.concat( 3069 [[dim], math_ops.range(0, dim), 3070 math_ops.range(dim + 1, rank)], 3071 axis=0)) 3072 3073 3074@RegisterPForWithArgs("RandomUniform") 3075@RegisterPForWithArgs("RandomUniformInt") 3076@RegisterPForWithArgs("RandomStandardNormal") 3077@RegisterPForWithArgs("TruncatedNormal") 3078def _convert_random(pfor_input, op_type, *args, **kw_args): 3079 del args 3080 del kw_args 3081 inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)] 3082 # inputs[0] is "shape" 3083 inputs[0] = array_ops.concat([pfor_input.pfor.loop_len_vector, inputs[0]], 3084 axis=0) 3085 logging.warning( 3086 "Note that %s inside pfor op may not give same output as " 3087 "inside a sequential loop.", op_type) 3088 outputs = _create_op( 3089 op_type, 3090 inputs, [x.dtype for x in pfor_input.outputs], 3091 attrs=pfor_input.op.node_def.attr).outputs 3092 return [wrap(x, True) for x in outputs] 3093 3094 3095@RegisterPFor("RandomGamma") 3096@RegisterPFor("RandomPoissonV2") 3097def _convert_random_with_param(pfor_input): 3098 shape = pfor_input.unstacked_input(0) 3099 # param is lam (Poisson rate) or alpha (Gamma shape). 3100 param, param_stacked, _ = pfor_input.input(1) 3101 logging.warning( 3102 "Note that %s inside pfor op may not give same output as " 3103 "inside a sequential loop.", pfor_input.op_type) 3104 3105 if param_stacked: 3106 samples = _create_op( 3107 pfor_input.op_type, 3108 inputs=[shape, param], 3109 op_dtypes=[x.dtype for x in pfor_input.outputs], 3110 attrs=pfor_input.op.node_def.attr).outputs[0] 3111 loop_dim = array_ops.shape(shape)[0] 3112 stacked_samples = _transpose_dim_to_front(samples, loop_dim) 3113 else: 3114 shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 3115 stacked_samples = _create_op( 3116 pfor_input.op_type, 3117 inputs=[shape, param], 3118 op_dtypes=[x.dtype for x in pfor_input.outputs], 3119 attrs=pfor_input.op.node_def.attr).outputs[0] 3120 3121 return wrap(stacked_samples, True) 3122 3123 3124@RegisterPFor("Multinomial") 3125def _convert_multinomial(pfor_input): 3126 logits, logits_stacked, _ = pfor_input.input(0) 3127 num_samples = pfor_input.unstacked_input(1) 3128 seed = pfor_input.get_attr("seed") 3129 seed2 = pfor_input.get_attr("seed2") 3130 output_dtype = pfor_input.get_attr("output_dtype") 3131 logging.warning( 3132 "Note that Multinomial inside pfor op may not give same output as " 3133 "inside a sequential loop.") 3134 3135 n = pfor_input.pfor.loop_len_vector[0] 3136 if logits_stacked: 3137 flattened_logits = _flatten_first_two_dims(logits) 3138 samples = gen_random_ops.multinomial( 3139 flattened_logits, 3140 num_samples, 3141 seed=seed, 3142 seed2=seed2, 3143 output_dtype=output_dtype) 3144 stacked_samples = _unflatten_first_dim(samples, [n]) 3145 else: 3146 samples = gen_random_ops.multinomial( 3147 logits, 3148 num_samples * n, 3149 seed=seed, 3150 seed2=seed2, 3151 output_dtype=output_dtype) 3152 stacked_samples = array_ops.transpose( 3153 array_ops.reshape(samples, [-1, n, num_samples]), [1, 0, 2]) 3154 3155 return wrap(stacked_samples, True) 3156 3157 3158@RegisterPFor("StatelessMultinomial") 3159@RegisterPFor("StatelessParameterizedTruncatedNormal") 3160@RegisterPFor("StatelessRandomBinomial") 3161@RegisterPFor("StatelessRandomGammaV2") 3162@RegisterPFor("StatelessRandomNormal") 3163@RegisterPFor("StatelessRandomPoisson") 3164@RegisterPFor("StatelessRandomUniform") 3165@RegisterPFor("StatelessRandomUniformInt") 3166@RegisterPFor("StatelessRandomUniformFullInt") 3167@RegisterPFor("StatelessTruncatedNormal") 3168def _convert_stateless_multinomial(pfor_input): 3169 # Unlike stateful random ops, for stateless ones we want better 3170 # reproducibility based on seed. Hence we don't want to use a similar strategy 3171 # as used for stateful ones where we generate a possibly different set of 3172 # random numbers under vectorization. 3173 # Unfortunately, the kernels currently are not necessarily setup to do this 3174 # efficiently and hence we fallback to a sequential loop for vectorization. 3175 return _fallback_converter(pfor_input, warn=False) 3176 3177 3178# linalg_ops 3179 3180 3181@RegisterPForWithArgs("XlaEinsum") 3182@RegisterPForWithArgs("Einsum") 3183def _convert_einsum(pfor_input, op_type): 3184 first_input, first_input_stacked, _ = pfor_input.input(0) 3185 second_input, second_input_stacked, _ = pfor_input.input(1) 3186 3187 # Parse the einsum equation. 3188 equation = pfor_input.get_attr("equation").decode("utf-8") 3189 input_expr, output_expr = equation.split("->") 3190 input_a_expr, input_b_expr = input_expr.split(",") 3191 3192 # pick a placeholder symbol to use for the new axis 3193 chosen_symbol = None 3194 for s in string.ascii_letters: 3195 if s in equation: 3196 continue 3197 else: 3198 chosen_symbol = s 3199 break 3200 3201 if chosen_symbol is None: 3202 raise ValueError("Could not figure out what symbol to use for new axis.") 3203 3204 assert first_input_stacked or second_input_stacked 3205 if first_input_stacked: 3206 input_a_expr = "{}{}".format(chosen_symbol, input_a_expr) 3207 if second_input_stacked: 3208 input_b_expr = "{}{}".format(chosen_symbol, input_b_expr) 3209 output_expr = "{}{}".format(chosen_symbol, output_expr) 3210 3211 new_equation = "{},{}->{}".format(input_a_expr, input_b_expr, output_expr) 3212 if op_type == "XlaEinsum": 3213 result = xla.einsum(equation=new_equation, a=first_input, b=second_input) 3214 else: 3215 assert op_type == "Einsum" 3216 result = special_math_ops.einsum(new_equation, first_input, second_input) 3217 3218 return wrap(result, True) 3219 3220 3221@RegisterPFor("Cholesky") 3222def _convert_cholesky(pfor_input): 3223 t = pfor_input.stacked_input(0) 3224 return wrap(linalg_ops.cholesky(t), True) 3225 3226 3227@RegisterPFor("LogMatrixDeterminant") 3228def _convert_log_matrix_determinant(pfor_input): 3229 t = pfor_input.stacked_input(0) 3230 return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)] 3231 3232 3233@RegisterPFor("MatrixInverse") 3234def _convert_matrix_inverse(pfor_input): 3235 t = pfor_input.stacked_input(0) 3236 adjoint = pfor_input.get_attr("adjoint") 3237 return wrap(gen_linalg_ops.matrix_inverse(t, adjoint=adjoint), True) 3238 3239 3240@RegisterPFor("MatrixSolve") 3241def _convert_matrix_solve(pfor_input): 3242 pfor_input.stack_inputs() 3243 matrix = pfor_input.stacked_input(0) 3244 rhs = pfor_input.stacked_input(1) 3245 adjoint = pfor_input.get_attr("adjoint") 3246 output = gen_linalg_ops.matrix_solve( 3247 matrix, rhs, adjoint=adjoint) 3248 return wrap(output, True) 3249 3250 3251@RegisterPFor("MatrixTriangularSolve") 3252def _convert_matrix_triangular_solve(pfor_input): 3253 pfor_input.expanddim_inputs_for_broadcast() 3254 matrix = pfor_input.input(0)[0] 3255 rhs = pfor_input.input(1)[0] 3256 lower = pfor_input.get_attr("lower") 3257 adjoint = pfor_input.get_attr("adjoint") 3258 output = linalg_ops.matrix_triangular_solve( 3259 matrix, rhs, lower=lower, adjoint=adjoint) 3260 return wrap(output, True) 3261 3262 3263@RegisterPFor("SelfAdjointEigV2") 3264def _convert_self_adjoint_eig(pfor_input): 3265 t = pfor_input.stacked_input(0) 3266 compute_v = pfor_input.get_attr("compute_v") 3267 e, v = gen_linalg_ops.self_adjoint_eig_v2(t, compute_v=compute_v) 3268 # If compute_v is False, v will have shape [0]. 3269 return wrap(e, True), wrap(v, compute_v) 3270 3271 3272# logging_ops 3273 3274 3275@RegisterPFor("Assert") 3276def _convert_assert(pfor_input): 3277 cond, cond_stacked, _ = pfor_input.input(0) 3278 if cond_stacked: 3279 cond = math_ops.reduce_all(cond) 3280 3281 data_list = [x.t for x in pfor_input.inputs][1:] 3282 return _create_op( 3283 "Assert", [cond] + data_list, [], attrs=pfor_input.op.node_def.attr) 3284 3285 3286@RegisterPFor("Print") 3287def _convert_print(pfor_input): 3288 # Note that we don't stack all the inputs. Hence unstacked values are printed 3289 # once here vs multiple times in a while_loop. 3290 pfor_input.stack_inputs([0]) 3291 outputs = _create_op( 3292 "Print", [x.t for x in pfor_input.inputs], 3293 [x.dtype for x in pfor_input.outputs], 3294 attrs=pfor_input.op.node_def.attr).outputs 3295 return [wrap(x, True) for x in outputs] 3296 3297 3298# data_flow_ops 3299 3300# TensorArray conversion is tricky since we don't support arrays of 3301# TensorArrays. For converting them, we consider two distinct cases: 3302# 3303# 1. The array is constructed outside the pfor call, and read/written inside the 3304# loop. 3305# This is an easier case since we don't need to make an array of TensorArrays. 3306# A correctness requirement is that these parallel iterations shouldn't attempt 3307# to write to the same location. Hence at conversion time we disallow indices to 3308# be loop-invariant as that would guarantee a collision. Even if the indices are 3309# not loop-invariant, they could conflict and that shall trigger runtime errors. 3310# 3311# 2. The array is constructed and used entirely inside each pfor iteration. 3312# For simplicity, here we require that the indices used for write/scatter are 3313# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in 3314# different pfor iterations. We consider two sub_cases: 3315# 3316# 2a Elements written to the array are "stacked" 3317# To simulate multiple TensorArrays, we may increase the dimension of each 3318# element of the array. i.e. the i_th row of the j_th entry of the converted 3319# TensorArray corresponds to the j_th entry of the TensorArray in the i_th 3320# pfor iteration. 3321# 3322# 2b Elements written to the array are "unstacked" 3323# In this case we don't increase the dimensions to avoid redundant tiling. Each 3324# iteration is trying to write the same value. So we convert that to a single 3325# write. 3326# 3327# Here are some tricks used to implement the above: 3328# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of 3329# trying to trace whether future writes are stacked or unstacked in order to set 3330# this attr, we set it to correspond to unknown shape. 3331# - We use the "flow" output of the different ops to track whether the array 3332# elements are stacked or unstacked. If a stacked write/scatter is done, we make 3333# the flow stacked as well. 3334# - We use some heuristic traversal of the graph to track whether the 3335# TensorArray handle was created inside or outside the pfor loop. 3336 3337 3338@RegisterPFor("TensorArrayV3") 3339def _convert_tensor_array_v3(pfor_input): 3340 size = pfor_input.unstacked_input(0) 3341 dtype = pfor_input.get_attr("dtype") 3342 dynamic_size = pfor_input.get_attr("dynamic_size") 3343 clear_after_read = pfor_input.get_attr("clear_after_read") 3344 identical_element_shapes = pfor_input.get_attr("identical_element_shapes") 3345 tensor_array_name = pfor_input.get_attr("tensor_array_name") 3346 handle, flow = data_flow_ops.tensor_array_v3( 3347 size, 3348 dtype=dtype, 3349 # We don't set element shape since we don't know if writes are stacked or 3350 # not yet. 3351 element_shape=None, 3352 dynamic_size=dynamic_size, 3353 clear_after_read=clear_after_read, 3354 identical_element_shapes=identical_element_shapes, 3355 tensor_array_name=tensor_array_name) 3356 # Note we keep flow unstacked for now since we don't know if writes will be 3357 # stacked or not. 3358 return wrap(handle, False), wrap(flow, False) 3359 3360 3361@RegisterPFor("TensorArraySizeV3") 3362def _convert_tensor_array_size_v3(pfor_input): 3363 handle = pfor_input.unstacked_input(0) 3364 flow, flow_stacked, _ = pfor_input.input(1) 3365 if flow_stacked: 3366 flow = _unstack_flow(flow) 3367 size = data_flow_ops.tensor_array_size_v3(handle, flow) 3368 return wrap(size, False) 3369 3370 3371def _handle_inside_pfor(pfor_input, handle): 3372 """Returns True if handle was created inside the pfor loop.""" 3373 # We use some heuristic to find the original TensorArray creation op. 3374 # The logic should handle the common cases (except cond based subgraphs). 3375 # In theory the user could perform different operations on the handle (like 3376 # Reshape, stack multiple handles, etc) which could break this logic. 3377 # TODO(agarwal): handle Switch/Merge. 3378 while handle.op.type in ("Enter", "Identity"): 3379 handle = handle.op.inputs[0] 3380 if handle.op.type not in [ 3381 "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape" 3382 ]: 3383 raise ValueError("Unable to find source for handle %s" % handle) 3384 else: 3385 return pfor_input.pfor.op_is_inside_loop(handle.op) 3386 3387 3388def _unstack_flow(value): 3389 # TODO(agarwal): consider looking if this is a Tile op then get its input. 3390 # This may avoid running the Tile operations. 3391 return array_ops.gather(value, 0) 3392 3393 3394@RegisterPFor("TensorArrayReadV3") 3395def _convert_tensor_array_read_v3(pfor_input): 3396 handle = pfor_input.unstacked_input(0) 3397 index, index_stacked, _ = pfor_input.input(1) 3398 dtype = pfor_input.get_attr("dtype") 3399 flow, flow_stacked, _ = pfor_input.input(2) 3400 if flow_stacked: 3401 flow = _unstack_flow(flow) 3402 3403 is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3404 if is_inside_pfor: 3405 # Note that if we are inside a control flow construct inside the pfor, and 3406 # only some of the iterations are doing the read (i.e. 3407 # `all_indices_partitioned` is True), then the read operation should only 3408 # return values for the currently active pfor iterations (`all_indices` 3409 # below). Hence, whenever the returned value is stacked (i.e. `flow` is 3410 # stacked), we may need to do an extra gather after reading the values. Also 3411 # note that if `is_inside` is false, then values in the tensor array are 3412 # unstacked. So the check is only needed in this branch. 3413 all_indices = pfor_input.pfor.all_indices 3414 all_indices_partitioned = pfor_input.pfor.all_indices_partitioned 3415 # Note: flow_stacked indicates if values in the TensorArray are stacked or 3416 # not. 3417 if index_stacked: 3418 if flow_stacked: 3419 raise ValueError( 3420 "It looks like TensorArrayReadV3 was called on a TensorArray whose" 3421 " values are not loop-invariant, and the read indices were also" 3422 " not loop invariant. This is currently unsupported.") 3423 value = data_flow_ops.tensor_array_gather_v3( 3424 handle, index, flow, dtype=dtype) 3425 return wrap(value, True) 3426 value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype) 3427 if flow_stacked and all_indices_partitioned: 3428 value = array_ops.gather(value, all_indices) 3429 return wrap(value, flow_stacked) 3430 # Values in the TensorArray should be unstacked (since different iterations 3431 # couldn't write to the same location). So whether output is stacked or not 3432 # depends on index_stacked. 3433 if index_stacked: 3434 value = data_flow_ops.tensor_array_gather_v3( 3435 handle, index, flow, dtype=dtype) 3436 else: 3437 value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype) 3438 return wrap(value, index_stacked) 3439 3440 3441@RegisterPFor("TensorArrayWriteV3") 3442def _convert_tensor_array_write_v3(pfor_input): 3443 handle = pfor_input.unstacked_input(0) 3444 index, index_stacked, _ = pfor_input.input(1) 3445 value, value_stacked, _ = pfor_input.input(2) 3446 flow, flow_stacked, _ = pfor_input.input(3) 3447 if value_stacked and pfor_input.pfor.all_indices_partitioned: 3448 # Looks like we are in a control flow in a pfor where not all iterations are 3449 # active now. We don't allow that since that could lead to different indices 3450 # having different shapes which will be hard to merge later. 3451 raise ValueError("Writing non loop invariant values to TensorArray from " 3452 "inside a while_loop/cond not supported.") 3453 if flow_stacked: 3454 flow = _unstack_flow(flow) 3455 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3456 if is_inside: 3457 if index_stacked: 3458 raise ValueError("Need indices for %s to be loop invariant" % handle) 3459 if not flow_stacked and not value_stacked: 3460 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) 3461 return wrap(flow_out, False) 3462 else: 3463 if not value_stacked: 3464 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3465 # TODO(agarwal): Note that if flow is unstacked and value is stacked, then 3466 # this may or may not be a safe situation. flow is unstacked both for a 3467 # freshly created TensorArray, as well as after unstacked values are 3468 # written to it. If it is the latter, then we cannot write a stacked value 3469 # now since that may cause runtime errors due to different shapes in the 3470 # array. At the moment we are not able to handle this gracefully and 3471 # distinguish between the two cases. That would require some heuristic 3472 # traversal of the graph to figure out whether all the writes are 3473 # unstacked or not. 3474 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) 3475 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3476 else: 3477 if not index_stacked: 3478 raise ValueError("Need indices for %s to be not loop invariant" % handle) 3479 # Note that even when index_stacked is true, actual values in index may 3480 # still not be unique. However that will cause runtime error when executing 3481 # the scatter operation below. 3482 if not value_stacked: 3483 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3484 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow) 3485 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3486 3487 3488def _transpose_first_two_dims(value): 3489 # TODO(agarwal): optimize if one of the dims == 1. 3490 value_shape = array_ops.shape(value) 3491 v0 = value_shape[0] 3492 v1 = value_shape[1] 3493 value = array_ops.reshape(value, [v0, v1, -1]) 3494 value = array_ops.transpose(value, [1, 0, 2]) 3495 new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0) 3496 return array_ops.reshape(value, new_shape) 3497 3498 3499@RegisterPFor("TensorArrayGatherV3") 3500def _convert_tensor_array_gather_v3(pfor_input): 3501 handle = pfor_input.unstacked_input(0) 3502 indices, indices_stacked, _ = pfor_input.input(1) 3503 indices = array_ops.reshape(indices, [-1]) 3504 flow, flow_stacked, _ = pfor_input.input(2) 3505 if flow_stacked: 3506 flow = _unstack_flow(flow) 3507 dtype = pfor_input.get_attr("dtype") 3508 # TODO(agarwal): support element_shape attr? 3509 3510 n = pfor_input.pfor.loop_len_vector 3511 value = data_flow_ops.tensor_array_gather_v3( 3512 handle, indices, flow, dtype=dtype) 3513 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3514 if is_inside: 3515 # flow_stacked indicates if values in the TensorArray are stacked or not. 3516 if indices_stacked: 3517 if flow_stacked: 3518 raise ValueError( 3519 "It looks like TensorArrayGatherV3 was called on a TensorArray " 3520 "whose values are not loop-invariant, and the indices were also " 3521 "not loop invariant. This is currently unsupported.") 3522 else: 3523 value = _unflatten_first_dim(value, n) 3524 return wrap(value, True) 3525 else: 3526 if flow_stacked: 3527 # Since elements in this array are stacked and `value` was produced by 3528 # gather, its first two dims are "gathered elements" and "stack 3529 # dimension". Our semantics require these two to be flipped. 3530 value = _transpose_first_two_dims(value) 3531 return wrap(value, flow_stacked) 3532 else: 3533 # Values in the TensorArray should be unstacked (since different iterations 3534 # couldn't write to the same location). So whether output is stacked or not 3535 # depends on indices_stacked. 3536 if indices_stacked: 3537 value = _unflatten_first_dim(value, n) 3538 return wrap(value, indices_stacked) 3539 3540 3541@RegisterPFor("TensorArrayScatterV3") 3542def _convert_tensor_array_scatter_v3(pfor_input): 3543 handle = pfor_input.unstacked_input(0) 3544 indices, indices_stacked, _ = pfor_input.input(1) 3545 indices = array_ops.reshape(indices, [-1]) 3546 value, value_stacked, _ = pfor_input.input(2) 3547 flow, flow_stacked, _ = pfor_input.input(3) 3548 3549 if flow_stacked: 3550 flow = _unstack_flow(flow) 3551 3552 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3553 if is_inside: 3554 if indices_stacked: 3555 raise ValueError("Need indices for %s to be loop invariant" % handle) 3556 # Note that flow_stacked indicates if existing values in the array are 3557 # stacked or not. 3558 if not flow_stacked and not value_stacked: 3559 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 3560 flow) 3561 return wrap(flow_out, False) 3562 if not value_stacked: 3563 # TODO(agarwal): tile in the second dimension directly instead of 3564 # transposing below. 3565 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3566 3567 value = _transpose_first_two_dims(value) 3568 # TODO(agarwal): Note that if a previous write was unstacked, flow will be 3569 # unstacked, and a stacked value may be written here which may cause 3570 # runtime error due to different elements having different shape. We do 3571 # not try to prevent that. 3572 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 3573 flow) 3574 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3575 if not indices_stacked: 3576 raise ValueError("Need indices for %s to be not loop invariant" % handle) 3577 if not value_stacked: 3578 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3579 value = _flatten_first_two_dims(value) 3580 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, flow) 3581 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3582 3583 3584@RegisterPFor("TensorArrayGradV3") 3585def _convert_tensor_array_grad_v3(pfor_input): 3586 handle = pfor_input.unstacked_input(0) 3587 flow, flow_stacked, _ = pfor_input.input(1) 3588 if flow_stacked: 3589 flow = _unstack_flow(flow) 3590 source = pfor_input.get_attr("source") 3591 # TODO(agarwal): For now, we assume that gradients are stacked if the 3592 # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong 3593 # will give runtime error due to incorrect shape being written to the 3594 # accumulator. It is difficult to know in advance if gradients written will be 3595 # stacked or not. Note that flow being stacked is not indicative of the 3596 # gradient being stacked or not. Revisit this later. 3597 shape_to_prepend = pfor_input.pfor.loop_len_vector 3598 grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape( 3599 handle=handle, 3600 flow_in=flow, 3601 shape_to_prepend=shape_to_prepend, 3602 source=source) 3603 flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t 3604 return [wrap(grad_handle, False), wrap(flow_out, True)] 3605 3606 3607def _stack_tensor_list_shape(shape, first_dim): 3608 shape_value = tensor_util.constant_value(shape) 3609 # Note that negative values in the shape are used to signify unknown shapes 3610 # and are handled in a special way. 3611 if shape_value is not None: 3612 shape_value = np.asarray(shape_value) 3613 if -1 in shape_value: 3614 return constant_op.constant(-1) 3615 elif not shape_value.size: 3616 return first_dim 3617 else: 3618 shape = array_ops.reshape(shape, [-1]) 3619 return control_flow_ops.cond( 3620 math_ops.reduce_any(shape < 0), 3621 lambda: constant_op.constant(-1), 3622 lambda: array_ops.concat([first_dim, shape], axis=0)) 3623 3624 3625def _tile_variant_with_length(t, length): 3626 """stacks `t` `length` times.""" 3627 if _is_variant_with_internal_stacking(t): 3628 # The content of TensorLists is vectorized, not the variant itself. 3629 return t 3630 original_tensor = t 3631 t.set_shape([]) 3632 t = array_ops.reshape(t, [-1]) 3633 with ops.device("CPU:0"): 3634 result = array_ops.tile(t, length) 3635 # TODO(b/169968286): Should regular shape functions do handle data 3636 # propagation here? 3637 custom_gradient.copy_handle_data(original_tensor, result) 3638 return result 3639 3640 3641def _tile_variant(t, pfor_input): 3642 """stacks `t` according to its loop context.""" 3643 return _tile_variant_with_length(t, pfor_input.pfor.loop_len_vector) 3644 3645 3646def _untile_variant(t): 3647 if _is_variant_with_internal_stacking(t): 3648 # The content of TensorLists is vectorized, not the variant itself. 3649 if not t.shape.is_compatible_with([]): 3650 raise AssertionError( 3651 ("Unexpectedly saw a vectorized variant (e.g. TensorList) with " 3652 "non-scalar shape: {!r}").format(t)) 3653 return t 3654 return array_ops.gather(t, 0) 3655 3656 3657@RegisterPFor("OptionalFromValue") 3658def _convert_optional_from_value(pfor_input): 3659 pfor_input.stack_inputs() 3660 return wrap( 3661 gen_dataset_ops.optional_from_value([x.t for x in pfor_input.inputs]), 3662 True) 3663 3664 3665@RegisterPFor("OptionalGetValue") 3666def _convert_optional_get_value(pfor_input): 3667 handle = pfor_input.stacked_input(0) 3668 output_types = pfor_input.get_attr("output_types") 3669 original_output_shapes = pfor_input.get_attr("output_shapes") 3670 output_shapes = [] 3671 for shape in original_output_shapes: 3672 shape = tensor_shape.TensorShape(shape) 3673 loop_len_shape = tensor_shape.TensorShape( 3674 [tensor_util.constant_value(pfor_input.pfor.loop_len_vector)]) 3675 shape = loop_len_shape.concatenate(shape) 3676 output_shapes.append(shape.as_proto()) 3677 results = gen_dataset_ops.optional_get_value(handle, output_types, 3678 output_shapes) 3679 return [wrap(t, True) for t in results] 3680 3681 3682@RegisterPFor("TensorListReserve") 3683def _convert_tensor_list_reserve(pfor_input): 3684 element_shape = pfor_input.unstacked_input(0) 3685 num_elements = pfor_input.unstacked_input(1) 3686 element_dtype = pfor_input.get_attr("element_dtype") 3687 3688 # Prepend a dimension to element_shape. 3689 element_shape = _stack_tensor_list_shape(element_shape, 3690 pfor_input.pfor.loop_len_vector) 3691 handle = list_ops.tensor_list_reserve( 3692 element_shape, num_elements, element_dtype=element_dtype) 3693 3694 return wrap(_tile_variant(handle, pfor_input), True) 3695 3696 3697@RegisterPFor("TensorListElementShape") 3698def _convert_tensor_list_element_shape(pfor_input): 3699 handle = _untile_variant(pfor_input.stacked_input(0)) 3700 shape_type = pfor_input.get_attr("shape_type") 3701 shape = list_ops.tensor_list_element_shape(handle, shape_type) 3702 shape = array_ops.reshape(shape, [-1]) 3703 shape = shape[1:] 3704 return wrap(shape, False) 3705 3706 3707@RegisterPFor("TensorListLength") 3708def _convert_tensor_list_length(pfor_input): 3709 handle = _untile_variant(pfor_input.stacked_input(0)) 3710 return wrap(list_ops.tensor_list_length(handle), False) 3711 3712 3713def _stack_tensor_list(handle, dtype, loop_len_vector, element_shape=None): 3714 if element_shape is None: 3715 element_shape = list_ops.tensor_list_element_shape(handle, dtypes.int32) 3716 length = list_ops.tensor_list_length(handle) 3717 new_handle = list_ops.tensor_list_reserve( 3718 _stack_tensor_list_shape(element_shape, loop_len_vector), length, dtype) 3719 3720 def _body_fn(i, h): 3721 elem = list_ops.tensor_list_get_item(handle, i, dtype, element_shape) 3722 elem = _stack(elem, loop_len_vector).t 3723 return i + 1, list_ops.tensor_list_set_item(h, i, elem) 3724 3725 return control_flow_ops.while_loop(lambda i, _: i < length, _body_fn, 3726 [0, new_handle])[1] 3727 3728 3729@RegisterPFor("TensorListGetItem") 3730def _convert_tensor_list_get_item(pfor_input): 3731 handle, handle_stacked, _ = pfor_input.input(0) 3732 index, index_stacked, _ = pfor_input.input(1) 3733 element_shape = pfor_input.unstacked_input(2) 3734 element_dtype = pfor_input.get_attr("element_dtype") 3735 3736 if handle_stacked: 3737 handle = _untile_variant(handle) 3738 element_shape = _stack_tensor_list_shape(element_shape, 3739 pfor_input.pfor.loop_len_vector) 3740 if index_stacked: 3741 # We use a sequential loop since that may be more efficient than first 3742 # gathering and concatenating all the element corresponding to `index`, 3743 # and then doing a gather on it. 3744 def _map_fn(i): 3745 item_i = list_ops.tensor_list_get_item( 3746 handle, 3747 index[i], 3748 element_dtype=element_dtype) 3749 return array_ops.gather(item_i, i) 3750 3751 output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices) 3752 return wrap(output, True) 3753 else: 3754 output = list_ops.tensor_list_get_item( 3755 handle, 3756 index, 3757 element_shape=element_shape, 3758 element_dtype=element_dtype) 3759 return wrap(output, True) 3760 else: 3761 assert index_stacked 3762 return wrap( 3763 list_ops.tensor_list_gather( 3764 handle, 3765 index, 3766 element_shape=element_shape, 3767 element_dtype=element_dtype), True) 3768 3769 3770@RegisterPFor("TensorListSetItem") 3771def _convert_tensor_array_set_item(pfor_input): 3772 handle, handle_stacked, _ = pfor_input.input(0) 3773 index, index_stacked, _ = pfor_input.input(1) 3774 item, item_stacked, _ = pfor_input.input(2) 3775 3776 if not handle_stacked: 3777 # Special case where we can statically guarantee that the indices are 3778 # disjoint. 3779 if index is pfor_input.pfor.all_indices: 3780 if not item_stacked: 3781 item = _stack(item, pfor_input.pfor.loop_len_vector).t 3782 return wrap( 3783 list_ops.tensor_list_scatter(item, index, input_handle=handle), False) 3784 else: 3785 handle = _stack_tensor_list(handle, item.dtype, 3786 pfor_input.pfor.loop_len_vector) 3787 else: 3788 handle = _untile_variant(handle) 3789 3790 if index_stacked: 3791 # TODO(agarwal): handle this. 3792 raise ValueError("Vectorizing writes to a TensorList with loop " 3793 "variant indices is currently unsupported.") 3794 3795 else: 3796 if not item_stacked: 3797 item = _stack(item, pfor_input.pfor.loop_len_vector).t 3798 handle = list_ops.tensor_list_set_item(handle, index, item) 3799 return wrap(_tile_variant(handle, pfor_input), True) 3800 3801 3802@RegisterPFor("TensorListPushBack") 3803def _convert_tensor_list_push_back(pfor_input): 3804 handle, handle_stacked, _ = pfor_input.input(0) 3805 tensor, tensor_stacked, _ = pfor_input.input(1) 3806 if handle_stacked: 3807 handle = _untile_variant(handle) 3808 else: 3809 handle = _stack_tensor_list(handle, tensor.dtype, 3810 pfor_input.pfor.loop_len_vector) 3811 if not tensor_stacked: 3812 tensor = _stack(tensor, pfor_input.pfor.loop_len_vector).t 3813 handle = list_ops.tensor_list_push_back(handle, tensor) 3814 return wrap(_tile_variant(handle, pfor_input), True) 3815 3816 3817@RegisterPFor("TensorListPopBack") 3818def _convert_tensor_array_push_back(pfor_input): 3819 handle = pfor_input.stacked_input(0) 3820 element_shape = pfor_input.unstacked_input(1) 3821 handle = _untile_variant(handle) 3822 3823 if element_shape.shape.ndims == 0: 3824 # Default / unspecified 3825 vectorized_shape = -1 3826 else: 3827 # PopBack has an element shape set when it's the gradient of PushBack, only 3828 # used when the list is uninitialized. 3829 vectorized_shape = array_ops.concat( 3830 [pfor_input.pfor.loop_len_vector, element_shape], axis=0) 3831 3832 output_handle, tensor = gen_list_ops.tensor_list_pop_back( 3833 input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"), 3834 element_shape=vectorized_shape) 3835 return wrap(output_handle, True), wrap(tensor, True) 3836 3837 3838@RegisterPFor("TensorListConcatV2") 3839def _convert_tensor_list_concat_v2(pfor_input): 3840 input_handle = pfor_input.stacked_input(0) 3841 element_shape = pfor_input.unstacked_input(1) 3842 leading_dims = pfor_input.unstacked_input(2) 3843 element_dtype = pfor_input.get_attr("element_dtype") 3844 3845 handle = _untile_variant(input_handle) 3846 length = list_ops.tensor_list_length(handle) 3847 # Note that element_shape attribute can have incomplete shapes. This doesn't 3848 # seem to work well when creating another list and then doing a concat on it. 3849 # Hence we try to find the dynamic shape here. 3850 element_shape = control_flow_ops.cond( 3851 length > 0, lambda: array_ops.shape( 3852 list_ops.tensor_list_get_item(handle, 0, element_dtype, None)), 3853 lambda: constant_op.constant([0, 0], dtype=dtypes.int32)) 3854 # The code below creates a copy of the list with each elements' first two 3855 # dimensions transposed. 3856 new_element_shape = array_ops.concat( 3857 [element_shape[1:2], element_shape[0:1], element_shape[2:]], axis=0) 3858 3859 # Create a new TensorList with elements transposed. 3860 def _transpose_elem(i, h): 3861 elem = list_ops.tensor_list_get_item(handle, i, element_dtype, None) 3862 elem = _transpose_first_two_dims(elem) 3863 return i + 1, list_ops.tensor_list_set_item(h, i, elem) 3864 3865 new_handle = list_ops.tensor_list_reserve(new_element_shape, length, 3866 element_dtype) 3867 new_handle = control_flow_ops.while_loop(lambda i, _: i < length, 3868 _transpose_elem, [0, new_handle])[1] 3869 output, lengths = gen_list_ops.tensor_list_concat_v2( 3870 input_handle=new_handle, 3871 element_dtype=element_dtype, 3872 element_shape=new_element_shape, 3873 leading_dims=leading_dims) 3874 output = _transpose_first_two_dims(output) 3875 return wrap(output, True), wrap(lengths, False) 3876 3877 3878@RegisterPFor("TensorListStack") 3879def _convert_tensor_list_stack(pfor_input): 3880 handle = pfor_input.stacked_input(0) 3881 input_shape = pfor_input.unstacked_input(1) 3882 element_dtype = pfor_input.get_attr("element_dtype") 3883 num_elements = pfor_input.get_attr("num_elements") 3884 3885 handle = _untile_variant(handle) 3886 input_shape = _stack_tensor_list_shape(input_shape, 3887 pfor_input.pfor.loop_len_vector) 3888 output = list_ops.tensor_list_stack( 3889 handle, 3890 element_dtype, 3891 element_shape=input_shape, 3892 num_elements=num_elements) 3893 output = _transpose_first_two_dims(output) 3894 return wrap(output, True) 3895 3896 3897@RegisterPFor("TensorListGather") 3898def _convert_tensor_list_gather(pfor_input): 3899 handle, handle_stacked, _ = pfor_input.input(0) 3900 index, index_stacked, _ = pfor_input.input(1) 3901 element_shape = pfor_input.unstacked_input(2) 3902 element_dtype = pfor_input.get_attr("element_dtype") 3903 3904 if handle_stacked: 3905 handle = _untile_variant(handle) 3906 element_shape = _stack_tensor_list_shape(element_shape, 3907 pfor_input.pfor.loop_len_vector) 3908 if index_stacked: 3909 # We use a sequential loop since that may be more efficient than first 3910 # gathering and concatenating all the element corresponding to `index`, 3911 # and then doing a gather on it. 3912 def _map_fn(i): 3913 item_i = list_ops.tensor_list_gather( 3914 handle, 3915 index[i], 3916 element_dtype=element_dtype) 3917 axis = array_ops.rank(index) - 1 3918 return array_ops.gather(item_i, i, axis=axis) 3919 3920 output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices) 3921 return wrap(output, True) 3922 else: 3923 output = list_ops.tensor_list_gather( 3924 handle, 3925 index, 3926 element_shape=element_shape, 3927 element_dtype=element_dtype) 3928 return wrap(output, True) 3929 else: 3930 assert index_stacked 3931 index_shape = array_ops.shape(index) 3932 index = array_ops.reshape(index, [-1]) 3933 values = list_ops.tensor_list_gather( 3934 handle, index, element_shape=element_shape, element_dtype=element_dtype) 3935 final_shape = array_ops.concat( 3936 [index_shape, array_ops.shape(values)[1:]], axis=0) 3937 return wrap(array_ops.reshape(values, final_shape), True) 3938 3939 3940@RegisterPFor("TensorListScatterIntoExistingList") 3941def _convert_tensor_list_scatter(pfor_input): 3942 pfor_input.stack_inputs([1]) 3943 handle, handle_stacked, _ = pfor_input.input(0) 3944 item = pfor_input.stacked_input(1) 3945 # TODO(agarwal): handle stacked indices. 3946 indices = pfor_input.unstacked_input(2) 3947 if handle_stacked: 3948 handle = _untile_variant(handle) 3949 else: 3950 handle = _stack_tensor_list(handle, item.dtype, 3951 pfor_input.pfor.loop_len_vector) 3952 3953 item = _transpose_first_two_dims(item) 3954 handle = list_ops.tensor_list_scatter(item, indices, input_handle=handle) 3955 return wrap(_tile_variant(handle, pfor_input), True) 3956 3957 3958@RegisterPFor("TensorListFromTensor") 3959def _convert_tensor_list_from_tensor(pfor_input): 3960 tensor = pfor_input.stacked_input(0) 3961 element_shape = pfor_input.unstacked_input(1) 3962 tensor = _transpose_first_two_dims(tensor) 3963 element_shape = _stack_tensor_list_shape(element_shape, 3964 pfor_input.pfor.loop_len_vector) 3965 handle = list_ops.tensor_list_from_tensor(tensor, element_shape) 3966 return wrap(_tile_variant(handle, pfor_input), True) 3967 3968 3969# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar 3970# to TensorArrays, we convert them by changing the dimension of the elements 3971# inside the stack. 3972# 3973# We consider two cases: 3974# 3975# 1. StackV2 is constructed and used entirely inside the pfor loop. 3976# We keep a single Stack and perform the push/pop operations of all the 3977# iterations in lock-step. We also assume that all the iterations perform these 3978# operations. In case of dynamic control flow, if only some of the iterations 3979# try to perform a push/pop, then the conversion may not work correctly and may 3980# cause undefined behavior. 3981# TODO(agarwal): test StackV2 with dynamic control flow. 3982# 3983# 2. StackV2 is constructed outside the pfor loop. 3984# Performing stack push/pop in a parallel fashion is ill-defined. However given 3985# that reading stacks created externally is a common operation when computing 3986# jacobians, we provide some special semantics here as follows. 3987# - disallow push operations to the stack 3988# - pop operations are performed in lock step by all iterations, similar to the 3989# case when the stack is created inside. A single value is popped during the 3990# lock-step operation and broadcast to all the iterations. Values in the stack 3991# are assumed to be loop-invariant. 3992# 3993# Some other implementation details: 3994# We use an ugly logic to find whether values in Stack data structure are 3995# loop invariant or not. When converting push/pop operations, we keep track of 3996# whether the last conversion used a stacked value or not (see _stack_cache 3997# below). As a result if an unstacked value is written first, subsequent stacked 3998# writes are disallowed when they could have been allowed in theory. 3999 4000# Map from cache key based on StackV2 handle to a bool indicating whether values 4001# are stacked or not. 4002# TODO(agarwal): move _stack_cache inside pfor? 4003_stack_cache = {} 4004 4005 4006def _stack_cache_key(pfor_input): 4007 """Create cache key corresponding to a stack handle.""" 4008 op_type = pfor_input.op_type 4009 assert op_type in ["StackPushV2", "StackPopV2"], op_type 4010 orig_handle = pfor_input.op.inputs[0] 4011 while orig_handle.op.type in ["Identity", "Enter"]: 4012 orig_handle = orig_handle.op.inputs[0] 4013 assert orig_handle.op.type == "StackV2", orig_handle.op 4014 return ops.get_default_graph(), pfor_input.pfor, orig_handle 4015 4016 4017def _stack_handle_inside_pfor(handle, pfor_input): 4018 while handle.op.type in ["Identity", "Enter"]: 4019 handle = handle.op.inputs[0] 4020 assert handle.op.type == "StackV2", ("Unable to find StackV2 op. Got %s" % 4021 handle.op) 4022 return pfor_input.pfor.op_is_inside_loop(handle.op) 4023 4024 4025@RegisterPFor("StackPushV2") 4026def _convert_stack_push_v2(pfor_input): 4027 handle = pfor_input.unstacked_input(0) 4028 elem, elem_stacked, _ = pfor_input.input(1) 4029 swap_memory = pfor_input.get_attr("swap_memory") 4030 4031 if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input): 4032 raise ValueError("StackPushV2 not allowed on stacks created outside pfor") 4033 stack_cache_key = _stack_cache_key(pfor_input) 4034 stacked = _stack_cache.get(stack_cache_key, None) 4035 if stacked is None: 4036 stacked = elem_stacked 4037 _stack_cache[stack_cache_key] = stacked 4038 else: 4039 # If we previously made it unstacked then we can't revert to being stacked. 4040 if not stacked and elem_stacked: 4041 raise ValueError( 4042 "It looks like the stack was previously determined to be loop" 4043 " invariant, but we are now trying to push a loop dependent value" 4044 " to it. This is currently unsupported.") 4045 if stacked and not elem_stacked: 4046 elem = _stack(elem, pfor_input.pfor.loop_len_vector).t 4047 out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory) 4048 return wrap(out, stacked) 4049 4050 4051# Note that inputs to this convertor will be unstacked. However it should get 4052# called since it is a stateful op. 4053@RegisterPFor("StackPopV2") 4054def _convert_stack_pop_v2(pfor_input): 4055 handle = pfor_input.unstacked_input(0) 4056 stack_cache_key = _stack_cache_key(pfor_input) 4057 stacked = _stack_cache.get(stack_cache_key, None) 4058 # If a StackPushV2 has not been converted yet, we default to unstacked since 4059 # the push could be outside of pfor, or the convertor may not be called if the 4060 # inputs are unconverted. 4061 if stacked is None: 4062 stacked = False 4063 _stack_cache[stack_cache_key] = False 4064 elem_type = pfor_input.get_attr("elem_type") 4065 out = data_flow_ops.stack_pop_v2(handle, elem_type) 4066 return wrap(out, stacked) 4067 4068 4069# parsing_ops 4070 4071 4072@RegisterPFor("DecodeCSV") 4073def _convert_decode_csv(pfor_input): 4074 lines = pfor_input.stacked_input(0) 4075 record_defaults = [ 4076 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) 4077 ] 4078 field_delim = pfor_input.get_attr("field_delim") 4079 use_quote_delim = pfor_input.get_attr("use_quote_delim") 4080 select_cols = pfor_input.get_attr("select_cols") 4081 if not select_cols: 4082 select_cols = None 4083 return [ 4084 wrap(t, True) for t in parsing_ops.decode_csv( 4085 lines, 4086 record_defaults, 4087 field_delim=field_delim, 4088 use_quote_delim=use_quote_delim, 4089 select_cols=select_cols) 4090 ] 4091 4092 4093@RegisterPFor("ParseSingleExample") 4094def _convert_parse_single_example(pfor_input): 4095 serialized = pfor_input.stacked_input(0) 4096 dense_defaults = [ 4097 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) 4098 ] 4099 sparse_keys = pfor_input.get_attr("sparse_keys") 4100 dense_keys = pfor_input.get_attr("dense_keys") 4101 sparse_types = pfor_input.get_attr("sparse_types") 4102 dense_shapes = pfor_input.get_attr("dense_shapes") 4103 output = gen_parsing_ops.parse_example( 4104 serialized=serialized, 4105 names=[], 4106 dense_defaults=dense_defaults, 4107 sparse_keys=sparse_keys, 4108 dense_keys=dense_keys, 4109 sparse_types=sparse_types, 4110 dense_shapes=dense_shapes) 4111 return [wrap(t, True, True) for t in nest.flatten(output)] 4112 4113 4114@RegisterPFor("ParseExampleV2") 4115def _convert_parse_example_v2(pfor_input): 4116 serialized = pfor_input.stacked_input(0) 4117 sparse_keys = pfor_input.unstacked_input(2) 4118 dense_keys = pfor_input.unstacked_input(3) 4119 ragged_keys = pfor_input.unstacked_input(4) 4120 dense_defaults = [ 4121 pfor_input.unstacked_input(i) for i in range(5, pfor_input.num_inputs) 4122 ] 4123 num_sparse = pfor_input.get_attr("num_sparse") 4124 sparse_types = pfor_input.get_attr("sparse_types") 4125 ragged_value_types = pfor_input.get_attr("ragged_value_types") 4126 ragged_split_types = pfor_input.get_attr("ragged_split_types") 4127 dense_shapes = pfor_input.get_attr("dense_shapes") 4128 if serialized.shape.ndims not in (None, 1): 4129 raise ValueError("ParseExampleV2 can only be converted if `serialized` " 4130 "is scalar.") 4131 output = gen_parsing_ops.parse_example_v2( 4132 serialized=serialized, 4133 names=[], 4134 sparse_keys=sparse_keys, 4135 dense_keys=dense_keys, 4136 ragged_keys=ragged_keys, 4137 dense_defaults=dense_defaults, 4138 num_sparse=num_sparse, 4139 sparse_types=sparse_types, 4140 ragged_value_types=ragged_value_types, 4141 ragged_split_types=ragged_split_types, 4142 dense_shapes=dense_shapes) 4143 return [wrap(t, True, True) for t in nest.flatten(output)] 4144 4145 4146# functional_ops 4147 4148 4149def _convert_function_call(func, converter, inputs): 4150 assert isinstance(func.graph, func_graph.FuncGraph), func 4151 assert isinstance(converter, PFor) 4152 4153 # TODO(agarwal): consider caching this function definition. 4154 @def_function.function 4155 def f(*args): 4156 assert all(isinstance(arg, WrappedTensor) for arg in args), args 4157 assert len(args) == len(func.graph.inputs), (args, func.graph.inputs) 4158 # Map inputs to function arguments. 4159 for inp, arg in zip(func.graph.inputs, args): 4160 converter._add_conversion(inp, arg) 4161 # Convert output tensors. 4162 return tuple( 4163 [converter._convert_helper(x).t for x in func._func_graph_outputs]) 4164 4165 call_outputs = f(*inputs) 4166 assert len(call_outputs) == len(func._func_graph_outputs) 4167 outputs = [] 4168 for call_output, output_tensor in zip(call_outputs, func._func_graph_outputs): 4169 func_output = converter._convert_helper(output_tensor) 4170 outputs.append( 4171 wrap(call_output, func_output.is_stacked, 4172 func_output.is_sparse_stacked)) 4173 return outputs 4174 4175 4176@RegisterPFor("StatefulPartitionedCall") 4177@RegisterPFor("PartitionedCall") 4178def _convert_partitioned_call(pfor_input): 4179 func_name = pfor_input.get_attr("f").name 4180 func = pfor_input.op.graph._get_function(compat.as_bytes(func_name)) 4181 assert isinstance(func.graph, func_graph.FuncGraph), ( 4182 "Could not find FuncGraph object for %s. Got func %s" % (func_name, func)) 4183 pfor = pfor_input.pfor 4184 converter = PFor( 4185 loop_var=pfor.loop_var, 4186 loop_len=pfor.loop_len_vector[0], 4187 pfor_ops=func.graph.get_operations(), 4188 fallback_to_while_loop=pfor.fallback_to_while_loop, 4189 all_indices=pfor.all_indices, 4190 all_indices_partitioned=pfor.all_indices_partitioned, 4191 pfor_config=pfor.pfor_config) 4192 return _convert_function_call(func, converter, pfor_input.inputs) 4193 4194 4195def _partition_inputs_for_indices(inputs, indices): 4196 new_inputs = [] 4197 for inp in inputs: 4198 if inp.is_stacked: 4199 new_inputs.append(wrap(array_ops.gather(inp.t, indices), True)) 4200 else: 4201 new_inputs.append(inp) 4202 return new_inputs 4203 4204 4205def _outputs_for_branch(func_name, indices, pfor_input, inputs): 4206 if indices is None: 4207 indices = pfor_input.pfor.all_indices 4208 partitioned = pfor_input.pfor.all_indices_partitioned 4209 else: 4210 partitioned = True 4211 func = pfor_input.op.graph._get_function(func_name) 4212 converter = PFor( 4213 loop_var=pfor_input.pfor.loop_var, 4214 loop_len=array_ops.size(indices), 4215 pfor_ops=func.graph.get_operations(), 4216 fallback_to_while_loop=pfor_input.pfor.fallback_to_while_loop, 4217 all_indices=indices, 4218 all_indices_partitioned=partitioned, 4219 pfor_config=pfor_input.pfor.pfor_config) 4220 outputs = _convert_function_call(func, converter, inputs) 4221 stacked_outputs = [] 4222 for out in outputs: 4223 if not out.is_stacked: 4224 stacked_outputs.append(_stack(out.t, [array_ops.size(indices)]).t) 4225 else: 4226 stacked_outputs.append(out.t) 4227 return stacked_outputs 4228 4229 4230# TODO(agarwal): Currently the converted code aggressively tiles loop variant 4231# outputs from the then/else branches. Instead, it could do so only if at least 4232# one of the branch outputs is loop variant. 4233@RegisterPFor("StatelessIf") 4234@RegisterPFor("If") 4235def _convert_if(pfor_input): 4236 cond, cond_stacked, _ = pfor_input.input(0) 4237 inputs = pfor_input.inputs[1:] 4238 then_branch = pfor_input.get_attr("then_branch") 4239 else_branch = pfor_input.get_attr("else_branch") 4240 4241 if cond_stacked: 4242 cond_int = math_ops.cast(cond, dtypes.int32) 4243 # Compute loop indices for the different branches 4244 false_indices, true_indices = data_flow_ops.dynamic_partition( 4245 pfor_input.pfor.all_indices, cond_int, 2) 4246 # Compute indices for cond being True or False. 4247 if pfor_input.pfor.all_indices_partitioned: 4248 else_indices, then_indices = data_flow_ops.dynamic_partition( 4249 math_ops.range(pfor_input.pfor.loop_len_vector[0]), 4250 cond_int, 2) 4251 else: 4252 else_indices, then_indices = false_indices, true_indices 4253 # Partition inputs 4254 then_inputs = _partition_inputs_for_indices(inputs, then_indices) 4255 else_inputs = _partition_inputs_for_indices(inputs, else_indices) 4256 4257 # Convert "then" branch. 4258 then_outputs = _outputs_for_branch(then_branch.name, true_indices, 4259 pfor_input, then_inputs) 4260 4261 # Convert "else" branch. 4262 else_outputs = _outputs_for_branch(else_branch.name, false_indices, 4263 pfor_input, else_inputs) 4264 4265 assert len(then_outputs) == len(else_outputs) 4266 # Note that if the "then" and "else" branches are updating the same state, 4267 # and possibly reading them as well, it could lead to undefined behavior 4268 # since the ordering of those operations is not well defined. 4269 # One possibility is to order all the "then" branches to execute before all 4270 # the "else" branches so that the side-effects in the former are visible to 4271 # the latter. For now, we leave that as undefined behavior. 4272 outputs = [] 4273 # Merge outputs 4274 for then_output, else_output in zip(then_outputs, else_outputs): 4275 out = data_flow_ops.dynamic_stitch([then_indices, else_indices], 4276 [then_output, else_output]) 4277 outputs.append(wrap(out, True)) 4278 return outputs 4279 else: 4280 outputs = control_flow_ops.cond( 4281 cond, 4282 lambda: _outputs_for_branch(then_branch.name, None, pfor_input, inputs), 4283 lambda: _outputs_for_branch(else_branch.name, None, pfor_input, inputs)) 4284 return [wrap(t, True) for t in outputs] 4285 4286 4287class WhileV2(object): 4288 """Object for vectorizing V2 while_loop op.""" 4289 4290 def __init__(self, pfor_input): 4291 self._pfor_input = pfor_input 4292 self._pfor = pfor_input.pfor 4293 cond_func_name = pfor_input.get_attr("cond").name 4294 self._cond_func = pfor_input.op.graph._get_function(compat.as_bytes( 4295 cond_func_name)) 4296 body_func_name = pfor_input.get_attr("body").name 4297 self._body_func = pfor_input.op.graph._get_function(compat.as_bytes( 4298 body_func_name)) 4299 if self._cond_func is None or self._body_func is None: 4300 raise ValueError("Error extracting cond and body functions for op %s." % ( 4301 self._pfor_input.op)) 4302 # Indices of inputs that are passed unchanged through the while loop body. 4303 # Typically these are tensors captured from outside the body context. 4304 self._body_pass_through_indices = set() 4305 for i, (inp, out) in enumerate(zip(self._body_func.graph.inputs, 4306 self._body_func.graph.outputs)): 4307 if id(inp) == id(out): 4308 self._body_pass_through_indices.add(i) 4309 self._parallel_iterations = self._pfor_input.get_attr("parallel_iterations") 4310 4311 def _output_shapes(self): 4312 # Calculate output shape for vectorized loop. This will be used as 4313 # shape_invariant. Merges shape inference outputs with the `output_shapes` 4314 # attribute of the op. 4315 output_shapes = [out.shape for out in self._pfor_input.op.outputs] 4316 shapes = self._pfor_input.get_attr("output_shapes") 4317 if not shapes: 4318 shapes = [tensor_shape.TensorShape(None) for _ in output_shapes] 4319 else: 4320 shapes = [tensor_shape.TensorShape(shape) for shape in shapes] 4321 for i, shape in enumerate(shapes): 4322 shape = shape.merge_with(output_shapes[i]) 4323 pfor_input = self._pfor_input.input(i) 4324 if pfor_input.is_stacked: 4325 if _is_variant_with_internal_stacking(pfor_input.t): 4326 shape = tensor_shape.TensorShape([]).concatenate(shape) 4327 else: 4328 shape = tensor_shape.TensorShape([None]).concatenate(shape) 4329 output_shapes[i] = shape 4330 assert len(output_shapes) == self._pfor_input.num_inputs 4331 return output_shapes 4332 4333 def _init_values(self): 4334 """Create arguments passed to converted while_loop.""" 4335 loop_len = self._pfor.loop_len_vector[0] 4336 inputs = [] 4337 # TensorArrays for outputs of converted while loop 4338 output_tas = [] 4339 4340 with ops.name_scope("while_init"): 4341 for inp in self._pfor_input.inputs: 4342 inputs.append(inp.t) 4343 output_tas.append(tensor_array_ops.TensorArray( 4344 inp.t.dtype, 4345 size=loop_len, 4346 dynamic_size=False, 4347 infer_shape=True)) 4348 # See documentation for __call__ for the structure of init_values. 4349 indices = ( 4350 math_ops.range(self._pfor.loop_len_vector[0]) 4351 if self._pfor.all_indices_partitioned else self._pfor.all_indices) 4352 return [True, indices] + inputs + output_tas 4353 4354 def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): 4355 """Handles case when condition is pfor loop invariant.""" 4356 # Note that all iterations end together. So we don't need to partition the 4357 # inputs. 4358 not_all_done = array_ops.reshape(conditions, []) 4359 return not_all_done, indices, inputs, output_tas 4360 4361 def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, 4362 output_tas): 4363 """Handles case when condition is pfor loop dependent.""" 4364 # Compute if all iterations are done. 4365 not_all_done = math_ops.reduce_any(conditions) 4366 conditions_int = math_ops.cast(conditions, dtypes.int32) 4367 # Partition the indices. 4368 done_indices, new_indices = data_flow_ops.dynamic_partition( 4369 indices, conditions_int, 2) 4370 4371 new_inputs = [] 4372 new_output_tas = [] 4373 for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): 4374 pass_through = i in self._body_pass_through_indices 4375 # Partition the inputs. 4376 if stacked: 4377 done_inp, new_inp = data_flow_ops.dynamic_partition( 4378 inp, conditions_int, 2) 4379 else: 4380 if not pass_through: 4381 done_inp = _stack(inp, [array_ops.size(done_indices)]).t 4382 new_inp = inp 4383 4384 new_inputs.append(new_inp) 4385 out_ta = output_tas[i] 4386 if not pass_through: 4387 # Note that done_indices can be empty. done_inp should also be empty 4388 # in that case. 4389 out_ta = out_ta.scatter(done_indices, done_inp) 4390 new_output_tas.append(out_ta) 4391 4392 assert len(new_output_tas) == len(output_tas) 4393 assert len(new_inputs) == len(inputs) 4394 return not_all_done, new_indices, new_inputs, new_output_tas 4395 4396 def _process_body(self, inputs_stacked, new_indices, cond_stacked, 4397 new_inputs, not_all_done): 4398 """Convert the body function.""" 4399 # This is used to store the indices of inputs to the while op that need to 4400 # be stacked. This stacking may be needed in cases where the input to the 4401 # while_loop is loop_invariant but the corresponding output is not. 4402 mismatching_stacked_indices = [] 4403 4404 def true_fn(): 4405 """Converts the body function for all but last iteration.""" 4406 wrapped_inputs = [wrap(inp, stacked) for inp, stacked in 4407 zip(new_inputs, inputs_stacked)] 4408 # Note the iterative process below to figure out loop invariance. 4409 # Here we iterate on vectorization process till a fixed point. The issue 4410 # is that the while body can take pfor loop invariant inputs but return 4411 # loop variant outputs. For any loop variant output, the corresponding 4412 # input has to be then made loop variant (since subsequent while 4413 # iterations will need to see loop variant values). 4414 # However once we make a new input loop variant, we might make other 4415 # outputs loop variant. Hence we need to iterate till we get fixed point. 4416 while True: 4417 if self._pfor.all_indices_partitioned: 4418 indices = array_ops.gather(self._pfor.all_indices, new_indices) 4419 else: 4420 indices = new_indices 4421 body_pfor = PFor( 4422 loop_var=self._pfor.loop_var, 4423 loop_len=array_ops.size(new_indices), 4424 pfor_ops=self._body_func.graph.get_operations(), 4425 fallback_to_while_loop=self._pfor.fallback_to_while_loop, 4426 all_indices=indices, 4427 all_indices_partitioned=(self._pfor.all_indices_partitioned or 4428 cond_stacked), 4429 pfor_config=self._pfor.pfor_config) 4430 stacking_mismatch = False 4431 outputs = _convert_function_call(self._body_func, 4432 body_pfor, 4433 wrapped_inputs) 4434 for i, (out, inp) in enumerate(zip(outputs, wrapped_inputs)): 4435 if out.is_stacked != inp.is_stacked: 4436 stacking_mismatch = True 4437 mismatching_stacked_indices.append(i) 4438 stacked = _stack(inp.t, [array_ops.size(new_indices)]) 4439 if inp.t.dtype == dtypes.variant: 4440 stacked = wrap( 4441 _tile_variant_with_length(stacked.t, 4442 [array_ops.size(new_indices)])) 4443 wrapped_inputs[i] = stacked 4444 if not stacking_mismatch: 4445 if mismatching_stacked_indices: 4446 # We needed to stack some inputs. This code will be abandoned and 4447 # should not get executed. Hence we simply return `new_inputs` to 4448 # make sure the graph construction code completes. 4449 with ops.control_dependencies([ 4450 control_flow_ops.Assert( 4451 False, ["pfor ERROR: this branch should never execute"])]): 4452 return [array_ops.identity(x) for x in new_inputs] 4453 else: 4454 return [out.t for out in outputs] 4455 4456 # If all are done, we simply return `new_inputs`. Else we need to run the 4457 # body function. 4458 return control_flow_ops.cond( 4459 not_all_done, 4460 true_fn, 4461 lambda: list(new_inputs)), mismatching_stacked_indices 4462 4463 def __call__(self): 4464 """Converter for the V2 while_loop. 4465 4466 The conversion of a while_loop is another while_loop. 4467 4468 The arguments to this converted while_loop are as follows: 4469 not_all_done: Boolean scalar Tensor indicating if all the pfor iterations 4470 are done. 4471 indices: int32 1-D Tensor storing the id of the pfor iterations that are not 4472 done. 4473 args: Remaining arguments. These can be divided into 2 categories: 4474 - The first set of arguments correspond one-to-one to the inputs to the 4475 unvectorized while_loop. 4476 - The second set are TensorArrays, corresponding one-to-one to each output 4477 of the unvectorized while_loop. Each TensorArray has `PFor.loop_len` 4478 elements, i.e. the number of pfor iterations. At the end, the i'th 4479 element of each TensorArray will contain the output computed by the i'th 4480 iteration of pfor. Note that elements can be written into these tensors 4481 arrays in any order, depending on when the corresponding pfor iteration 4482 is done. 4483 In each iteration, the while_loop body recomputes the condition for all 4484 active pfor iterations to see which of them are now done. It then partitions 4485 all the inputs and passes them along to the converted body. Values for all 4486 the iterations that are done are written to TensorArrays indexed by the pfor 4487 iteration number. When all iterations are done, the TensorArrays are stacked 4488 to get the final value. 4489 4490 Returns: 4491 List of converted outputs. 4492 """ 4493 output_shapes = self._output_shapes() 4494 # Note that we use these lists as a hack since we need the `body` to compute 4495 # these values during construction of the while_loop graph. 4496 cond_is_stacked = [None] 4497 indices_to_stack = [] 4498 4499 def cond(not_all_done, *_): 4500 return not_all_done 4501 4502 def body(not_all_done, indices, *args): 4503 # See documentation for __call__ for the structure of *args. 4504 num_inputs = self._pfor_input.num_inputs 4505 inputs = args[:num_inputs] 4506 output_tas = args[num_inputs:] 4507 inputs_stacked = [x.is_stacked for x in self._pfor_input.inputs] 4508 assert len(inputs) >= len(output_tas) 4509 assert len(inputs) == len(inputs_stacked) 4510 # Convert condition 4511 with ops.name_scope("while_cond"): 4512 # Note that we set all_indices_partitioned to True here. At this point 4513 # we don't know if indices will be partitioned. Hence we use the 4514 # conservative value. 4515 cond_pfor = PFor( 4516 loop_var=self._pfor.loop_var, 4517 loop_len=array_ops.size(indices), 4518 pfor_ops=self._cond_func.graph.get_operations(), 4519 fallback_to_while_loop=self._pfor.fallback_to_while_loop, 4520 all_indices=indices, 4521 all_indices_partitioned=True, 4522 pfor_config=self._pfor.pfor_config) 4523 4524 wrapped_inputs = [wrap(inp, stacked) for inp, stacked 4525 in zip(inputs, inputs_stacked)] 4526 conditions, cond_stacked, _ = _convert_function_call( 4527 self._cond_func, 4528 cond_pfor, 4529 wrapped_inputs)[0] 4530 cond_is_stacked[0] = cond_stacked 4531 4532 # Recompute the new condition, write outputs of done iterations, and 4533 # partition the inputs if needed. 4534 if not cond_stacked: 4535 (not_all_done, new_indices, new_inputs, 4536 new_output_tas) = self._process_cond_unstacked(conditions, indices, 4537 inputs, output_tas) 4538 else: 4539 (not_all_done, new_indices, new_inputs, 4540 new_output_tas) = self._process_cond_stacked(conditions, indices, 4541 inputs, inputs_stacked, 4542 output_tas) 4543 # Convert body 4544 with ops.name_scope("while_body"): 4545 # Compute the outputs from the body. 4546 new_outputs, mismatching_stacked_indices = self._process_body( 4547 inputs_stacked, new_indices, cond_stacked, new_inputs, not_all_done) 4548 4549 indices_to_stack[:] = mismatching_stacked_indices 4550 for i, new_output in enumerate(new_outputs): 4551 new_output.set_shape(output_shapes[i]) 4552 new_args = ([not_all_done, new_indices] + new_outputs + 4553 list(new_output_tas)) 4554 return tuple(new_args) 4555 4556 # Note that we run the code below in a function since we might abandon the 4557 # generated code in cases where the conversion dictates that some inputs be 4558 # further stacked. Hence we run the graph construction using 4559 # `get_concrete_function` and avoid calling the constructed function if not 4560 # needed. 4561 @def_function.function 4562 def while_fn(): 4563 # Create init_values that will be passed to the while_loop. 4564 init_values = self._init_values() 4565 ta_shape_invariants = [tensor_shape.TensorShape([]) for _ in 4566 self._pfor_input.outputs] 4567 shape_invariants = ( 4568 [tensor_shape.TensorShape([]), tensor_shape.TensorShape([None])] 4569 + output_shapes + ta_shape_invariants) 4570 4571 while_outputs = control_flow_ops.while_loop( 4572 cond, body, init_values, 4573 shape_invariants=shape_invariants, 4574 parallel_iterations=self._parallel_iterations) 4575 if indices_to_stack: 4576 # This function will be abandoned. 4577 return while_outputs 4578 else: 4579 num_inputs = self._pfor_input.num_inputs 4580 new_inputs = while_outputs[2:num_inputs+2] 4581 output_tas = while_outputs[num_inputs+2:] 4582 assert cond_is_stacked[0] is not None 4583 outputs = [] 4584 for i, inp in enumerate(new_inputs): 4585 if cond_is_stacked[0]: 4586 if i in self._body_pass_through_indices: 4587 outputs.append(init_values[i + 2]) 4588 else: 4589 ta = output_tas[i] 4590 outputs.append(ta.stack()) 4591 else: 4592 outputs.append(inp) 4593 return outputs 4594 4595 _ = while_fn.get_concrete_function() 4596 if indices_to_stack: 4597 # Need to abandon the current conversion, stack some inputs and restart. 4598 self._pfor_input.stack_inputs( 4599 stack_indices=indices_to_stack, tile_variants=True) 4600 # Note that this call will recurse at most one time. The first call will 4601 # do the required stacking, based on the iterative procedure in 4602 # _process_body, and the next invocation to __call__ should not need to do 4603 # any more stacking. 4604 # We invoke `self()` here as a way to discard any corrupted state. 4605 return self() 4606 else: 4607 outputs = while_fn() 4608 wrapped_outputs = [] 4609 for i, (out, inp) in enumerate(zip(outputs, self._pfor_input.inputs)): 4610 if i not in self._body_pass_through_indices and cond_is_stacked[0]: 4611 wrapped_outputs.append(wrap(out, True)) 4612 else: 4613 wrapped_outputs.append(wrap(out, inp.is_stacked)) 4614 return wrapped_outputs 4615 4616 4617@RegisterPFor("StatelessWhile") 4618@RegisterPFor("While") 4619def _convert_while(pfor_input): 4620 converter = WhileV2(pfor_input) 4621 return converter() 4622 4623 4624# spectral_ops 4625 4626 4627@RegisterPForWithArgs("FFT", gen_spectral_ops.fft) 4628@RegisterPForWithArgs("FFT2D", gen_spectral_ops.fft2d) 4629@RegisterPForWithArgs("FFT3D", gen_spectral_ops.fft3d) 4630@RegisterPForWithArgs("IFFT", gen_spectral_ops.ifft) 4631@RegisterPForWithArgs("IFFT2D", gen_spectral_ops.ifft2d) 4632@RegisterPForWithArgs("IFFT3D", gen_spectral_ops.ifft3d) 4633def _convert_fft(pfor_input, _, op_func): 4634 return wrap(op_func(pfor_input.stacked_input(0)), True) 4635 4636 4637@RegisterPForWithArgs("RFFT", gen_spectral_ops.rfft, "Tcomplex") 4638@RegisterPForWithArgs("RFFT2D", gen_spectral_ops.rfft2d, "Tcomplex") 4639@RegisterPForWithArgs("RFFT3D", gen_spectral_ops.rfft3d, "Tcomplex") 4640@RegisterPForWithArgs("IRFFT", gen_spectral_ops.irfft, "Treal") 4641@RegisterPForWithArgs("IRFFT2D", gen_spectral_ops.irfft2d, "Treal") 4642@RegisterPForWithArgs("IRFFT3D", gen_spectral_ops.irfft3d, "Treal") 4643def _convert_rfft(pfor_input, _, op_func, attr_name): 4644 inp = pfor_input.stacked_input(0) 4645 fft_length = pfor_input.unstacked_input(1) 4646 attr = pfor_input.get_attr(attr_name) 4647 return wrap(op_func(inp, fft_length, attr), True) 4648