1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Compiled parallel-for loop.""" 16# pylint: disable=missing-docstring,g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23 24from tensorflow.python.eager import context 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import bitwise_ops 33from tensorflow.python.ops import check_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import data_flow_ops 36from tensorflow.python.ops import gen_parsing_ops 37from tensorflow.python.ops import gen_sparse_ops 38from tensorflow.python.ops import map_fn 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import nn_ops 41from tensorflow.python.ops import parsing_ops 42from tensorflow.python.ops import sparse_ops 43from tensorflow.python.ops import tensor_array_ops 44from tensorflow.python.platform import flags 45from tensorflow.python.platform import tf_logging as logging 46from tensorflow.python.util import compat 47from tensorflow.python.util import nest 48 49flags.DEFINE_bool( 50 "op_conversion_fallback_to_while_loop", False, 51 "If true, falls back to using a while loop for ops for " 52 "which a converter is not defined.") 53 54 55def _stack(t, length): 56 """stacks `t` `length` times.""" 57 ones = array_ops.ones_like(array_ops.shape(t)) 58 multiples = array_ops.concat([length, ones], 0) 59 t = array_ops.tile(array_ops.expand_dims(t, 0), multiples) 60 return wrap(t, True) 61 62 63# The following stateful ops can be safely called once, and with the same 64# signature as the unconverted version, if their inputs are loop invariant. 65# TODO(agarwal): implement a strategy for converting Variable reads/writes. The 66# plan is to map each read/write in the loop_fn to a corresponding merged 67# read/write in the converted graph. Writes need to be mergeable (e.g. 68# AssignAdd) to be used in `pfor`. Given a certain read/write order in the 69# loop_fn, doing a one-to-one conversion will simulate executing such 70# instructions in lock-step across all iterations. 71passthrough_stateful_ops = set([ 72 "VariableV2", 73 "VarHandleOp", 74 "ReadVariableOp", 75 "StackV2", 76 "TensorArrayWriteV3", 77 "TensorArrayReadV3", 78 "TensorArraySizeV3", 79]) 80 81 82def _is_stateful_pfor_op(op): 83 if isinstance(op, WhileOp): 84 return op.is_stateful 85 if op.type == "Const": 86 # Const didn't have an op_def. 87 return False 88 if op.type in passthrough_stateful_ops: 89 return False 90 assert hasattr(op, "op_def") and op.op_def is not None, op 91 return op.op_def.is_stateful 92 93 94# pylint: disable=protected-access 95class WhileOp(object): 96 """Object for storing state for converting the outputs of a while_loop.""" 97 98 def __init__(self, exit_node, pfor_ops, pfor_config): 99 """Initializer. 100 101 Args: 102 exit_node: A tensor output from the while_loop. 103 pfor_ops: list of ops inside the current pfor loop. 104 pfor_config: PForConfig object used while constructing loop body. 105 """ 106 self._pfor_config = pfor_config 107 self._pfor_ops = set(pfor_ops) 108 self._pfor_op_ids = set([x._id for x in pfor_ops]) 109 assert isinstance(exit_node, ops.Tensor) 110 self._while_context = exit_node.op._get_control_flow_context() 111 assert isinstance(self._while_context, control_flow_ops.WhileContext) 112 self._context_name = self._while_context.name 113 self._condition = self._while_context.pivot.op.inputs[0] 114 # Parts of an external while_loop could be created inside a pfor loop. 115 # However for the purpose here, we declare such loops to be external. Also 116 # note that we check if the condition was created inside or outside to 117 # determine if the while_loop was first created inside or outside. 118 # TODO(agarwal): check that the Enter and Exit of this loop are unstacked. 119 self._is_inside_loop = self.op_is_inside_loop(self._condition.op) 120 if self._is_inside_loop: 121 for e in self._while_context.loop_exits: 122 assert self.op_is_inside_loop(e.op) 123 124 # Note the code below tries to reverse engineer an existing while_loop graph 125 # by assuming the following pattern of nodes. 126 # 127 # NextIteration <---- Body <--- Enter 128 # | ^ 129 # V ___| Y 130 # Enter -> Merge -> Switch___ 131 # ^ | N 132 # | V 133 # LoopCond Exit 134 135 # Node that elements in the list below correspond one-to-one with each 136 # other. i.e. these lists are the same size, and the i_th entry corresponds 137 # to different Operations/Tensors of a single cycle as illustrated above. 138 # List of Switch ops (ops.Operation) that feed into an Exit Node. 139 self._exit_switches = [] 140 # List of inputs (ops.Tensor) to NextIteration. 141 self._body_outputs = [] 142 # List of list of control inputs of the NextIteration nodes. 143 self._next_iter_control_inputs = [] 144 # List of Merge ops (ops.Operation). 145 self._enter_merges = [] 146 # List of output (ops.Tensor) of Exit nodes. 147 self._outputs = [] 148 149 # List of Enter Tensors. 150 # There are two types of Enter nodes: 151 # - The Enter nodes that are used in the `loop_vars` argument to 152 # `while_loop` (see 153 # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect 154 # these Enter nodes immediately below by tracing backwards from the Exit 155 # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the 156 # diagram above. This allows us to have a 1:1 correspondence between the 157 # self._outputs and the first elements in self._enters. 158 # - The Enter nodes that are used only by the body. They don't appear in the 159 # `loop_vars` and are not returned from the `while_loop`. In Python code, 160 # they are usually captured by the body lambda. We collect them below by 161 # iterating over all the ops in the graph. They are appended to the end of 162 # self._enters or self._direct_enters, and don't correspond to any outputs 163 # in self._outputs. Note that we keep the resource/variant Enter nodes in 164 # self._direct_enters and the constructed while_loop's body uses them 165 # directly as opposed to passing them as loop variables. This is done 166 # because the while_body cannot partition the resource/variant Tensors, so 167 # it has to leave them unchanged. 168 self._enters = [] 169 self._direct_enters = [] 170 171 for e in self._while_context.loop_exits: 172 self._outputs.append(e.op.outputs[0]) 173 switch = e.op.inputs[0].op 174 assert switch.type == "Switch", switch 175 self._exit_switches.append(switch) 176 merge = switch.inputs[0].op 177 assert merge.type == "Merge", merge 178 self._enter_merges.append(merge) 179 enter = merge.inputs[0].op 180 assert enter.type == "Enter", enter 181 self._enters.append(enter.outputs[0]) 182 next_iter = merge.inputs[1].op 183 assert next_iter.type == "NextIteration", next_iter 184 self._body_outputs.append(next_iter.inputs[0]) 185 self._next_iter_control_inputs.append(next_iter.control_inputs) 186 187 # Collect all the Enter nodes that are not part of `loop_vars`, the second 188 # category described above. 189 # Also track whether the loop body has any stateful ops. 190 self._is_stateful = False 191 for op in ops.get_default_graph().get_operations(): 192 # TODO(agarwal): make sure this works with nested case. 193 control_flow_context = op._get_control_flow_context() 194 if control_flow_context is None: 195 continue 196 if control_flow_context.name == self._context_name: 197 self._is_stateful |= _is_stateful_pfor_op(op) 198 if op.type == "Enter": 199 output = op.outputs[0] 200 if output not in self._enters: 201 if output.dtype in (dtypes.resource, dtypes.variant): 202 if output not in self._direct_enters: 203 self._direct_enters.append(output) 204 else: 205 self._enters.append(output) 206 207 def __str__(self): 208 """String representation.""" 209 return "while_loop(%s)" % self.name 210 211 @property 212 def inputs(self): 213 """Input to all the Enter nodes.""" 214 return [x.op.inputs[0] for x in self._enters + self._direct_enters] 215 216 @property 217 def control_inputs(self): 218 """Control input to all the Enter nodes.""" 219 control_inputs = [] 220 for x in self._enters + self._direct_enters: 221 control_inputs.extend(x.op.control_inputs) 222 return control_inputs 223 224 @property 225 def outputs(self): 226 """Outputs of all the Exit nodes.""" 227 return self._outputs 228 229 @property 230 def name(self): 231 """Context name for the while loop.""" 232 return self._context_name 233 234 @property 235 def is_inside_loop(self): 236 """Returns true if the while_loop was created inside the pfor.""" 237 return self._is_inside_loop 238 239 def op_is_inside_loop(self, op): 240 """True if op was created inside the pfor loop body.""" 241 assert isinstance(op, ops.Operation) 242 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops 243 # since it appears there tensorflow API could return different python 244 # objects representing the same Operation node. 245 return op._id in self._pfor_op_ids 246 247 @property 248 def is_stateful(self): 249 return self._is_stateful 250 251 @property 252 def pfor_converter(self): 253 """Return a converter for the while loop.""" 254 return self 255 256 def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs, 257 inputs_stacked): 258 """Create a PFor object for converting parts of the while_loop. 259 260 Args: 261 parent_pfor: PFor object being used for converting the while_loop. 262 indices: int32 Tensor of ids for the iterations that are still active 263 (i.e. did not exit the while_loop). 264 cond_stacked: True if the while_loop condition is stacked. 265 inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note 266 that these Tensors are a subset of the loop variables for the generated 267 while_loop. 268 inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`, 269 indicating if the value is stacked or not. 270 271 Returns: 272 A PFor instance. The instance is initialized by adding conversion mappings 273 of nodes that will be external to the conversion that the returned 274 instance will be used for. e.g. Enter nodes as well as Merge and Switch 275 outputs are mapped to converted values. 276 """ 277 num_outputs = len(self._outputs) 278 assert len(inputs) == len(self._enters) 279 assert len(inputs_stacked) == len(self._enters) 280 loop_var = parent_pfor.loop_var 281 loop_len = array_ops.size(indices) 282 pfor = PFor( 283 loop_var, 284 loop_len, 285 pfor_ops=self._pfor_ops, 286 all_indices=indices, 287 all_indices_partitioned=cond_stacked, 288 pfor_config=self._pfor_config) 289 # Map all inputs of Enter nodes in self._direct_enters to their converted 290 # values. 291 for enter in self._direct_enters: 292 enter_input = enter.op.inputs[0] 293 converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper( 294 enter_input) 295 # Since these are resources / variants, they should be unstacked. 296 assert not stacked and not is_sparse_stacked, (enter, converted_enter) 297 pfor._add_conversion(enter, wrap(converted_enter, False)) 298 299 # Map all Enter nodes to the inputs. 300 for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked): 301 pfor._add_conversion(enter, wrap(inp, stacked)) 302 # Map outputs of Switch and Merge. 303 for i in range(num_outputs): 304 wrapped_inp = wrap(inputs[i], inputs_stacked[i]) 305 merge = self._enter_merges[i] 306 pfor._add_conversion(merge.outputs[0], wrapped_inp) 307 # Note that second output of Merge is typically not used, except possibly 308 # as a control dependency. To avoid trying to output the correct value, we 309 # employ a hack here. We output a dummy invalid value with an incorrect 310 # dtype. This will allow control dependency to work but if using it as an 311 # input, it should typically lead to errors during graph construction due 312 # to dtype mismatch. 313 # TODO(agarwal): Check in the original graph to see if there are any 314 # consumers of this Tensor that use it as an input. 315 pfor._add_conversion(merge.outputs[1], 316 wrap(constant_op.constant(-1.0), False)) 317 switch = self._exit_switches[i] 318 # Don't need to worry about switch.output[0] which will feed to Exit node. 319 pfor._add_conversion(switch.outputs[1], wrapped_inp) 320 return pfor 321 322 def _convert_enter(self, parent_pfor, enter): 323 """Converts an Enter node.""" 324 inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0]) 325 control_inputs = [ 326 parent_pfor._convert_helper(x).t for x in enter.op.control_inputs 327 ] 328 if control_inputs: 329 with ops.control_dependencies(control_inputs): 330 inp = array_ops.identity(inp) 331 return inp, stacked 332 333 def _maybe_stacked(self, cache, inp): 334 """Heuristic to figue out if the coverting inp leads to a stacked value. 335 336 337 Args: 338 cache: map from Tensor to boolean indicating stacked/unstacked. 339 inp: input Tensor. 340 341 Returns: 342 True if `inp` could get stacked. If the function returns False, the 343 converted value should be guaranteed to be unstacked. If returning True, 344 it may or may not be stacked. 345 """ 346 if inp in cache: 347 return cache[inp] 348 if not self.op_is_inside_loop(inp.op): 349 return False 350 op = inp.op 351 output = False 352 if op.type in [ 353 "Shape", 354 "Rank" 355 "ShapeN", 356 "ZerosLike", 357 "TensorArrayV3", 358 "TensorArraySizeV3", 359 ]: 360 output = False 361 elif _is_stateful_pfor_op(op): 362 # This may be fairly aggressive. 363 output = True 364 elif op.type == "Exit": 365 # This may be fairly aggressive. 366 output = True 367 else: 368 for t in op.inputs: 369 if self._maybe_stacked(cache, t): 370 output = True 371 break 372 cache[inp] = output 373 return output 374 375 def _create_init_values(self, pfor_input): 376 """Create arguments passed to converted while_loop.""" 377 with ops.name_scope("while_init"): 378 loop_len_vector = pfor_input.pfor.loop_len_vector 379 loop_len = loop_len_vector[0] 380 num_outputs = len(self._outputs) 381 382 inputs = [] 383 maybe_stacked_cache = {} 384 # Convert all the Enters. Need to do this before checking for stacking 385 # below. 386 for i, enter in enumerate(self._enters): 387 inp, stacked = self._convert_enter(pfor_input.pfor, enter) 388 inputs.append(inp) 389 maybe_stacked_cache[enter] = stacked 390 # Since this enter node is part of the `loop_vars`, it corresponds to an 391 # output and its preceding switch. We mark this switch's output the same 392 # stackness, to act at the base case for the logic below. Below, we will 393 # be going through the body figuring out which inputs might need to be 394 # stacked and which inputs can safely remain unstacked. 395 if i < num_outputs: 396 maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked 397 398 # Shape invariants for init_values corresponding to self._enters. 399 input_shape_invariants = [] 400 # TensorArrays for outputs of converted while loop 401 output_tas = [] 402 # Shape invariants for output TensorArrays. 403 ta_shape_invariants = [] 404 # List of booleans indicating stackness of inputs, i.e. tensors 405 # corresponding to self._enters. 406 inputs_stacked = [] 407 for i, inp in enumerate(inputs): 408 enter = self._enters[i] 409 inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter) 410 # Note that even when an input is unstacked, the body could make it 411 # stacked. we use a heuristic below to figure out if body may be making 412 # it stacked. 413 if i < num_outputs: 414 body_output = self._body_outputs[i] 415 if enter.op in self._pfor_ops: 416 body_output_stacked = self._maybe_stacked(maybe_stacked_cache, 417 body_output) 418 else: 419 # If constructed outside of pfor loop, then the output would not be 420 # stacked. 421 body_output_stacked = False 422 if body_output_stacked and not inp_stacked: 423 inp = _stack(inp, loop_len_vector).t 424 inputs[i] = inp 425 inp_stacked = True 426 # TODO(agarwal): other attributes for the TensorArray ? 427 output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len)) 428 ta_shape_invariants.append(tensor_shape.TensorShape(None)) 429 430 inputs_stacked.append(inp_stacked) 431 input_shape_invariants.append(tensor_shape.TensorShape(None)) 432 433 # See documentation for __call__ for the structure of init_values. 434 init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas 435 # TODO(agarwal): try stricter shape invariants 436 shape_invariants = ( 437 [tensor_shape.TensorShape(None), 438 tensor_shape.TensorShape(None) 439 ] + input_shape_invariants + ta_shape_invariants) 440 441 return init_values, inputs_stacked, shape_invariants 442 443 def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): 444 """Handles case when condition is unstacked. 445 446 Note that all iterations end together. So we don't need to partition the 447 inputs. When all iterations are done, we write the inputs to the 448 TensorArrays. Note that we only write to index 0 of output_tas. Since all 449 iterations end together, they can all be output together. 450 """ 451 not_all_done = array_ops.reshape(conditions, []) 452 new_output_tas = [] 453 # pylint: disable=cell-var-from-loop 454 for i, out_ta in enumerate(output_tas): 455 inp = inputs[i] 456 new_output_tas.append( 457 control_flow_ops.cond(not_all_done, 458 lambda: out_ta, 459 lambda: out_ta.write(0, inp))) 460 # pylint: enable=cell-var-from-loop 461 return not_all_done, indices, inputs, new_output_tas 462 463 def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, 464 output_tas): 465 num_outputs = len(self._outputs) 466 # Compute if all iterations are done. 467 not_all_done = math_ops.reduce_any(conditions) 468 conditions_int = math_ops.cast(conditions, dtypes.int32) 469 # Partition the indices. 470 done_indices, new_indices = data_flow_ops.dynamic_partition( 471 indices, conditions_int, 2) 472 473 new_inputs = [] 474 new_output_tas = [] 475 for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): 476 # Partition the inputs. 477 if stacked: 478 done_inp, new_inp = data_flow_ops.dynamic_partition( 479 inp, conditions_int, 2) 480 else: 481 # TODO(agarwal): avoid this stacking. See TODO earlier in 482 # _process_cond_unstacked. 483 done_inp = _stack(inp, [array_ops.size(done_indices)]).t 484 new_inp = inp 485 new_inputs.append(new_inp) 486 # For iterations that are done, write them to TensorArrays. 487 if i < num_outputs: 488 out_ta = output_tas[i] 489 # Note that done_indices can be empty. done_inp should also be empty in 490 # that case. 491 new_output_tas.append(out_ta.scatter(done_indices, done_inp)) 492 return not_all_done, new_indices, new_inputs, new_output_tas 493 494 def _process_body(self, pfor_input, inputs_stacked, 495 new_indices, cond_stacked, new_inputs, 496 not_all_done): 497 """Convert the body function.""" 498 499 def true_fn(control_inputs, body_pfor, body_output, stacked): 500 """Converts the body function for all but last iteration. 501 502 This essentially converts body_output. Additionally, it needs to handle 503 any control dependencies on the NextIteration node. So it creates another 504 Identity node with the converted dependencies. 505 """ 506 converted_control_inp = [] 507 for x in control_inputs: 508 for t in x.outputs: 509 converted_control_inp.append(body_pfor._convert_helper(t).t) 510 if stacked: 511 # Note convert always does the stacking. 512 output = body_pfor.convert(body_output) 513 else: 514 output, convert_stacked, _ = body_pfor._convert_helper(body_output) 515 assert convert_stacked == stacked, body_output 516 with ops.control_dependencies(converted_control_inp): 517 return array_ops.identity(output) 518 519 body_pfor = self._init_pfor(pfor_input.pfor, new_indices, 520 cond_stacked, new_inputs, 521 inputs_stacked) 522 new_outputs = [] 523 524 for i, (body_output, stacked) in enumerate( 525 zip(self._body_outputs, inputs_stacked)): 526 control_inp = self._next_iter_control_inputs[i] 527 out_dtype = body_output.dtype 528 # Note that we want to run the body only if not all pfor iterations are 529 # done. If all are done, we return empty tensors since these values will 530 # not be used. Notice that the value returned by the loop is based on 531 # TensorArrays and not directly on these returned values. 532 # pylint: disable=cell-var-from-loop 533 new_output = control_flow_ops.cond( 534 not_all_done, 535 lambda: true_fn(control_inp, body_pfor, body_output, stacked), 536 lambda: constant_op.constant([], dtype=out_dtype)) 537 # pylint: enable=cell-var-from-loop 538 new_outputs.append(new_output) 539 return new_outputs 540 541 def __call__(self, pfor_input): 542 """Converter for the while_loop. 543 544 The conversion of a while_loop is another while_loop. 545 546 The arguments to this converted while_loop are as follows: 547 not_all_done: Boolean scalar Tensor indicating if all the pfor iterations 548 are done. 549 indices: int32 1-D Tensor storing the id of the iterations that are not 550 done. 551 args: Remaining arguments. These can be divided into 3 categories: 552 - First set of arguments are the tensors that correspond to the initial 553 elements of self._enters. The elements that appear in original while 554 loop's `loop_vars`. 555 - The second set of arguments are the tensors that correspond to the 556 remaining elements of self._enters. These are the tensors that directly 557 enter the original while loop body. 558 - Finally, the last set of arguments are TensorArrays. These TensorArrays 559 correspond to the outputs of the original while_loop, i.e. to the 560 elements in self._outputs. Each TensorArray has `PFor.loop_len` 561 elements, i.e. the number of pfor iterations. At the end, the i'th 562 element of each TensorArray will contain the output computed by the 563 i'th iteration of pfor. Note that elements can be written into these 564 tensors arrays in any order, depending on when the corresponding pfor 565 iteration is done. 566 If the original while_loop had `k` tensors in its `loop_vars` and its body 567 directly captured `m` tensors, the `args` will contain `2 * k + m` values. 568 569 In each iteration, the while_loop body recomputes the condition for all 570 active pfor iterations to see which of them are now done. It then partitions 571 all the inputs and passes them along to the converted body. Values for all 572 the iterations that are done are written to TensorArrays indexed by the pfor 573 iteration number. When all iterations are done, the TensorArrays are stacked 574 to get the final value. 575 576 Args: 577 pfor_input: A PForInput object corresponding to the output of any Exit 578 node from this while loop. 579 580 Returns: 581 List of converted outputs. 582 """ 583 # Create init_values that will be passed to the while_loop. 584 init_values, inputs_stacked, shape_invariants = self._create_init_values( 585 pfor_input) 586 # Note that we use a list as a hack since we need the nested function body 587 # to set the value of cond_is_stacked. python2.x doesn't support nonlocal 588 # variables. 589 cond_is_stacked = [None] 590 591 def cond(not_all_done, *_): 592 return not_all_done 593 594 def body(not_all_done, indices, *args): 595 # See documentatin for __call__ for the structure of *args. 596 num_enters = len(self._enters) 597 inputs = args[:num_enters] 598 output_tas = args[num_enters:] 599 # TODO(agarwal): see which outputs have consumers and only populate the 600 # TensorArrays corresponding to those. Or do those paths get trimmed out 601 # from inside the while_loop body? 602 assert len(inputs) >= len(output_tas) 603 assert len(inputs) == len(inputs_stacked) 604 605 # Convert condition 606 with ops.name_scope("while_cond"): 607 # Note that we set cond_stacked to True here. At this point we don't 608 # know if it could be loop invariant, hence the conservative value is 609 # to assume stacked. 610 cond_pfor = self._init_pfor(pfor_input.pfor, indices, 611 cond_stacked=True, 612 inputs=inputs, 613 inputs_stacked=inputs_stacked) 614 conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition) 615 cond_is_stacked[0] = cond_stacked 616 617 # Recompute the new condition, write outputs of done iterations, and 618 # partition the inputs if needed. 619 if not cond_stacked: 620 (not_all_done, new_indices, 621 new_inputs, new_output_tas) = self._process_cond_unstacked( 622 conditions, indices, inputs, output_tas) 623 else: 624 (not_all_done, new_indices, 625 new_inputs, new_output_tas) = self._process_cond_stacked( 626 conditions, indices, inputs, inputs_stacked, output_tas) 627 628 # Convert body 629 with ops.name_scope("while_body"): 630 # Compute the outputs from the body. 631 new_outputs = self._process_body(pfor_input, inputs_stacked, 632 new_indices, cond_stacked, new_inputs, 633 not_all_done) 634 635 # Note that the first num_outputs new values of inputs are computed using 636 # the body. Rest of them were direct Enters into the condition/body and 637 # the partitioning done earlier is sufficient to give the new value. 638 num_outputs = len(self._outputs) 639 new_args = ([not_all_done, new_indices] + new_outputs + list( 640 new_inputs[num_outputs:]) + new_output_tas) 641 return tuple(new_args) 642 643 while_outputs = control_flow_ops.while_loop( 644 cond, body, init_values, shape_invariants=shape_invariants) 645 output_tas = while_outputs[-len(self._outputs):] 646 outputs = [] 647 assert cond_is_stacked[0] is not None 648 for inp_stacked, ta in zip(inputs_stacked, output_tas): 649 if cond_is_stacked[0]: 650 outputs.append(wrap(ta.stack(), True)) 651 else: 652 # Note that if while_loop condition is unstacked, all iterations exit at 653 # the same time and we wrote those outputs in index 0 of the tensor 654 # array. 655 outputs.append(wrap(ta.read(0), inp_stacked)) 656 return outputs 657 658 659class _PforInput(object): 660 """Input object passed to registered pfor converters.""" 661 662 def __init__(self, pfor, op, inputs): 663 """Creates a _PforInput object. 664 665 Args: 666 pfor: PFor converter object. 667 op: the Operation object that is being converted. 668 inputs: list of WrappedTensor objects representing converted values of the 669 inputs of `op`. 670 """ 671 self.pfor = pfor 672 self._op = op 673 self._inputs = inputs 674 675 def stack_inputs(self, stack_indices=None): 676 """Stacks unstacked inputs at `stack_indices`. 677 678 Args: 679 stack_indices: indices of inputs at which stacking is done. If None, 680 stacking is done at all indices. 681 """ 682 if stack_indices is None: 683 stack_indices = range(len(self._inputs)) 684 length = self.pfor.loop_len_vector 685 for i in stack_indices: 686 inp = self._inputs[i] 687 if not inp.is_stacked: 688 self._inputs[i] = _stack(inp.t, length) 689 690 def expanddim_inputs_for_broadcast(self): 691 """Reshapes stacked inputs to prepare them for broadcast. 692 693 Since stacked inputs have an extra leading dimension, automatic broadcasting 694 rules could incorrectly try to expand dimensions before that leading 695 dimension. To avoid that, we reshape these stacked inputs to the maximum 696 rank they will need to be broadcasted to. 697 """ 698 if not self._inputs: 699 return 700 701 # Find max rank 702 def _get_rank(x): 703 rank = array_ops.rank(x.t) 704 if not x.is_stacked: 705 rank += 1 706 return rank 707 708 ranks = [_get_rank(x) for x in self._inputs] 709 max_rank = ranks[0] 710 for rank in ranks[1:]: 711 max_rank = math_ops.maximum(rank, max_rank) 712 713 for i, inp in enumerate(self._inputs): 714 if inp.is_stacked: 715 shape = array_ops.shape(inp.t) 716 rank_diff = array_ops.reshape(max_rank - ranks[i], [1]) 717 ones = array_ops.tile([1], rank_diff) 718 new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0) 719 self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True) 720 721 @property 722 def inputs(self): 723 return self._inputs 724 725 @property 726 def num_inputs(self): 727 return len(self._inputs) 728 729 def input(self, index): 730 assert len(self._inputs) > index, (index, self._inputs) 731 return self._inputs[index] 732 733 def stacked_input(self, index): 734 t, is_stacked, _ = self.input(index) 735 if not is_stacked: 736 op_type = self.op_type 737 op_def = getattr(self._op, "op_def", None) 738 if op_def is None: 739 input_name = "at index %d" % index 740 else: 741 input_name = "\"%s\"" % op_def.input_arg[index].name 742 raise ValueError("Input %s of op \"%s\" expected to be not loop invariant" 743 ".\nError while converting op %s" 744 "with converted inputs\n%s" % (input_name, op_type, 745 self._op, self.inputs)) 746 return t 747 748 def unstacked_input(self, index): 749 t, is_stacked, _ = self.input(index) 750 if is_stacked: 751 op_type = self.op_type 752 op_def = getattr(self._op, "op_def", None) 753 if op_def is None: 754 input_name = "at index %d" % index 755 else: 756 input_name = "\"%s\"" % op_def.input_arg[index].name 757 raise ValueError("Input %s of op \"%s\" expected to be loop invariant" 758 ".\nError while converting op %s" 759 "with converted inputs\n%s" % (input_name, op_type, 760 self._op, self.inputs)) 761 return t 762 763 @property 764 def op(self): 765 return self._op 766 767 @property 768 def op_type(self): 769 return self._op.type 770 771 def get_attr(self, attr): 772 return self._op.get_attr(attr) 773 774 @property 775 def outputs(self): 776 return self._op.outputs 777 778 def output(self, index): 779 assert index < len(self._op.outputs) 780 return self._op.outputs[index] 781 782 783_pfor_converter_registry = {} 784 785 786class RegisterPFor(object): 787 """Utility to register converters for pfor. 788 789 Usage: 790 @RegisterPFor(foo_op_type) 791 def _foo_converter(pfor_input): 792 ... 793 794 The above will register conversion function `_foo_converter` for handling 795 conversion of `foo_op_type`. During conversion, the registered functin will be 796 called with a single argument of type `PForInput` which will contain state 797 needed for the conversion. This registered function should output a list of 798 WrappedTensor object with the same length as the number of outputs of op being 799 converted. If the op had zero outputs, then it should return a ops.Operation 800 object. 801 """ 802 803 def __init__(self, op_type): 804 """Creates an object to register a converter for op with type `op_type`.""" 805 self.op_type = op_type 806 807 def __call__(self, converter): 808 name = self.op_type 809 assert name not in _pfor_converter_registry, "Re-registering %s " % name 810 _pfor_converter_registry[name] = converter 811 return converter 812 813 814class RegisterPForWithArgs(RegisterPFor): 815 """Utility to register converters for pfor. 816 817 Usage: 818 @RegisteRPFor(foo_op_type, foo=value, ....) 819 def _foo_converter(pfor_input, foo=None, ....): 820 ... 821 822 See RegisterPFor for details on the conversion function. 823 `RegisterPForWithArgs` allows binding extra arguments to the 824 conversion function at registration time. 825 """ 826 827 def __init__(self, op_type, *args, **kw_args): 828 super(RegisterPForWithArgs, self).__init__(op_type) 829 self._args = args 830 self._kw_args = kw_args 831 832 def __call__(self, converter): 833 834 def _f(pfor_input): 835 return converter(pfor_input, self.op_type, *self._args, **self._kw_args) 836 837 super(RegisterPForWithArgs, self).__call__(_f) 838 return converter 839 840 841def _create_op(op_type, inputs, op_dtypes, attrs=None): 842 """Utility to create an op.""" 843 return ops.get_default_graph().create_op( 844 op_type, inputs, op_dtypes, attrs=attrs, compute_device=True) 845 846 847WrappedTensor = collections.namedtuple("WrappedTensor", 848 ["t", "is_stacked", "is_sparse_stacked"]) 849"""Wrapper around the result of a Tensor conversion. 850 851The additional fields are useful for keeping track of the conversion state as 852data flows through the ops in the loop body. For every op whose output is a 853Tensor, its converter should return either a WrappedTensor or a list of 854WrappedTensors. 855 856Args: 857 t: The converted tensor 858 is_stacked: True if the tensor is stacked, i.e. represents the results of all 859 the iterations of the loop, where each row i of the tensor corresponds to 860 that op's output on iteration i of the loop. False if the tensor is not 861 stacked, i.e. represents the result of the op on of a single iteration of 862 the loop, where the result does not vary between iterations. 863 is_sparse_stacked: True if the tensor corresponds to a component tensor 864 (indices, values, or dense_shape) of a sparse tensor, and has been logically 865 stacked via a sparse conversion. 866""" 867 868 869def wrap(tensor, is_stacked=True, is_sparse_stacked=False): 870 """Helper to create a WrappedTensor object.""" 871 assert isinstance(is_stacked, bool) 872 assert isinstance(is_sparse_stacked, bool) 873 assert isinstance(tensor, ops.Tensor) 874 assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is " 875 "stacked via a sparse " 876 "conversion, it must also be " 877 "stacked.") 878 return WrappedTensor(tensor, is_stacked, is_sparse_stacked) 879 880 881def _fallback_converter(pfor_input): 882 logging.warn("Using a while_loop for converting %s", pfor_input.op_type) 883 output_dtypes = [x.dtype for x in pfor_input.outputs] 884 iters = pfor_input.pfor.loop_len_vector[0] 885 886 def while_body(i, *ta_list): 887 """Body of while loop.""" 888 inputs = [ 889 x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs 890 ] 891 op_outputs = _create_op( 892 pfor_input.op_type, 893 inputs, 894 output_dtypes, 895 attrs=pfor_input.op.node_def.attr).outputs 896 897 outputs = [] 898 for out, ta in zip(op_outputs, ta_list): 899 assert isinstance(out, ops.Tensor) 900 outputs.append(ta.write(i, array_ops.expand_dims(out, 0))) 901 return tuple([i + 1] + outputs) 902 903 ta_list = control_flow_ops.while_loop( 904 lambda i, *ta: i < iters, while_body, [0] + [ 905 tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes 906 ])[1:] 907 return tuple([wrap(ta.concat(), True) for ta in ta_list]) 908 909 910class PForConfig(object): 911 """A configuration object used to communicate with loop body function.""" 912 913 def __init__(self): 914 # This may be set to the number of iterations. 915 self._maybe_iters = None 916 # Map from output placeholder to the unvectorized tensor. 917 self._reduce_concat_map = {} 918 # Reverse map of `self._reduce_concat_map`. 919 self._reverse_reduce_concat_map = {} 920 921 def _has_reductions(self): 922 """True if some reductions where performed by loop body.""" 923 return len(self._reduce_concat_map) 924 925 def _set_iters(self, iters): 926 """Set number of pfor iterations.""" 927 self._maybe_iters = iters 928 929 # TODO(agarwal): handle reductions inside control flow constructs. 930 def reduce_concat(self, x): 931 """Performs a concat reduction on `x` across pfor iterations. 932 933 Note that this currently may not work inside a control flow construct. 934 Args: 935 x: an unvectorized Tensor. 936 937 Returns: 938 A Tensor that has rank one higher than `x`. The value is the vectorized 939 version of `x`, i.e. stacking the value of `x` across different pfor 940 iterations. 941 """ 942 assert not context.executing_eagerly() 943 assert isinstance(x, ops.Tensor) 944 if x not in self._reduce_concat_map: 945 out_shape = tensor_shape.TensorShape([self._maybe_iters]).concatenate( 946 x.shape) 947 with ops.control_dependencies([x]): 948 # Control dependency to make sure out is converted after x. 949 out = array_ops.placeholder(x.dtype, out_shape) 950 self._reduce_concat_map[out] = x 951 self._reverse_reduce_concat_map[x] = out 952 return out 953 else: 954 return self._reverse_reduce_concat_map[x] 955 956 def reduce_mean(self, x): 957 """Performs a mean reduction on `x` across pfor iterations. 958 959 Note that this currently may not work inside a control flow construct. 960 Args: 961 x: an unvectorized Tensor. 962 963 Returns: 964 A Tensor that has same rank as `x`. The value is the mean of the values 965 of `x` across the pfor iterations. 966 """ 967 y = self.reduce_concat(x) 968 return math_ops.reduce_mean(y, axis=0) 969 970 def reduce_sum(self, x): 971 """Performs a sum reduction on `x` across pfor iterations. 972 973 Note that this currently may not work inside a control flow construct. 974 Args: 975 x: an unvectorized Tensor. 976 977 Returns: 978 A Tensor that has same rank as `x`. The value is the sum of the values 979 of `x` across the pfor iterations. 980 """ 981 y = self.reduce_concat(x) 982 return math_ops.reduce_sum(y, axis=0) 983 984 def _lookup_reduction(self, pl): 985 """Lookups Placeholder `pl` in the reduction map.""" 986 assert isinstance(pl, ops.Tensor) 987 return self._reduce_concat_map.get(pl, None) 988 989 990class PFor(object): 991 """Implementation of rewrite of parallel-for loops. 992 993 This class takes a DAG or a set of DAGs representing the body of a 994 parallel-for loop, and adds new operations to the graph that implements 995 functionality equivalent to running that loop body for a specified number of 996 iterations. This new set of nodes may or may not use a tensorflow loop 997 construct. 998 999 The process of conversion does not delete or change any existing operations. 1000 It only adds operations that efficiently implement the equivalent 1001 functionality. We refer to the added ops as "converted ops". 1002 1003 The conversion process uses a simple greedy heuristic. It walks the loop body 1004 and tries to express the functionality of running each node in a loop with a 1005 new set of nodes. When converting an op several cases are possible: 1006 - The op is not inside the loop body. Hence it can be used as is. 1007 - The op does not depend on the iteration number and is stateless. In this 1008 case, it can be used as is. 1009 - The op is not stateful, and depends on iteration number only through control 1010 dependencies. In this case, we can create a single op with same inputs and 1011 attributes, but with "converted" control dependencies. 1012 - The op is not stateful, and all its inputs are loop invariant. In this 1013 case, similar to above, we can create a single op with same inputs and 1014 attributes, but with "converted" control dependencies. 1015 - The op is stateful or at least one of the inputs is not loop invariant. In 1016 this case, we run the registered converter for that op to create a set of 1017 converted ops. All nodes in the set will have converted control dependencies 1018 corresponding to control dependencies of the original op. If the op returned 1019 multiple outputs, "converted outputs" could be produced by different ops in 1020 this set. 1021 """ 1022 1023 def __init__(self, 1024 loop_var, 1025 loop_len, 1026 pfor_ops, 1027 all_indices=None, 1028 all_indices_partitioned=False, 1029 pfor_config=None): 1030 """Creates an object to rewrite a parallel-for loop. 1031 1032 Args: 1033 loop_var: ops.Tensor output of a Placeholder operation. The value should 1034 be an int32 scalar representing the loop iteration number. 1035 loop_len: A scalar or scalar Tensor representing the number of iterations 1036 the loop is run for. 1037 pfor_ops: List of all ops inside the loop body. 1038 all_indices: If not None, an int32 vector with size `loop_len` 1039 representing the iteration ids that are still active. These values 1040 should be unique and sorted. However they may not be contiguous. This is 1041 typically the case when inside a control flow construct which has 1042 partitioned the indices of the iterations that are being converted. 1043 all_indices_partitioned: If True, this object is being constructed from a 1044 control flow construct where not all the pfor iterations are guaranteed 1045 to be active. 1046 pfor_config: PForConfig object used while constructing the loop body. 1047 """ 1048 assert isinstance(loop_var, ops.Tensor) 1049 assert loop_var.op.type == "Placeholder" 1050 self._loop_var = loop_var 1051 loop_len_value = tensor_util.constant_value(loop_len) 1052 if loop_len_value is not None: 1053 loop_len = loop_len_value 1054 self._loop_len_vector = array_ops.reshape(loop_len, [1]) 1055 self._all_indices_partitioned = all_indices_partitioned 1056 if all_indices_partitioned: 1057 assert all_indices is not None 1058 self.all_indices = ( 1059 math_ops.range(loop_len) if all_indices is None else all_indices) 1060 1061 self._conversion_map = {} 1062 self._conversion_map[loop_var] = wrap(self.all_indices, True) 1063 self._pfor_ops = set(pfor_ops) 1064 self._pfor_op_ids = set([x._id for x in pfor_ops]) 1065 self._pfor_config = pfor_config 1066 1067 def op_is_inside_loop(self, op): 1068 """True if op was created inside the pfor loop body.""" 1069 assert isinstance(op, ops.Operation) 1070 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops 1071 # since it appears there tensorflow API could return different python 1072 # objects representing the same Operation node. 1073 return op._id in self._pfor_op_ids 1074 1075 def _convert_sparse(self, y): 1076 """Returns the converted value corresponding to SparseTensor y. 1077 1078 For SparseTensors, instead of stacking the component tensors separately, 1079 resulting in component tensors with shapes (N, m, rank), (N, m), and (N, 1080 rank) respectively for indices, values, and dense_shape (where N is the loop 1081 length and m is the number of sparse tensor values per loop iter), we want 1082 to logically stack the SparseTensors, to create a SparseTensor whose 1083 components are size (N * m, rank + 1), (N * m, ), and (rank + 1,) 1084 respectively. 1085 1086 Here, we try to get the conversion of each component tensor. 1087 If the tensors are stacked via a sparse conversion, return the resulting 1088 SparseTensor composed of the converted components. Otherwise, the component 1089 tensors are either unstacked or stacked naively. In the latter case, we 1090 unstack the component tensors to reform loop_len SparseTensor elements, 1091 then correctly batch them. 1092 1093 The unstacked tensors must have the same rank. Each dimension of each 1094 SparseTensor will expand to be the largest among all SparseTensor elements 1095 for that dimension. For example, if there are N SparseTensors of rank 3 1096 being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i), 1097 the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)). 1098 1099 Args: 1100 y: A tf.SparseTensor. 1101 1102 Returns: 1103 A tf.SparseTensor that is the converted value corresponding to y. 1104 """ 1105 outputs = [ 1106 self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape) 1107 ] 1108 assert all(isinstance(o, WrappedTensor) for o in outputs) 1109 1110 if all(w.is_sparse_stacked for w in outputs): 1111 return sparse_tensor.SparseTensor(*[w.t for w in outputs]) 1112 1113 assert not any(w.is_sparse_stacked for w in outputs), ( 1114 "Error converting SparseTensor. All components should be logically " 1115 "stacked, or none.") 1116 1117 # If component tensors were not sparsely stacked, they are either unstacked 1118 # or stacked without knowledge that they are components of sparse tensors. 1119 # In this case, we have to restack them. 1120 return self._restack_sparse_tensor_logically( 1121 *[self._unwrap_or_tile(w) for w in outputs]) 1122 1123 def _restack_sparse_tensor_logically(self, indices, values, shape): 1124 sparse_tensor_rank = indices.get_shape().dims[-1].value 1125 if sparse_tensor_rank is not None: 1126 sparse_tensor_rank += 1 1127 1128 def fn(args): 1129 res = gen_sparse_ops.serialize_sparse( 1130 args[0], args[1], args[2], out_type=dtypes.variant) 1131 return res 1132 1133 # Applies a map function to the component tensors to serialize each 1134 # sparse tensor element and batch them all, then deserializes the batch. 1135 # TODO(rachelim): Try to do this without map_fn -- add the right offsets 1136 # to shape and indices tensors instead. 1137 result = map_fn.map_fn( 1138 fn, [indices, values, shape], dtype=dtypes.variant) 1139 return sparse_ops.deserialize_sparse( 1140 result, dtype=values.dtype, rank=sparse_tensor_rank) 1141 1142 def _unwrap_or_tile(self, wrapped_tensor): 1143 """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it.""" 1144 output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked 1145 if is_stacked: 1146 return output 1147 else: 1148 return _stack(output, self._loop_len_vector).t 1149 1150 def convert(self, y): 1151 """Returns the converted value corresponding to y. 1152 1153 Args: 1154 y: A ops.Tensor or a ops.Operation object. If latter, y should not have 1155 any outputs. 1156 1157 Returns: 1158 If y does not need to be converted, it returns y as is. Else it returns 1159 the "converted value" corresponding to y. 1160 """ 1161 if y is None: 1162 return None 1163 if isinstance(y, sparse_tensor.SparseTensor): 1164 return self._convert_sparse(y) 1165 output = self._convert_helper(y) 1166 if isinstance(output, WrappedTensor): 1167 assert isinstance(y, ops.Tensor) 1168 return self._unwrap_or_tile(output) 1169 else: 1170 assert isinstance(y, ops.Operation) 1171 assert not y.outputs 1172 assert isinstance(output, ops.Operation) 1173 return output 1174 1175 def _was_converted(self, t): 1176 """True if t is not a conversion of itself.""" 1177 converted_t = self._conversion_map[t] 1178 return converted_t.t is not t 1179 1180 def _add_conversion(self, old_output, new_output): 1181 self._conversion_map[old_output] = new_output 1182 1183 def _convert_helper(self, op_or_tensor): 1184 stack = [op_or_tensor] 1185 while stack: 1186 y = stack[0] 1187 if y in self._conversion_map: 1188 assert isinstance(self._conversion_map[y], 1189 (WrappedTensor, ops.Operation)) 1190 stack.pop(0) 1191 continue 1192 if isinstance(y, ops.Operation): 1193 assert not y.outputs, ( 1194 "We only support converting Operation objects with no outputs. " 1195 "Got %s", y) 1196 y_op = y 1197 else: 1198 assert isinstance(y, ops.Tensor), y 1199 y_op = y.op 1200 1201 is_while_loop = y_op.type == "Exit" 1202 if is_while_loop: 1203 while_op = WhileOp( 1204 y, pfor_ops=self._pfor_ops, pfor_config=self._pfor_config) 1205 is_inside_loop = while_op.is_inside_loop 1206 # If all nodes in the while_loop graph were created inside the pfor, we 1207 # treat the whole loop subgraph as a single op (y_op) and try to convert 1208 # it. For while_loops that are created completely or partially outside, 1209 # we treat them as external and should be able to simply return the Exit 1210 # node output as is without needing any conversion. Note that for 1211 # while_loops that are partially constructed inside, we assume they will 1212 # be loop invariant. If that is not the case, it will create runtime 1213 # errors since the converted graph would depend on the self._loop_var 1214 # placeholder. 1215 if is_inside_loop: 1216 y_op = while_op 1217 else: 1218 is_inside_loop = self.op_is_inside_loop(y_op) 1219 1220 # If this op was not created inside the loop body, we will return as is. 1221 # 1. Convert inputs and control inputs. 1222 1223 def _add_to_stack(x): 1224 if x not in self._conversion_map: 1225 stack.insert(0, x) 1226 return True 1227 else: 1228 return False 1229 1230 if is_inside_loop: 1231 added_to_stack = False 1232 for inp in y_op.inputs: 1233 added_to_stack |= _add_to_stack(inp) 1234 for cinp in y_op.control_inputs: 1235 if cinp.outputs: 1236 for t in cinp.outputs: 1237 added_to_stack |= _add_to_stack(t) 1238 else: 1239 added_to_stack |= _add_to_stack(cinp) 1240 if added_to_stack: 1241 continue 1242 1243 converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs] 1244 some_input_converted = any(self._was_converted(x) for x in y_op.inputs) 1245 some_input_stacked = any(x.is_stacked for x in converted_inputs) 1246 1247 converted_control_ops = set() 1248 some_control_input_converted = False 1249 for cinp in y_op.control_inputs: 1250 if cinp.outputs: 1251 for t in cinp.outputs: 1252 converted_t = self._conversion_map[t] 1253 if self._was_converted(t): 1254 some_control_input_converted = True 1255 converted_control_ops.add(converted_t.t.op) 1256 else: 1257 converted_cinp = self._conversion_map[cinp] 1258 assert isinstance(converted_cinp, ops.Operation) 1259 if converted_cinp != cinp: 1260 some_control_input_converted = True 1261 converted_control_ops.add(converted_cinp) 1262 converted_control_ops = list(converted_control_ops) 1263 is_stateful = _is_stateful_pfor_op(y_op) 1264 else: 1265 converted_inputs = [] 1266 converted_control_ops = [] 1267 logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op, 1268 converted_inputs, converted_control_ops) 1269 1270 # 2. Convert y_op 1271 # If converting a while_loop, we let the while_loop convertor deal with 1272 # putting the control dependencies appropriately. 1273 control_dependencies = [] if is_while_loop else converted_control_ops 1274 with ops.control_dependencies(control_dependencies), ops.name_scope( 1275 y_op.name + "/pfor/"): 1276 # Op is a placeholder for a reduction. 1277 if (self._pfor_config is not None and 1278 self._pfor_config._lookup_reduction(y) is not None): 1279 # Handle reductions. Map the placeholder to the unvectorized input 1280 # that is being reduced. 1281 reduction_input = self._pfor_config._lookup_reduction(y) 1282 assert isinstance(reduction_input, ops.Tensor), reduction_input 1283 # Tensor being reduced should already be converted due to a control 1284 # dependency on the created placeholder. 1285 # Note that in cases where reduction_input is in an outer context, one 1286 # needs to locate the corresponding Enter node and use that to lookup 1287 # the conversion. 1288 # TODO(agarwal): handle reductions inside control flow constructs. 1289 assert reduction_input in self._conversion_map, ( 1290 "Unable to handle reduction of %s, possibly as it was used " 1291 "inside a control flow construct. Note that reductions across " 1292 "pfor iterations are currently not supported inside control flow " 1293 "constructs." % reduction_input) 1294 output = self._conversion_map[reduction_input] 1295 # If original input is not stacked, we tile it. Also we always mark 1296 # output as unstacked. 1297 new_outputs = [wrap(self._unwrap_or_tile(output), False)] 1298 # None of the inputs and control inputs were converted. 1299 elif (not is_inside_loop or 1300 (not is_stateful and not some_input_converted and 1301 not some_control_input_converted)): 1302 if y == y_op: 1303 assert not isinstance(y_op, WhileOp) 1304 new_outputs = y_op 1305 else: 1306 new_outputs = [wrap(x, False) for x in y_op.outputs] 1307 elif not (is_stateful or is_while_loop or some_input_stacked): 1308 # All inputs are unstacked or uncoverted but some control inputs are 1309 # converted. 1310 # TODO(rachelim): Handle the case where some inputs are sparsely 1311 # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs)) 1312 new_op = _create_op(y_op.type, [x.t for x in converted_inputs], 1313 [x.dtype for x in y_op.outputs], 1314 y_op.node_def.attr) 1315 if y == y_op: 1316 new_outputs = new_op 1317 else: 1318 new_outputs = [wrap(x, False) for x in new_op.outputs] 1319 else: 1320 # Either some inputs are not loop invariant or op is stateful. 1321 if hasattr(y_op, "pfor_converter"): 1322 converter = y_op.pfor_converter 1323 else: 1324 converter = _pfor_converter_registry.get(y_op.type, None) 1325 if converter is None: 1326 if flags.FLAGS.op_conversion_fallback_to_while_loop: 1327 converter = _fallback_converter 1328 else: 1329 raise ValueError( 1330 "No converter defined for %s\n%s\ninputs: %s. " 1331 "\nEither add a converter or set " 1332 "--op_conversion_fallback_to_while_loop=True, " 1333 "which may run slower" % (y_op.type, y_op, converted_inputs)) 1334 # TODO(rachelim): Handle the case where some inputs are sparsely 1335 # stacked. We should only call the converter if it supports handling 1336 # those inputs. 1337 new_outputs = converter(_PforInput(self, y_op, converted_inputs)) 1338 if isinstance(new_outputs, WrappedTensor): 1339 new_outputs = [new_outputs] 1340 assert isinstance(new_outputs, 1341 (list, tuple, ops.Operation)), new_outputs 1342 logging.vlog(2, "converted %s %s", y_op, new_outputs) 1343 1344 # Insert into self._conversion_map 1345 if y == y_op: 1346 assert isinstance(new_outputs, ops.Operation) 1347 self._add_conversion(y_op, new_outputs) 1348 else: 1349 for old_output, new_output in zip(y_op.outputs, new_outputs): 1350 assert isinstance(new_output, WrappedTensor), (new_output, y, y_op) 1351 self._add_conversion(old_output, new_output) 1352 stack.pop(0) 1353 1354 return self._conversion_map[op_or_tensor] 1355 1356 @property 1357 def loop_len_vector(self): 1358 """Returns a single element vector whose value is number of iterations.""" 1359 return self._loop_len_vector 1360 1361 @property 1362 def loop_var(self): 1363 """Returns placeholder loop variable.""" 1364 return self._loop_var 1365 1366 @property 1367 def pfor_ops(self): 1368 return self._pfor_ops 1369 1370 @property 1371 def all_indices_partitioned(self): 1372 """all_indices_partitioned property. 1373 1374 Returns: 1375 True if we are inside a control flow construct and not all pfor iterations 1376 may be active. 1377 """ 1378 return self._all_indices_partitioned 1379 1380# nn_ops 1381 1382 1383def _flatten_first_two_dims(x): 1384 """Merges first two dimensions.""" 1385 old_shape = array_ops.shape(x) 1386 new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0) 1387 return array_ops.reshape(x, new_shape) 1388 1389 1390def _unflatten_first_dim(x, first_dim): 1391 """Splits first dimension into [first_dim, -1].""" 1392 old_shape = array_ops.shape(x) 1393 new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0) 1394 return array_ops.reshape(x, new_shape) 1395 1396 1397def _inputs_with_flattening(pfor_input, input_indices): 1398 """Stacks and flattens first dim of inputs at indices `input_indices`.""" 1399 if input_indices is None: 1400 input_indices = [] 1401 pfor_input.stack_inputs(stack_indices=input_indices) 1402 inputs = [] 1403 for i in range(pfor_input.num_inputs): 1404 if i in input_indices: 1405 inp = pfor_input.stacked_input(i) 1406 inp = _flatten_first_two_dims(inp) 1407 else: 1408 inp = pfor_input.unstacked_input(i) 1409 inputs.append(inp) 1410 return inputs 1411 1412 1413@RegisterPForWithArgs("Conv2D", dims=[0]) 1414@RegisterPForWithArgs("AvgPool", dims=[0]) 1415@RegisterPForWithArgs("MaxPool", dims=[0]) 1416@RegisterPForWithArgs("MaxPool3D", dims=[0]) 1417@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2]) 1418@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2]) 1419@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2]) 1420@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2]) 1421@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1]) 1422def _convert_flatten_batch(pfor_input, op_type, dims): 1423 del op_type 1424 inputs = _inputs_with_flattening(pfor_input, dims) 1425 outputs = _create_op( 1426 pfor_input.op_type, 1427 inputs, [x.dtype for x in pfor_input.outputs], 1428 attrs=pfor_input.op.node_def.attr).outputs 1429 n = pfor_input.pfor.loop_len_vector 1430 outputs = [_unflatten_first_dim(x, n) for x in outputs] 1431 return [wrap(x, True) for x in outputs] 1432 1433 1434_channel_flatten_input_cache = {} 1435 1436 1437def _channel_flatten_input(x, data_format): 1438 """Merge the stack dimension with the channel dimension. 1439 1440 If S is pfor's stacking dimension, then, 1441 - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose 1442 should be cheap. 1443 - for SNHWC, we transpose to NHWCS. 1444 We then merge the S and C dimension. 1445 1446 Args: 1447 x: ops.Tensor to transform. 1448 data_format: "NCHW" or "NHWC". 1449 1450 Returns: 1451 A 3-element tuple with the transformed value, along with the shape for 1452 reshape and order for transpose required to transform back. 1453 """ 1454 1455 graph = ops.get_default_graph() 1456 cache_key = (graph, x, data_format) 1457 if cache_key not in _channel_flatten_input_cache: 1458 x_shape = array_ops.shape(x) 1459 if data_format == b"NCHW": 1460 order = [1, 0, 2, 3, 4] 1461 shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0) 1462 reverse_order = order 1463 else: 1464 order = [1, 2, 3, 0, 4] 1465 shape = array_ops.concat([x_shape[1:4], [-1]], axis=0) 1466 reverse_order = [3, 0, 1, 2, 4] 1467 # Move S dimension next to C dimension. 1468 x = array_ops.transpose(x, order) 1469 reverse_shape = array_ops.shape(x) 1470 # Reshape to merge the S and C dimension. 1471 x = array_ops.reshape(x, shape) 1472 outputs = x, reverse_order, reverse_shape 1473 _channel_flatten_input_cache[cache_key] = outputs 1474 else: 1475 outputs = _channel_flatten_input_cache[cache_key] 1476 return outputs 1477 1478 1479# Note that with training=True, running FusedBatchNorm on individual examples 1480# is very different from running FusedBatchNorm on a batch of those examples. 1481# This is because, for the latter case, the operation can be considered as first 1482# computing the mean and variance over all the examples and then using these 1483# to scale all those examples. This creates a data dependency between these 1484# different "iterations" since the inputs to the scaling step depends on the 1485# statistics coming from all these inputs. 1486# As with other kernels, the conversion here effectively runs the kernel 1487# independently for each iteration, and returns outputs by stacking outputs from 1488# each of those iterations. 1489@RegisterPFor("FusedBatchNorm") 1490def _convert_fused_batch_norm(pfor_input): 1491 is_training = pfor_input.get_attr("is_training") 1492 # When BatchNorm is used with training=False, mean and variance are provided 1493 # externally and used as is by the op. Thus, we can merge the S and N 1494 # dimensions as we do for regular operations. 1495 # When BatchNorm is used with training=True, mean and variance are computed 1496 # for each channel across the batch dimension (first one). If we merge S and N 1497 # dimensions, mean and variances will be computed over a larger set. So, we 1498 # merge the S and C dimensions instead. 1499 if not is_training: 1500 # We return zeros for batch_mean and batch_variance output. Note that CPU 1501 # and GPU seem to have different behavior for those two outputs. CPU outputs 1502 # zero because these values are not used during inference. GPU outputs 1503 # something, probably real means and variances. 1504 inputs = _inputs_with_flattening(pfor_input, [0]) 1505 outputs = _create_op( 1506 pfor_input.op_type, 1507 inputs, [x.dtype for x in pfor_input.outputs], 1508 attrs=pfor_input.op.node_def.attr).outputs 1509 y = outputs[0] 1510 n = pfor_input.pfor.loop_len_vector 1511 y = _unflatten_first_dim(y, n) 1512 mean = pfor_input.unstacked_input(3) 1513 zeros = array_ops.zeros_like(mean) 1514 return [wrap(y, True), wrap(zeros, False), wrap(zeros, False)] 1515 1516 pfor_input.stack_inputs() 1517 data_format = pfor_input.get_attr("data_format") 1518 # We merge the first dimension with the "C" dimension, run FusedBatchNorm, and 1519 # then transpose back. 1520 x = pfor_input.stacked_input(0) 1521 x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format) 1522 # Note that we stack all the other inputs as well so that they are the same 1523 # size as the new size of the channel dimension. 1524 inputs = [x] + [ 1525 array_ops.reshape(pfor_input.stacked_input(i), [-1]) 1526 for i in range(1, pfor_input.num_inputs) 1527 ] 1528 outputs = _create_op( 1529 pfor_input.op_type, 1530 inputs, [x.dtype for x in pfor_input.outputs], 1531 attrs=pfor_input.op.node_def.attr).outputs 1532 y = outputs[0] 1533 y = array_ops.reshape(y, reverse_shape) 1534 y = array_ops.transpose(y, reverse_order) 1535 n = pfor_input.pfor.loop_len_vector 1536 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] 1537 outputs = [y] + outputs 1538 return [wrap(x, True) for x in outputs] 1539 1540 1541@RegisterPFor("FusedBatchNormGrad") 1542def _convert_fused_batch_norm_grad(pfor_input): 1543 pfor_input.stack_inputs() 1544 data_format = pfor_input.get_attr("data_format") 1545 y_backprop = pfor_input.stacked_input(0) 1546 y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format) 1547 x = pfor_input.stacked_input(1) 1548 x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format) 1549 inputs = [y_backprop, x] + [ 1550 array_ops.reshape(pfor_input.stacked_input(i), [-1]) 1551 for i in range(2, pfor_input.num_inputs) 1552 ] 1553 outputs = _create_op( 1554 pfor_input.op_type, 1555 inputs, [x.dtype for x in pfor_input.outputs], 1556 attrs=pfor_input.op.node_def.attr).outputs 1557 x_backprop = outputs[0] 1558 x_backprop = array_ops.reshape(x_backprop, x_reverse_shape) 1559 x_backprop = array_ops.transpose(x_backprop, x_reverse_order) 1560 n = pfor_input.pfor.loop_len_vector 1561 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] 1562 outputs = [x_backprop] + outputs 1563 return [wrap(output, True) for output in outputs] 1564 1565 1566@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0) 1567@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0) 1568def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims, 1569 shape_dim): 1570 del op_type 1571 inputs = _inputs_with_flattening(pfor_input, flatten_dims) 1572 n = pfor_input.pfor.loop_len_vector 1573 # Adjust the `input_sizes` input. 1574 ones = array_ops.ones( 1575 [array_ops.shape(inputs[shape_dim])[0] - 1], dtype=n.dtype) 1576 inputs[shape_dim] *= array_ops.concat([n, ones], axis=0) 1577 outputs = _create_op( 1578 pfor_input.op_type, 1579 inputs, [x.dtype for x in pfor_input.outputs], 1580 attrs=pfor_input.op.node_def.attr).outputs 1581 outputs = [_unflatten_first_dim(x, n) for x in outputs] 1582 return [wrap(x, True) for x in outputs] 1583 1584 1585@RegisterPFor("Conv2DBackpropFilter") 1586def _convert_conv2d_backprop_filter(pfor_input): 1587 pfor_input.stack_inputs(stack_indices=[2]) 1588 inputs, inputs_stacked, _ = pfor_input.input(0) 1589 filter_sizes = pfor_input.unstacked_input(1) 1590 grads = pfor_input.stacked_input(2) 1591 strides = pfor_input.get_attr("strides") 1592 padding = pfor_input.get_attr("padding") 1593 use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu") 1594 data_format = pfor_input.get_attr("data_format") 1595 dilations = pfor_input.get_attr("dilations") 1596 if inputs_stacked: 1597 # TODO(agarwal): Implement this efficiently. 1598 logging.warn("Conv2DBackpropFilter uses a while_loop. Fix that!") 1599 1600 def while_body(i, ta): 1601 inp_i = inputs[i, ...] 1602 grad_i = grads[i, ...] 1603 output = nn_ops.conv2d_backprop_filter( 1604 inp_i, 1605 filter_sizes, 1606 grad_i, 1607 strides=strides, 1608 padding=padding, 1609 use_cudnn_on_gpu=use_cudnn_on_gpu, 1610 data_format=data_format, 1611 dilations=dilations) 1612 return i + 1, ta.write(i, array_ops.expand_dims(output, 0)) 1613 1614 n = array_ops.reshape(pfor_input.pfor.loop_len_vector, []) 1615 _, ta = control_flow_ops.while_loop( 1616 lambda i, ta: i < n, while_body, 1617 (0, tensor_array_ops.TensorArray(inputs.dtype, n))) 1618 output = ta.concat() 1619 return wrap(output, True) 1620 else: 1621 # We merge the stack dimension with the channel dimension of the gradients 1622 # and pretend we had a larger filter (see change to filter_sizes below). 1623 # Once the filter backprop is computed, we reshape and transpose back 1624 # appropriately. 1625 grads, _, _ = _channel_flatten_input(grads, data_format) 1626 n = pfor_input.pfor.loop_len_vector 1627 old_filter_sizes = filter_sizes 1628 filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0) 1629 output = nn_ops.conv2d_backprop_filter( 1630 inputs, 1631 filter_sizes, 1632 grads, 1633 strides=strides, 1634 padding=padding, 1635 use_cudnn_on_gpu=use_cudnn_on_gpu, 1636 data_format=data_format, 1637 dilations=dilations) 1638 new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0) 1639 output = array_ops.reshape(output, new_filter_shape) 1640 output = array_ops.transpose(output, [3, 0, 1, 2, 4]) 1641 return wrap(output, True) 1642 1643 1644# array_ops 1645 1646 1647@RegisterPForWithArgs("Identity", array_ops.identity) 1648@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient) 1649@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part) 1650def _convert_identity(pfor_input, op_type, op_func): 1651 del op_type 1652 return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) 1653 1654 1655@RegisterPFor("IdentityN") 1656def _convert_identity_n(pfor_input): 1657 outputs = array_ops.identity_n([x.t for x in pfor_input.inputs]) 1658 return [wrap(out, inp.is_stacked) for out, inp in 1659 zip(outputs, pfor_input.inputs)] 1660 1661 1662@RegisterPFor("Reshape") 1663def _convert_reshape(pfor_input): 1664 t = pfor_input.stacked_input(0) 1665 shape = pfor_input.unstacked_input(1) 1666 new_dim = array_ops.shape(t)[:1] 1667 new_shape = array_ops.concat([new_dim, shape], axis=0) 1668 return wrap(array_ops.reshape(t, new_shape), True) 1669 1670 1671@RegisterPFor("ExpandDims") 1672def _convert_expanddims(pfor_input): 1673 t = pfor_input.stacked_input(0) 1674 dim = pfor_input.unstacked_input(1) 1675 dim += math_ops.cast(dim >= 0, dtypes.int32) 1676 return wrap(array_ops.expand_dims(t, axis=dim), True) 1677 1678 1679@RegisterPFor("Slice") 1680def _convert_slice(pfor_input): 1681 t = pfor_input.stacked_input(0) 1682 begin = pfor_input.unstacked_input(1) 1683 size = pfor_input.unstacked_input(2) 1684 begin = array_ops.concat([[0], begin], axis=0) 1685 size = array_ops.concat([[-1], size], axis=0) 1686 return wrap(array_ops.slice(t, begin, size), True) 1687 1688 1689@RegisterPFor("Tile") 1690def _convert_tile(pfor_input): 1691 t = pfor_input.stacked_input(0) 1692 multiples = pfor_input.unstacked_input(1) 1693 multiples = array_ops.concat([[1], multiples], 0) 1694 return wrap(array_ops.tile(t, multiples), True) 1695 1696 1697@RegisterPFor("Pack") 1698def _convert_pack(pfor_input): 1699 pfor_input.stack_inputs() 1700 axis = pfor_input.get_attr("axis") 1701 if axis >= 0: 1702 axis += 1 1703 return wrap( 1704 array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True) 1705 1706 1707@RegisterPFor("Unpack") 1708def _convert_unpack(pfor_input): 1709 value = pfor_input.stacked_input(0) 1710 axis = pfor_input.get_attr("axis") 1711 if axis >= 0: 1712 axis += 1 1713 num = pfor_input.get_attr("num") 1714 return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)] 1715 1716 1717@RegisterPFor("Pad") 1718def _convert_pad(pfor_input): 1719 t = pfor_input.stacked_input(0) 1720 paddings = pfor_input.unstacked_input(1) 1721 paddings = array_ops.concat([[[0, 0]], paddings], 0) 1722 return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True) 1723 1724 1725@RegisterPFor("Split") 1726def _convert_split(pfor_input): 1727 split_dim = pfor_input.unstacked_input(0) 1728 t = pfor_input.stacked_input(1) 1729 num_split = pfor_input.get_attr("num_split") 1730 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) 1731 return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)] 1732 1733 1734@RegisterPFor("SplitV") 1735def _convert_split_v(pfor_input): 1736 t = pfor_input.stacked_input(0) 1737 splits = pfor_input.unstacked_input(1) 1738 split_dim = pfor_input.unstacked_input(2) 1739 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) 1740 return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)] 1741 1742 1743@RegisterPFor("Transpose") 1744def _convert_transpose(pfor_input): 1745 t = pfor_input.stacked_input(0) 1746 perm = pfor_input.unstacked_input(1) 1747 new_perm = array_ops.concat([[0], perm + 1], axis=0) 1748 return wrap(array_ops.transpose(t, new_perm), True) 1749 1750 1751@RegisterPFor("ZerosLike") 1752def _convert_zeroslike(pfor_input): 1753 t = pfor_input.stacked_input(0) 1754 shape = array_ops.shape(t)[1:] 1755 return wrap(array_ops.zeros(shape, dtype=t.dtype), False) 1756 1757 1758@RegisterPFor("Gather") 1759@RegisterPFor("GatherV2") 1760def _convert_gather(pfor_input): 1761 param, param_stacked, _ = pfor_input.input(0) 1762 indices, indices_stacked, _ = pfor_input.input(1) 1763 op_type = pfor_input.op_type 1764 if op_type == "Gather": 1765 validate_indices = pfor_input.get_attr("validate_indices") 1766 axis = 0 1767 else: 1768 validate_indices = None 1769 axis = pfor_input.unstacked_input(2) 1770 axis_value = tensor_util.constant_value(axis) 1771 if axis_value is not None: 1772 axis = axis_value 1773 if indices_stacked and not param_stacked: 1774 if indices == pfor_input.pfor.all_indices and axis == 0: 1775 param_shape0 = param.shape.dims[0].value 1776 indices_shape0 = indices.shape.dims[0].value 1777 if param_shape0 is not None and indices_shape0 == param_shape0: 1778 # Note that with loops and conditionals, indices may not be contiguous. 1779 # However they will be sorted and unique. So if the shape matches, then 1780 # it must be picking up all the rows of param. 1781 return wrap(param, True) 1782 # TODO(agarwal): use array_ops.slice here. 1783 output = array_ops.gather( 1784 param, indices, validate_indices=validate_indices, axis=axis) 1785 if axis != 0: 1786 axis = control_flow_ops.cond( 1787 axis < 0, lambda: axis + array_ops.rank(param), lambda: axis) 1788 order = array_ops.concat( 1789 [[axis], 1790 math_ops.range(axis), 1791 math_ops.range(axis + 1, array_ops.rank(output))], 1792 axis=0) 1793 output = control_flow_ops.cond( 1794 math_ops.equal(axis, 0), lambda: output, 1795 lambda: array_ops.transpose(output, order)) 1796 return wrap(output, True) 1797 if param_stacked: 1798 loop_len_vector = pfor_input.pfor.loop_len_vector 1799 pfor_input.stack_inputs(stack_indices=[1]) 1800 indices = pfor_input.stacked_input(1) 1801 param_flat = _flatten_first_two_dims(param) 1802 1803 # Recompute indices to handle stacked param. 1804 indices_offset = math_ops.range( 1805 loop_len_vector[0]) * array_ops.shape(param)[1] 1806 # Reshape indices_offset to allow broadcast addition 1807 ones = array_ops.ones([array_ops.rank(indices) - 1], dtype=dtypes.int32) 1808 new_shape = array_ops.concat([loop_len_vector, ones], axis=0) 1809 indices_offset = array_ops.reshape(indices_offset, new_shape) 1810 indices += indices_offset 1811 1812 # TODO(agarwal): handle axis != 0. May need to transpose param or 1813 # array_ops.gather_nd. 1814 if isinstance(axis, ops.Tensor): 1815 axis_value = tensor_util.constant_value(axis) 1816 else: 1817 try: 1818 axis_value = int(axis) 1819 except TypeError: 1820 axis_value = None 1821 msg = ("Gather, where indices and param are both loop dependent, currently " 1822 "requires axis=0") 1823 if axis_value is not None and axis_value != 0: 1824 raise ValueError("Error while converting %s. %s. Got axis=%d" % 1825 (pfor_input.op, msg, axis)) 1826 with ops.control_dependencies( 1827 [check_ops.assert_equal(axis, 0, message=msg)]): 1828 output = array_ops.gather(param_flat, indices) 1829 return wrap(output, True) 1830 1831 1832@RegisterPFor("ConcatV2") 1833def _convert_concatv2(pfor_input): 1834 n = pfor_input.num_inputs 1835 pfor_input.stack_inputs(stack_indices=range(n - 1)) 1836 axis = pfor_input.unstacked_input(n - 1) 1837 axis += math_ops.cast(axis >= 0, axis.dtype) 1838 return wrap( 1839 array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis), 1840 True) 1841 1842 1843@RegisterPFor("StridedSlice") 1844def _convert_strided_slice(pfor_input): 1845 inp = pfor_input.stacked_input(0) 1846 begin = pfor_input.unstacked_input(1) 1847 end = pfor_input.unstacked_input(2) 1848 strides = pfor_input.unstacked_input(3) 1849 begin_mask = pfor_input.get_attr("begin_mask") 1850 end_mask = pfor_input.get_attr("end_mask") 1851 ellipsis_mask = pfor_input.get_attr("ellipsis_mask") 1852 new_axis_mask = pfor_input.get_attr("new_axis_mask") 1853 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") 1854 1855 begin = array_ops.concat([[0], begin], axis=0) 1856 end = array_ops.concat([[0], end], axis=0) 1857 strides = array_ops.concat([[1], strides], axis=0) 1858 begin_mask = begin_mask << 1 | 1 1859 end_mask = end_mask << 1 | 1 1860 ellipsis_mask <<= 1 1861 new_axis_mask <<= 1 1862 shrink_axis_mask <<= 1 1863 return wrap( 1864 array_ops.strided_slice( 1865 inp, 1866 begin, 1867 end, 1868 strides, 1869 begin_mask=begin_mask, 1870 end_mask=end_mask, 1871 ellipsis_mask=ellipsis_mask, 1872 new_axis_mask=new_axis_mask, 1873 shrink_axis_mask=shrink_axis_mask), True) 1874 1875 1876@RegisterPFor("StridedSliceGrad") 1877def _convert_strided_slice_grad(pfor_input): 1878 shape = pfor_input.unstacked_input(0) 1879 begin = pfor_input.unstacked_input(1) 1880 end = pfor_input.unstacked_input(2) 1881 strides = pfor_input.unstacked_input(3) 1882 dy = pfor_input.stacked_input(4) 1883 begin_mask = pfor_input.get_attr("begin_mask") 1884 end_mask = pfor_input.get_attr("end_mask") 1885 ellipsis_mask = pfor_input.get_attr("ellipsis_mask") 1886 new_axis_mask = pfor_input.get_attr("new_axis_mask") 1887 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") 1888 1889 shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 1890 begin = array_ops.concat([[0], begin], axis=0) 1891 end = array_ops.concat([[0], end], axis=0) 1892 strides = array_ops.concat([[1], strides], axis=0) 1893 begin_mask = begin_mask << 1 | 1 1894 end_mask = end_mask << 1 | 1 1895 ellipsis_mask <<= 1 1896 new_axis_mask <<= 1 1897 shrink_axis_mask <<= 1 1898 return wrap( 1899 array_ops.strided_slice_grad( 1900 shape, 1901 begin, 1902 end, 1903 strides, 1904 dy, 1905 begin_mask=begin_mask, 1906 end_mask=end_mask, 1907 ellipsis_mask=ellipsis_mask, 1908 new_axis_mask=new_axis_mask, 1909 shrink_axis_mask=shrink_axis_mask), True) 1910 1911 1912# math_ops 1913 1914 1915@RegisterPFor("MatMul") 1916def _convert_matmul(pfor_input): 1917 # TODO(agarwal): Check if tiling is faster than two transposes. 1918 a, a_stacked, _ = pfor_input.input(0) 1919 b, b_stacked, _ = pfor_input.input(1) 1920 tr_a = pfor_input.get_attr("transpose_a") 1921 tr_b = pfor_input.get_attr("transpose_b") 1922 if a_stacked and b_stacked: 1923 output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True) 1924 return output 1925 elif a_stacked: 1926 if tr_a: 1927 a = array_ops.transpose(a, [0, 2, 1]) 1928 if a.shape.is_fully_defined(): 1929 x, y, z = a.shape 1930 else: 1931 x, y, z = [ 1932 array_ops.reshape(i, []) 1933 for i in array_ops.split(array_ops.shape(a), 3) 1934 ] 1935 a = array_ops.reshape(a, [x * y, z]) 1936 prod = math_ops.matmul(a, b, transpose_b=tr_b) 1937 return wrap(array_ops.reshape(prod, [x, y, -1]), True) 1938 else: 1939 assert b_stacked 1940 if tr_b: 1941 perm = [2, 0, 1] 1942 b = array_ops.transpose(b, perm) 1943 else: 1944 # As an optimization, if one of the first two dimensions is 1, then we can 1945 # reshape instead of transpose. 1946 # TODO(agarwal): This check can be done inside Transpose kernel. 1947 b_shape = array_ops.shape(b) 1948 min_dim = math_ops.minimum(b_shape[0], b_shape[1]) 1949 perm = control_flow_ops.cond( 1950 math_ops.equal(min_dim, 1), lambda: [0, 1, 2], lambda: [1, 0, 2]) 1951 new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]]) 1952 b = array_ops.transpose(b, perm) 1953 b = array_ops.reshape(b, new_shape) 1954 1955 if b.shape.is_fully_defined(): 1956 x, y, z = b.shape 1957 else: 1958 x, y, z = [ 1959 array_ops.reshape(i, []) 1960 for i in array_ops.split(array_ops.shape(b), 3) 1961 ] 1962 b = array_ops.reshape(b, [x, y * z]) 1963 prod = math_ops.matmul(a, b, transpose_a=tr_a) 1964 prod = array_ops.reshape(prod, [-1, y, z]) 1965 prod = array_ops.transpose(prod, [1, 0, 2]) 1966 return wrap(prod, True) 1967 1968 1969@RegisterPFor("BatchMatMul") 1970def _convert_batch_mat_mul(pfor_input): 1971 # TODO(agarwal): There may be a more efficient way to do this instead of 1972 # stacking the inputs. 1973 pfor_input.stack_inputs() 1974 x = pfor_input.stacked_input(0) 1975 y = pfor_input.stacked_input(1) 1976 adj_x = pfor_input.get_attr("adj_x") 1977 adj_y = pfor_input.get_attr("adj_y") 1978 1979 x = _flatten_first_two_dims(x) 1980 y = _flatten_first_two_dims(y) 1981 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) 1982 output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector) 1983 return wrap(output, True) 1984 1985 1986@RegisterPForWithArgs("Sum", math_ops.reduce_sum) 1987@RegisterPForWithArgs("Prod", math_ops.reduce_prod) 1988@RegisterPForWithArgs("Max", math_ops.reduce_max) 1989@RegisterPForWithArgs("Min", math_ops.reduce_min) 1990@RegisterPForWithArgs("Mean", math_ops.reduce_mean) 1991def _convert_reduction(pfor_input, _, op_func): 1992 t = pfor_input.stacked_input(0) 1993 indices = pfor_input.unstacked_input(1) 1994 # Shift positive indices by one to account for the extra dimension. 1995 indices += math_ops.cast(indices >= 0, dtypes.int32) 1996 keep_dims = pfor_input.get_attr("keep_dims") 1997 return wrap(op_func(t, indices, keepdims=keep_dims), True) 1998 1999 2000@RegisterPForWithArgs("Cumsum", math_ops.cumsum) 2001@RegisterPForWithArgs("Cumprod", math_ops.cumprod) 2002def _convert_cumfoo(pfor_input, _, op_func): 2003 t = pfor_input.stacked_input(0) 2004 axis = pfor_input.unstacked_input(1) 2005 # Shift positive indices by one to account for the extra dimension. 2006 axis += math_ops.cast(axis >= 0, dtypes.int32) 2007 exclusive = pfor_input.get_attr("exclusive") 2008 reverse = pfor_input.get_attr("reverse") 2009 return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True) 2010 2011 2012@RegisterPFor("BiasAdd") 2013def _convert_biasadd(pfor_input): 2014 t, t_stacked, _ = pfor_input.input(0) 2015 bias, bias_stacked, _ = pfor_input.input(1) 2016 data_format = pfor_input.get_attr("data_format").decode() 2017 if bias_stacked: 2018 # BiasAdd only supports 1-D biases, so cast bias to match value and use Add. 2019 pfor_input.expanddim_inputs_for_broadcast() 2020 t, _, _ = pfor_input.input(0) 2021 bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype) 2022 if compat.as_bytes(data_format) == b"NCHW": 2023 b_shape = array_ops.shape(bias) 2024 new_b_shape = array_ops.concat( 2025 [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0) 2026 bias = array_ops.reshape(bias, new_b_shape) 2027 return wrap(math_ops.add(t, bias), True) 2028 else: 2029 assert t_stacked, "At least one input to BiasAdd should be loop variant." 2030 if compat.as_bytes(data_format) == b"NCHW": 2031 shape = array_ops.shape(t) 2032 flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0) 2033 t = array_ops.reshape(t, flattened_shape) 2034 t = nn_ops.bias_add(t, bias, data_format="NCHW") 2035 t = array_ops.reshape(t, shape) 2036 return wrap(t, True) 2037 return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True) 2038 2039 2040@RegisterPFor("UnsortedSegmentSum") 2041def _convert_unsortedsegmentsum(pfor_input): 2042 data, data_stacked, _ = pfor_input.input(0) 2043 # TODO(agarwal): handle unstacked? 2044 segment_ids = pfor_input.stacked_input(1) 2045 # TODO(agarwal): handle stacked? 2046 num_segments = pfor_input.unstacked_input(2) 2047 if not data_stacked: 2048 data = _stack(data, pfor_input.pfor.loop_len_vector).t 2049 segment_shape = array_ops.shape(segment_ids) 2050 n = segment_shape[0] 2051 ones = array_ops.ones_like(segment_shape)[1:] 2052 segment_offset = num_segments * math_ops.range(n) 2053 segment_offset = array_ops.reshape(segment_offset, 2054 array_ops.concat([[n], ones], axis=0)) 2055 segment_ids += segment_offset 2056 num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast( 2057 n, dtypes.int64) 2058 output = math_ops.unsorted_segment_sum(data, segment_ids, num_segments) 2059 new_output_shape = array_ops.concat( 2060 [[n, -1], array_ops.shape(output)[1:]], axis=0) 2061 output = array_ops.reshape(output, new_output_shape) 2062 return wrap(output, True) 2063 2064 2065@RegisterPFor("Cast") 2066def _convert_cast(pfor_input): 2067 inp = pfor_input.stacked_input(0) 2068 dtype = pfor_input.get_attr("DstT") 2069 return wrap(math_ops.cast(inp, dtype), True) 2070 2071 2072@RegisterPForWithArgs("Abs", math_ops.abs) 2073@RegisterPForWithArgs("Acos", math_ops.acos) 2074@RegisterPForWithArgs("Acosh", math_ops.acosh) 2075@RegisterPForWithArgs("Add", math_ops.add) 2076@RegisterPForWithArgs("AddV2", math_ops.add_v2) 2077@RegisterPForWithArgs("Angle", math_ops.angle) 2078@RegisterPForWithArgs("Asin", math_ops.asin) 2079@RegisterPForWithArgs("Asinh", math_ops.asinh) 2080@RegisterPForWithArgs("Atan", math_ops.atan) 2081@RegisterPForWithArgs("Atan2", math_ops.atan2) 2082@RegisterPForWithArgs("Atanh", math_ops.atanh) 2083@RegisterPForWithArgs("BesselI0e", math_ops.bessel_i0e) 2084@RegisterPForWithArgs("BesselI1e", math_ops.bessel_i1e) 2085@RegisterPForWithArgs("BitwiseAnd", bitwise_ops.bitwise_and) 2086@RegisterPForWithArgs("BitwiseOr", bitwise_ops.bitwise_or) 2087@RegisterPForWithArgs("BitwiseXor", bitwise_ops.bitwise_xor) 2088@RegisterPForWithArgs("Ceil", math_ops.ceil) 2089@RegisterPForWithArgs("Complex", math_ops.complex) 2090@RegisterPForWithArgs("ComplexAbs", math_ops.complex_abs) 2091@RegisterPForWithArgs("Conj", math_ops.conj) 2092@RegisterPForWithArgs("Cos", math_ops.cos) 2093@RegisterPForWithArgs("Cosh", math_ops.cosh) 2094@RegisterPForWithArgs("Digamma", math_ops.digamma) 2095@RegisterPForWithArgs("Div", math_ops.div) 2096@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan) 2097@RegisterPForWithArgs("Elu", nn_ops.elu) 2098@RegisterPForWithArgs("Equal", math_ops.equal) 2099@RegisterPForWithArgs("Erf", math_ops.erf) 2100@RegisterPForWithArgs("Erfc", math_ops.erfc) 2101@RegisterPForWithArgs("Exp", math_ops.exp) 2102@RegisterPForWithArgs("Expm1", math_ops.expm1) 2103@RegisterPForWithArgs("Floor", math_ops.floor) 2104@RegisterPForWithArgs("FloorDiv", math_ops.floor_div) 2105@RegisterPForWithArgs("FloorMod", math_ops.floor_mod) 2106@RegisterPForWithArgs("Greater", math_ops.greater) 2107@RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal) 2108@RegisterPForWithArgs("Igamma", math_ops.igamma) 2109@RegisterPForWithArgs("IgammaGradA", math_ops.igamma_grad_a) 2110@RegisterPForWithArgs("Igammac", math_ops.igammac) 2111@RegisterPForWithArgs("Imag", math_ops.imag) 2112@RegisterPForWithArgs("Inv", math_ops.inv) 2113@RegisterPForWithArgs("Invert", bitwise_ops.invert) 2114@RegisterPForWithArgs("IsFinite", math_ops.is_finite) 2115@RegisterPForWithArgs("IsInf", math_ops.is_inf) 2116@RegisterPForWithArgs("LeftShift", bitwise_ops.left_shift) 2117@RegisterPForWithArgs("Less", math_ops.less) 2118@RegisterPForWithArgs("LessEqual", math_ops.less_equal) 2119@RegisterPForWithArgs("Lgamma", math_ops.lgamma) 2120@RegisterPForWithArgs("Log", math_ops.log) 2121@RegisterPForWithArgs("Log1p", math_ops.log1p) 2122@RegisterPForWithArgs("LogicalAnd", math_ops.logical_and) 2123@RegisterPForWithArgs("LogicalNot", math_ops.logical_not) 2124@RegisterPForWithArgs("LogicalOr", math_ops.logical_or) 2125@RegisterPForWithArgs("LogicalXor", math_ops.logical_xor) 2126@RegisterPForWithArgs("Maximum", math_ops.maximum) 2127@RegisterPForWithArgs("Minimum", math_ops.minimum) 2128@RegisterPForWithArgs("Mod", math_ops.mod) 2129@RegisterPForWithArgs("Mul", math_ops.multiply) 2130@RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan) 2131@RegisterPForWithArgs("Neg", math_ops.negative) 2132@RegisterPForWithArgs("NotEqual", math_ops.not_equal) 2133@RegisterPForWithArgs("Polygamma", math_ops.polygamma) 2134@RegisterPForWithArgs("Pow", math_ops.pow) 2135@RegisterPForWithArgs("Real", math_ops.real) 2136@RegisterPForWithArgs("RealDiv", math_ops.divide) 2137@RegisterPForWithArgs("Reciprocal", math_ops.reciprocal) 2138@RegisterPForWithArgs("Relu", nn_ops.relu) 2139@RegisterPForWithArgs("Relu6", nn_ops.relu6) 2140@RegisterPForWithArgs("RightShift", bitwise_ops.right_shift) 2141@RegisterPForWithArgs("Rint", math_ops.rint) 2142@RegisterPForWithArgs("Round", math_ops.round) 2143@RegisterPForWithArgs("Rsqrt", math_ops.rsqrt) 2144@RegisterPForWithArgs("Selu", nn_ops.selu) 2145@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid) 2146@RegisterPForWithArgs("Sign", math_ops.sign) 2147@RegisterPForWithArgs("Sin", math_ops.sin) 2148@RegisterPForWithArgs("Sinh", math_ops.sinh) 2149@RegisterPForWithArgs("Softplus", nn_ops.softplus) 2150@RegisterPForWithArgs("Softsign", nn_ops.softsign) 2151@RegisterPForWithArgs("Sqrt", math_ops.sqrt) 2152@RegisterPForWithArgs("Square", math_ops.square) 2153@RegisterPForWithArgs("SquaredDifference", math_ops.squared_difference) 2154@RegisterPForWithArgs("Sub", math_ops.subtract) 2155@RegisterPForWithArgs("Tan", math_ops.tan) 2156@RegisterPForWithArgs("Tanh", math_ops.tanh) 2157@RegisterPForWithArgs("TruncateDiv", math_ops.truncate_div) 2158@RegisterPForWithArgs("TruncateMod", math_ops.truncate_mod) 2159@RegisterPForWithArgs("Xdivy", math_ops.xdivy) 2160@RegisterPForWithArgs("Xlogy", math_ops.xlogy) 2161@RegisterPForWithArgs("Zeta", math_ops.zeta) 2162def _convert_cwise(pfor_input, op_type, op_func): 2163 # Note that ops handled here do not have attributes except "T" and "Tout", and 2164 # hence don't need extra arguments passed to the cwise_op call below. 2165 for attr in pfor_input.op.node_def.attr.keys(): 2166 assert attr in [u"T", u"Tout"], (op_type, attr) 2167 pfor_input.expanddim_inputs_for_broadcast() 2168 return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) 2169 2170 2171@RegisterPFor("ApproximateEqual") 2172def _convert_approximate_equal(pfor_input): 2173 pfor_input.expanddim_inputs_for_broadcast() 2174 x = pfor_input.input(0)[0] 2175 y = pfor_input.input(1)[0] 2176 tolerance = pfor_input.get_attr("tolerance") 2177 return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True) 2178 2179 2180@RegisterPFor("Shape") 2181def _convert_shape(pfor_input): 2182 out_type = pfor_input.get_attr("out_type") 2183 return wrap( 2184 array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:], 2185 False) 2186 2187 2188@RegisterPFor("ShapeN") 2189def _convert_shape_n(pfor_input): 2190 out_type = pfor_input.get_attr("out_type") 2191 shapes = [ 2192 array_ops.shape(x, out_type=out_type)[1:] 2193 if stacked else array_ops.shape(x) for x, stacked, _ in pfor_input.inputs 2194 ] 2195 return [wrap(x, False) for x in shapes] 2196 2197 2198@RegisterPFor("Size") 2199def _convert_size(pfor_input): 2200 out_type = pfor_input.get_attr("out_type") 2201 n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type) 2202 return wrap( 2203 array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n, 2204 False) 2205 2206 2207@RegisterPFor("Rank") 2208def _convert_rank(pfor_input): 2209 return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False) 2210 2211 2212@RegisterPFor("AddN") 2213def _convert_addn(pfor_input): 2214 # AddN does not support broadcasting. 2215 pfor_input.stack_inputs() 2216 return wrap(math_ops.add_n([x.t for x in pfor_input.inputs]), True) 2217 2218 2219@RegisterPFor("BiasAddGrad") 2220def _convert_biasaddgrad(pfor_input): 2221 grad = pfor_input.stacked_input(0) 2222 fmt = pfor_input.get_attr("data_format") 2223 if fmt == b"NCHW": 2224 output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False) 2225 else: 2226 grad_shape = array_ops.shape(grad) 2227 last_dim_shape = grad_shape[-1] 2228 first_dim_shape = grad_shape[0] 2229 output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape]) 2230 output = math_ops.reduce_sum(output, axis=[1], keepdims=False) 2231 return wrap(output, True) 2232 2233 2234# Some required ops are not exposed under the tf namespace. Hence relying on 2235# _create_op to create them. 2236@RegisterPForWithArgs("EluGrad") 2237@RegisterPForWithArgs("Relu6Grad") 2238@RegisterPForWithArgs("ReluGrad") 2239@RegisterPForWithArgs("SeluGrad") 2240@RegisterPForWithArgs("SigmoidGrad") 2241@RegisterPForWithArgs("SoftplusGrad") 2242@RegisterPForWithArgs("SoftsignGrad") 2243@RegisterPForWithArgs("TanhGrad") 2244@RegisterPForWithArgs("SqrtGrad") 2245@RegisterPForWithArgs("RsqrtGrad") 2246@RegisterPForWithArgs("ReciprocalGrad") 2247def _convert_grads(pfor_input, op_type, *args, **kw_args): 2248 del args 2249 del kw_args 2250 # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we 2251 # have to use tiling here. 2252 pfor_input.stack_inputs() 2253 outputs = _create_op( 2254 op_type, [x.t for x in pfor_input.inputs], 2255 [x.dtype for x in pfor_input.outputs], 2256 attrs=pfor_input.op.node_def.attr).outputs 2257 return [wrap(x, True) for x in outputs] 2258 2259 2260@RegisterPFor("Select") 2261def _convert_select(pfor_input): 2262 pfor_input.stack_inputs() 2263 cond = pfor_input.stacked_input(0) 2264 t = pfor_input.stacked_input(1) 2265 e = pfor_input.stacked_input(2) 2266 cond_rank = array_ops.rank(cond) 2267 cond, t, e = control_flow_ops.cond( 2268 cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]), 2269 lambda: [cond, t, e]) 2270 outputs = _create_op( 2271 pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs], 2272 attrs=pfor_input.op.node_def.attr).outputs 2273 n = pfor_input.pfor.loop_len_vector 2274 out = control_flow_ops.cond(cond_rank > 1, 2275 lambda: _unflatten_first_dim(outputs[0], n), 2276 lambda: outputs[0]) 2277 return [wrap(out, True) for x in outputs] 2278 2279 2280# random_ops 2281 2282 2283@RegisterPForWithArgs("RandomUniform") 2284@RegisterPForWithArgs("RandomUniformInt") 2285@RegisterPForWithArgs("RandomStandardNormal") 2286@RegisterPForWithArgs("TruncatedNormal") 2287@RegisterPForWithArgs("RandomGamma") 2288@RegisterPForWithArgs("RandomPoissonV2") 2289def _convert_random(pfor_input, op_type, *args, **kw_args): 2290 del args 2291 del kw_args 2292 inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)] 2293 # inputs[0] is "shape" 2294 inputs[0] = array_ops.concat( 2295 [pfor_input.pfor.loop_len_vector, inputs[0]], axis=0) 2296 logging.warning( 2297 "Note that %s inside pfor op may not give same output as " 2298 "inside a sequential loop.", op_type) 2299 outputs = _create_op( 2300 op_type, 2301 inputs, [x.dtype for x in pfor_input.outputs], 2302 attrs=pfor_input.op.node_def.attr).outputs 2303 return [wrap(x, True) for x in outputs] 2304 2305 2306# logging_ops 2307 2308 2309@RegisterPFor("Assert") 2310def _convert_assert(pfor_input): 2311 cond, cond_stacked, _ = pfor_input.input(0) 2312 if cond_stacked: 2313 cond = math_ops.reduce_all(cond) 2314 2315 data_list = [x.t for x in pfor_input.inputs][1:] 2316 return _create_op("Assert", [cond] + data_list, [], 2317 attrs=pfor_input.op.node_def.attr) 2318 2319 2320@RegisterPFor("Print") 2321def _convert_print(pfor_input): 2322 # Note that we don't stack all the inputs. Hence unstacked values are printed 2323 # once here vs multiple times in a while_loop. 2324 pfor_input.stack_inputs([0]) 2325 outputs = _create_op( 2326 "Print", [x.t for x in pfor_input.inputs], 2327 [x.dtype for x in pfor_input.outputs], 2328 attrs=pfor_input.op.node_def.attr).outputs 2329 return [wrap(x, True) for x in outputs] 2330 2331 2332# data_flow_ops 2333 2334# TensorArray conversion is tricky since we don't support arrays of 2335# TensorArrays. For converting them, we consider two distinct cases: 2336# 2337# 1. The array is constructed outside the pfor call, and read/written inside the 2338# loop. 2339# This is an easier case since we don't need to make an array of TensorArrays. 2340# A correctness requirement is that these parallel iterations shouldn't attempt 2341# to write to the same location. Hence at conversion time we disallow indices to 2342# be loop-invariant as that would guarantee a collision. Even if the indices are 2343# not loop-invariant, they could conflict and that shall trigger runtime errors. 2344# 2345# 2. The array is constructed and used entirely inside each pfor iteration. 2346# For simplicity, here we require that the indices used for write/scatter are 2347# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in 2348# different pfor iterations. We consider two sub_cases: 2349# 2350# 2a Elements written to the array are "stacked" 2351# To simulate multiple TensorArrays, we may increase the dimension of each 2352# element of the array. i.e. the i_th row of the j_th entry of the converted 2353# TensorArray corresponds to the j_th entry of the TensorArray in the i_th 2354# pfor iteration. 2355# 2356# 2b Elements written to the array are "unstacked" 2357# In this case we don't increase the dimensions to avoid redundant tiling. Each 2358# iteration is trying to write the same value. So we convert that to a single 2359# write. 2360# 2361# Here are some tricks used to implement the above: 2362# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of 2363# trying to trace whether future writes are stacked or unstacked in order to set 2364# this attr, we set it to correspond to unknown shape. 2365# - We use the "flow" output of the different ops to track whether the array 2366# elements are stacked or unstacked. If a stacked write/scatter is done, we make 2367# the flow stacked as well. 2368# - We use some heuristic traversal of the graph to track whether the 2369# TensorArray handle was created inside or outside the pfor loop. 2370 2371 2372@RegisterPFor("TensorArrayV3") 2373def _convert_tensor_array_v3(pfor_input): 2374 size = pfor_input.unstacked_input(0) 2375 dtype = pfor_input.get_attr("dtype") 2376 dynamic_size = pfor_input.get_attr("dynamic_size") 2377 clear_after_read = pfor_input.get_attr("clear_after_read") 2378 identical_element_shapes = pfor_input.get_attr("identical_element_shapes") 2379 tensor_array_name = pfor_input.get_attr("tensor_array_name") 2380 handle, flow = data_flow_ops.tensor_array_v3( 2381 size, 2382 dtype=dtype, 2383 # We don't set element shape since we don't know if writes are stacked or 2384 # not yet. 2385 element_shape=None, 2386 dynamic_size=dynamic_size, 2387 clear_after_read=clear_after_read, 2388 identical_element_shapes=identical_element_shapes, 2389 tensor_array_name=tensor_array_name) 2390 # Note we keep flow unstacked for now since we don't know if writes will be 2391 # stacked or not. 2392 return wrap(handle, False), wrap(flow, False) 2393 2394 2395@RegisterPFor("TensorArraySizeV3") 2396def _convert_tensor_array_size_v3(pfor_input): 2397 handle = pfor_input.unstacked_input(0) 2398 flow, flow_stacked, _ = pfor_input.input(1) 2399 if flow_stacked: 2400 flow = _unstack_flow(flow) 2401 size = data_flow_ops.tensor_array_size_v3(handle, flow) 2402 return wrap(size, False) 2403 2404 2405def _handle_inside_pfor(pfor_input, handle): 2406 """Returns True if handle was created inside the pfor loop.""" 2407 # We use some heuristic to find the original TensorArray creation op. 2408 # The logic should handle the common cases (except cond based subgraphs). 2409 # In theory the user could perform different operations on the handle (like 2410 # Reshape, stack multiple handles, etc) which could break this logic. 2411 # TODO(agarwal): handle Switch/Merge. 2412 while handle.op.type in ("Enter", "Identity"): 2413 handle = handle.op.inputs[0] 2414 if handle.op.type not in [ 2415 "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape"]: 2416 raise ValueError("Unable to find source for handle %s" % handle) 2417 else: 2418 return pfor_input.pfor.op_is_inside_loop(handle.op) 2419 2420 2421def _unstack_flow(value): 2422 # TODO(agarwal): consider looking if this is a Tile op then get its input. 2423 # This may avoid running the Tile operations. 2424 return array_ops.gather(value, 0) 2425 2426 2427@RegisterPFor("TensorArrayReadV3") 2428def _convert_tensor_array_read_v3(pfor_input): 2429 handle = pfor_input.unstacked_input(0) 2430 index, index_stacked, _ = pfor_input.input(1) 2431 dtype = pfor_input.get_attr("dtype") 2432 flow, flow_stacked, _ = pfor_input.input(2) 2433 if flow_stacked: 2434 flow = _unstack_flow(flow) 2435 2436 is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 2437 if is_inside_pfor: 2438 # Note that if we are inside a control flow construct inside the pfor, and 2439 # only some of the iterations are doing the read (i.e. 2440 # `all_indices_partitioned` is True), then the read operation should only 2441 # return values for the currently active pfor iterations (`all_indices` 2442 # below). Hence, whenever the returned value is stacked (i.e. `flow` is 2443 # stacked), we may need to do an extra gather after reading the values. Also 2444 # note that if `is_inside` is false, then values in the tensor array are 2445 # unstacked. So the check is only needed in this branch. 2446 all_indices = pfor_input.pfor.all_indices 2447 all_indices_partitioned = pfor_input.pfor.all_indices_partitioned 2448 # Note: flow_stacked indicates if values in the TensorArray are stacked or 2449 # not. 2450 if index_stacked: 2451 if flow_stacked: 2452 raise ValueError( 2453 "It looks like TensorArrayReadV3 was called on a TensorArray whose" 2454 " values are not loop-invariant, and the read indices were also" 2455 " not loop invariant. This is currently unsupported.") 2456 value = data_flow_ops.tensor_array_gather_v3( 2457 handle, index, flow, dtype=dtype) 2458 return wrap(value, True) 2459 value = data_flow_ops.tensor_array_read_v3( 2460 handle, index, flow, dtype=dtype) 2461 if flow_stacked and all_indices_partitioned: 2462 value = array_ops.gather(value, all_indices) 2463 return wrap(value, flow_stacked) 2464 # Values in the TensorArray should be unstacked (since different iterations 2465 # couldn't write to the same location). So whether output is stacked or not 2466 # depends on index_stacked. 2467 if index_stacked: 2468 value = data_flow_ops.tensor_array_gather_v3( 2469 handle, index, flow, dtype=dtype) 2470 else: 2471 value = data_flow_ops.tensor_array_read_v3( 2472 handle, index, flow, dtype=dtype) 2473 return wrap(value, index_stacked) 2474 2475 2476@RegisterPFor("TensorArrayWriteV3") 2477def _convert_tensor_array_write_v3(pfor_input): 2478 handle = pfor_input.unstacked_input(0) 2479 index, index_stacked, _ = pfor_input.input(1) 2480 value, value_stacked, _ = pfor_input.input(2) 2481 flow, flow_stacked, _ = pfor_input.input(3) 2482 if value_stacked and pfor_input.pfor.all_indices_partitioned: 2483 # Looks like we are in a control flow in a pfor where not all iterations are 2484 # active now. We don't allow that since that could lead to different indices 2485 # having different shapes which will be hard to merge later. 2486 raise ValueError("Writing non loop invariant values to TensorArray from " 2487 "inside a while_loop/cond not supported.") 2488 if flow_stacked: 2489 flow = _unstack_flow(flow) 2490 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 2491 if is_inside: 2492 if index_stacked: 2493 raise ValueError("Need indices for %s to be loop invariant" % handle) 2494 if not flow_stacked and not value_stacked: 2495 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) 2496 return wrap(flow_out, False) 2497 else: 2498 if not value_stacked: 2499 value = _stack(value, pfor_input.pfor.loop_len_vector).t 2500 # TODO(agarwal): Note that if flow is unstacked and value is stacked, then 2501 # this may or may not be a safe situation. flow is unstacked both for a 2502 # freshly created TensorArray, as well as after unstacked values are 2503 # written to it. If it is the latter, then we cannot write a stacked value 2504 # now since that may cause runtime errors due to different shapes in the 2505 # array. At the moment we are not able to handle this gracefully and 2506 # distinguish between the two cases. That would require some heuristic 2507 # traversal of the graph to figure out whether all the writes are 2508 # unstacked or not. 2509 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) 2510 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 2511 else: 2512 if not index_stacked: 2513 raise ValueError("Need indices for %s to be not loop invariant" % handle) 2514 # Note that even when index_stacked is true, actual values in index may 2515 # still not be unique. However that will cause runtime error when executing 2516 # the scatter operation below. 2517 if not value_stacked: 2518 value = _stack(value, pfor_input.pfor.loop_len_vector).t 2519 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow) 2520 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 2521 2522 2523def _transpose_first_two_dims(value): 2524 # TODO(agarwal): optimize if one of the dims == 1. 2525 value_shape = array_ops.shape(value) 2526 v0 = value_shape[0] 2527 v1 = value_shape[1] 2528 value = array_ops.reshape(value, [v0, v1, -1]) 2529 value = array_ops.transpose(value, [1, 0, 2]) 2530 new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0) 2531 return array_ops.reshape(value, new_shape) 2532 2533 2534@RegisterPFor("TensorArrayGatherV3") 2535def _convert_tensor_array_gather_v3(pfor_input): 2536 handle = pfor_input.unstacked_input(0) 2537 indices, indices_stacked, _ = pfor_input.input(1) 2538 indices = array_ops.reshape(indices, [-1]) 2539 flow, flow_stacked, _ = pfor_input.input(2) 2540 if flow_stacked: 2541 flow = _unstack_flow(flow) 2542 dtype = pfor_input.get_attr("dtype") 2543 # TODO(agarwal): support element_shape attr? 2544 2545 n = pfor_input.pfor.loop_len_vector 2546 value = data_flow_ops.tensor_array_gather_v3( 2547 handle, indices, flow, dtype=dtype) 2548 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 2549 if is_inside: 2550 # flow_stacked indicates if values in the TensorArray are stacked or not. 2551 if indices_stacked: 2552 if flow_stacked: 2553 raise ValueError( 2554 "It looks like TensorArrayGatherV3 was called on a TensorArray " 2555 "whose values are not loop-invariant, and the indices were also " 2556 "not loop invariant. This is currently unsupported.") 2557 else: 2558 value = _unflatten_first_dim(value, n) 2559 return wrap(value, True) 2560 else: 2561 if flow_stacked: 2562 # Since elements in this array are stacked and `value` was produced by 2563 # gather, its first two dims are "gathered elements" and "stack 2564 # dimension". Our semantics require these two to be flipped. 2565 value = _transpose_first_two_dims(value) 2566 return wrap(value, flow_stacked) 2567 else: 2568 # Values in the TensorArray should be unstacked (since different iterations 2569 # couldn't write to the same location). So whether output is stacked or not 2570 # depends on indices_stacked. 2571 if indices_stacked: 2572 value = _unflatten_first_dim(value, n) 2573 return wrap(value, indices_stacked) 2574 2575 2576@RegisterPFor("TensorArrayScatterV3") 2577def _convert_tensor_array_scatter_v3(pfor_input): 2578 handle = pfor_input.unstacked_input(0) 2579 indices, indices_stacked, _ = pfor_input.input(1) 2580 indices = array_ops.reshape(indices, [-1]) 2581 value, value_stacked, _ = pfor_input.input(2) 2582 flow, flow_stacked, _ = pfor_input.input(3) 2583 2584 if flow_stacked: 2585 flow = _unstack_flow(flow) 2586 2587 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 2588 if is_inside: 2589 if indices_stacked: 2590 raise ValueError("Need indices for %s to be loop invariant" % handle) 2591 # Note that flow_stacked indicates if existing values in the array are 2592 # stacked or not. 2593 if not flow_stacked and not value_stacked: 2594 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 2595 flow) 2596 return wrap(flow_out, False) 2597 if not value_stacked: 2598 # TODO(agarwal): tile in the second dimension directly instead of 2599 # transposing below. 2600 value = _stack(value, pfor_input.pfor.loop_len_vector).t 2601 2602 value = _transpose_first_two_dims(value) 2603 # TODO(agarwal): Note that if a previous write was unstacked, flow will be 2604 # unstacked, and a stacked value may be written here which may cause 2605 # runtime error due to different elements having different shape. We do 2606 # not try to prevent that. 2607 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 2608 flow) 2609 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 2610 if not indices_stacked: 2611 raise ValueError("Need indices for %s to be not loop invariant" % handle) 2612 if not value_stacked: 2613 value = _stack(value, pfor_input.pfor.loop_len_vector).t 2614 value = _flatten_first_two_dims(value) 2615 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 2616 flow) 2617 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 2618 2619 2620@RegisterPFor("TensorArrayGradV3") 2621def _convert_tensor_array_grad_v3(pfor_input): 2622 handle = pfor_input.unstacked_input(0) 2623 flow, flow_stacked, _ = pfor_input.input(1) 2624 if flow_stacked: 2625 flow = _unstack_flow(flow) 2626 source = pfor_input.get_attr("source") 2627 # TODO(agarwal): For now, we assume that gradients are stacked if the 2628 # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong 2629 # will give runtime error due to incorrect shape being written to the 2630 # accumulator. It is difficult to know in advance if gradients written will be 2631 # stacked or not. Note that flow being stacked is not indicative of the 2632 # gradient being stacked or not. Revisit this later. 2633 shape_to_prepend = pfor_input.pfor.loop_len_vector 2634 grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape( 2635 handle=handle, 2636 flow_in=flow, 2637 shape_to_prepend=shape_to_prepend, 2638 source=source) 2639 flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t 2640 return [wrap(grad_handle, False), wrap(flow_out, True)] 2641 2642 2643# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar 2644# to TensorArrays, we convert them by changing the dimension of the elements 2645# inside the stack. 2646# 2647# We consider two cases: 2648# 2649# 1. StackV2 is constructed and used entirely inside the pfor loop. 2650# We keep a single Stack and perform the push/pop operations of all the 2651# iterations in lock-step. We also assume that all the iterations perform these 2652# operations. In case of dynamic control flow, if only some of the iterations 2653# try to perform a push/pop, then the conversion may not work correctly and may 2654# cause undefined behavior. 2655# TODO(agarwal): test StackV2 with dynamic control flow. 2656# 2657# 2. StackV2 is constructed outside the pfor loop. 2658# Performing stack push/pop in a parallel fashion is ill-defined. However given 2659# that reading stacks created externally is a common operation when computing 2660# jacobians, we provide some special semantics here as follows. 2661# - disallow push operations to the stack 2662# - pop operations are performed in lock step by all iterations, similar to the 2663# case when the stack is created inside. A single value is popped during the 2664# lock-step operation and broadcast to all the iterations. Values in the stack 2665# are assumed to be loop-invariant. 2666# 2667# Some other implementation details: 2668# We use an ugly logic to find whether values in Stack data structure are 2669# loop invariant or not. When converting push/pop operations, we keep track of 2670# whether the last conversion used a stacked value or not (see _stack_cache 2671# below). As a result if an unstacked value is written first, subsequent stacked 2672# writes are disallowed when they could have been allowed in theory. 2673 2674# Map from cache key based on StackV2 handle to a bool indicating whether values 2675# are stacked or not. 2676# TODO(agarwal): move _stack_cache inside pfor? 2677_stack_cache = {} 2678 2679 2680def _stack_cache_key(pfor_input): 2681 """Create cache key corresponding to a stack handle.""" 2682 op_type = pfor_input.op_type 2683 assert op_type in ["StackPushV2", "StackPopV2"], op_type 2684 orig_handle = pfor_input.op.inputs[0] 2685 while orig_handle.op.type in ["Identity", "Enter"]: 2686 orig_handle = orig_handle.op.inputs[0] 2687 assert orig_handle.op.type == "StackV2", orig_handle.op 2688 return ops.get_default_graph(), pfor_input.pfor, orig_handle 2689 2690 2691def _stack_handle_inside_pfor(handle, pfor_input): 2692 while handle.op.type in ["Identity", "Enter"]: 2693 handle = handle.op.inputs[0] 2694 assert handle.op.type == "StackV2", ( 2695 "Unable to find StackV2 op. Got %s" % handle.op) 2696 return pfor_input.pfor.op_is_inside_loop(handle.op) 2697 2698 2699@RegisterPFor("StackPushV2") 2700def _convert_stack_push_v2(pfor_input): 2701 handle = pfor_input.unstacked_input(0) 2702 elem, elem_stacked, _ = pfor_input.input(1) 2703 swap_memory = pfor_input.get_attr("swap_memory") 2704 2705 if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input): 2706 raise ValueError("StackPushV2 not allowed on stacks created outside pfor") 2707 stack_cache_key = _stack_cache_key(pfor_input) 2708 stacked = _stack_cache.get(stack_cache_key, None) 2709 if stacked is None: 2710 stacked = elem_stacked 2711 _stack_cache[stack_cache_key] = stacked 2712 else: 2713 # If we previously made it unstacked then we can't revert to being stacked. 2714 if not stacked and elem_stacked: 2715 raise ValueError( 2716 "It looks like the stack was previously determined to be loop" 2717 " invariant, but we are now trying to push a loop dependent value" 2718 " to it. This is currently unsupported.") 2719 if stacked and not elem_stacked: 2720 elem = _stack(elem, pfor_input.pfor.loop_len_vector).t 2721 out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory) 2722 return wrap(out, stacked) 2723 2724 2725# Note that inputs to this convertor will be unstacked. However it should get 2726# called since it is a stateful op. 2727@RegisterPFor("StackPopV2") 2728def _convert_stack_pop_v2(pfor_input): 2729 handle = pfor_input.unstacked_input(0) 2730 stack_cache_key = _stack_cache_key(pfor_input) 2731 stacked = _stack_cache.get(stack_cache_key, None) 2732 # If a StackPushV2 has not been converted yet, we default to unstacked since 2733 # the push could be outside of pfor, or the covertor may not be called if the 2734 # inputs are unconverted. 2735 if stacked is None: 2736 stacked = False 2737 _stack_cache[stack_cache_key] = False 2738 elem_type = pfor_input.get_attr("elem_type") 2739 out = data_flow_ops.stack_pop_v2(handle, elem_type) 2740 return wrap(out, stacked) 2741 2742 2743# parsing_ops 2744 2745 2746@RegisterPFor("DecodeCSV") 2747def _convert_decode_csv(pfor_input): 2748 lines = pfor_input.stacked_input(0) 2749 record_defaults = [ 2750 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) 2751 ] 2752 field_delim = pfor_input.get_attr("field_delim") 2753 use_quote_delim = pfor_input.get_attr("use_quote_delim") 2754 select_cols = pfor_input.get_attr("select_cols") 2755 if not select_cols: 2756 select_cols = None 2757 return [ 2758 wrap(t, True) for t in parsing_ops.decode_csv( 2759 lines, 2760 record_defaults, 2761 field_delim=field_delim, 2762 use_quote_delim=use_quote_delim, 2763 select_cols=select_cols) 2764 ] 2765 2766 2767@RegisterPFor("ParseSingleExample") 2768def _convert_parse_single_example(pfor_input): 2769 serialized = pfor_input.stacked_input(0) 2770 dense_defaults = [ 2771 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) 2772 ] 2773 sparse_keys = pfor_input.get_attr("sparse_keys") 2774 dense_keys = pfor_input.get_attr("dense_keys") 2775 sparse_types = pfor_input.get_attr("sparse_types") 2776 dense_shapes = pfor_input.get_attr("dense_shapes") 2777 output = gen_parsing_ops.parse_example( 2778 serialized=serialized, 2779 names=[], 2780 dense_defaults=dense_defaults, 2781 sparse_keys=sparse_keys, 2782 dense_keys=dense_keys, 2783 sparse_types=sparse_types, 2784 dense_shapes=dense_shapes) 2785 return [wrap(t, True, True) for t in nest.flatten(output)] 2786