1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Compiled parallel-for loop."""
16# pylint: disable=missing-docstring,g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23
24from tensorflow.python.eager import context
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import bitwise_ops
33from tensorflow.python.ops import check_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import data_flow_ops
36from tensorflow.python.ops import gen_parsing_ops
37from tensorflow.python.ops import gen_sparse_ops
38from tensorflow.python.ops import map_fn
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import nn_ops
41from tensorflow.python.ops import parsing_ops
42from tensorflow.python.ops import sparse_ops
43from tensorflow.python.ops import tensor_array_ops
44from tensorflow.python.platform import flags
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.python.util import compat
47from tensorflow.python.util import nest
48
49flags.DEFINE_bool(
50    "op_conversion_fallback_to_while_loop", False,
51    "If true, falls back to using a while loop for ops for "
52    "which a converter is not defined.")
53
54
55def _stack(t, length):
56  """stacks `t` `length` times."""
57  ones = array_ops.ones_like(array_ops.shape(t))
58  multiples = array_ops.concat([length, ones], 0)
59  t = array_ops.tile(array_ops.expand_dims(t, 0), multiples)
60  return wrap(t, True)
61
62
63# The following stateful ops can be safely called once, and with the same
64# signature as the unconverted version, if their inputs are loop invariant.
65# TODO(agarwal): implement a strategy for converting Variable reads/writes. The
66# plan is to map each read/write in the loop_fn to a corresponding merged
67# read/write in the converted graph. Writes need to be mergeable (e.g.
68# AssignAdd) to be used in `pfor`. Given a certain read/write order in the
69# loop_fn, doing a one-to-one conversion will simulate executing such
70# instructions in lock-step across all iterations.
71passthrough_stateful_ops = set([
72    "VariableV2",
73    "VarHandleOp",
74    "ReadVariableOp",
75    "StackV2",
76    "TensorArrayWriteV3",
77    "TensorArrayReadV3",
78    "TensorArraySizeV3",
79])
80
81
82def _is_stateful_pfor_op(op):
83  if isinstance(op, WhileOp):
84    return op.is_stateful
85  if op.type == "Const":
86    # Const didn't have an op_def.
87    return False
88  if op.type in passthrough_stateful_ops:
89    return False
90  assert hasattr(op, "op_def") and op.op_def is not None, op
91  return op.op_def.is_stateful
92
93
94# pylint: disable=protected-access
95class WhileOp(object):
96  """Object for storing state for converting the outputs of a while_loop."""
97
98  def __init__(self, exit_node, pfor_ops, pfor_config):
99    """Initializer.
100
101    Args:
102      exit_node: A tensor output from the while_loop.
103      pfor_ops: list of ops inside the current pfor loop.
104      pfor_config: PForConfig object used while constructing loop body.
105    """
106    self._pfor_config = pfor_config
107    self._pfor_ops = set(pfor_ops)
108    self._pfor_op_ids = set([x._id for x in pfor_ops])
109    assert isinstance(exit_node, ops.Tensor)
110    self._while_context = exit_node.op._get_control_flow_context()
111    assert isinstance(self._while_context, control_flow_ops.WhileContext)
112    self._context_name = self._while_context.name
113    self._condition = self._while_context.pivot.op.inputs[0]
114    # Parts of an external while_loop could be created inside a pfor loop.
115    # However for the purpose here, we declare such loops to be external. Also
116    # note that we check if the condition was created inside or outside to
117    # determine if the while_loop was first created inside or outside.
118    # TODO(agarwal): check that the Enter and Exit of this loop are unstacked.
119    self._is_inside_loop = self.op_is_inside_loop(self._condition.op)
120    if self._is_inside_loop:
121      for e in self._while_context.loop_exits:
122        assert self.op_is_inside_loop(e.op)
123
124    # Note the code below tries to reverse engineer an existing while_loop graph
125    # by assuming the following pattern of nodes.
126    #
127    #          NextIteration <---- Body <--- Enter
128    #              |                ^
129    #              V             ___| Y
130    #    Enter -> Merge -> Switch___
131    #                       ^       | N
132    #                       |       V
133    #                  LoopCond    Exit
134
135    # Node that elements in the list below correspond one-to-one with each
136    # other. i.e. these lists are the same size, and the i_th entry corresponds
137    # to different Operations/Tensors of a single cycle as illustrated above.
138    # List of Switch ops (ops.Operation) that feed into an Exit Node.
139    self._exit_switches = []
140    # List of inputs (ops.Tensor) to NextIteration.
141    self._body_outputs = []
142    # List of list of control inputs of the NextIteration nodes.
143    self._next_iter_control_inputs = []
144    # List of Merge ops (ops.Operation).
145    self._enter_merges = []
146    # List of output (ops.Tensor) of Exit nodes.
147    self._outputs = []
148
149    # List of Enter Tensors.
150    # There are two types of Enter nodes:
151    # - The Enter nodes that are used in the `loop_vars` argument to
152    # `while_loop` (see
153    # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect
154    # these Enter nodes immediately below by tracing backwards from the Exit
155    # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the
156    # diagram above. This allows us to have a 1:1 correspondence between the
157    # self._outputs and the first elements in self._enters.
158    # - The Enter nodes that are used only by the body. They don't appear in the
159    # `loop_vars` and are not returned from the `while_loop`. In Python code,
160    # they are usually captured by the body lambda. We collect them below by
161    # iterating over all the ops in the graph. They are appended to the end of
162    # self._enters or self._direct_enters, and don't correspond to any outputs
163    # in self._outputs. Note that we keep the resource/variant Enter nodes in
164    # self._direct_enters and the constructed while_loop's body uses them
165    # directly as opposed to passing them as loop variables. This is done
166    # because the while_body cannot partition the resource/variant Tensors, so
167    # it has to leave them unchanged.
168    self._enters = []
169    self._direct_enters = []
170
171    for e in self._while_context.loop_exits:
172      self._outputs.append(e.op.outputs[0])
173      switch = e.op.inputs[0].op
174      assert switch.type == "Switch", switch
175      self._exit_switches.append(switch)
176      merge = switch.inputs[0].op
177      assert merge.type == "Merge", merge
178      self._enter_merges.append(merge)
179      enter = merge.inputs[0].op
180      assert enter.type == "Enter", enter
181      self._enters.append(enter.outputs[0])
182      next_iter = merge.inputs[1].op
183      assert next_iter.type == "NextIteration", next_iter
184      self._body_outputs.append(next_iter.inputs[0])
185      self._next_iter_control_inputs.append(next_iter.control_inputs)
186
187    # Collect all the Enter nodes that are not part of `loop_vars`, the second
188    # category described above.
189    # Also track whether the loop body has any stateful ops.
190    self._is_stateful = False
191    for op in ops.get_default_graph().get_operations():
192      # TODO(agarwal): make sure this works with nested case.
193      control_flow_context = op._get_control_flow_context()
194      if control_flow_context is None:
195        continue
196      if control_flow_context.name == self._context_name:
197        self._is_stateful |= _is_stateful_pfor_op(op)
198        if op.type == "Enter":
199          output = op.outputs[0]
200          if output not in self._enters:
201            if output.dtype in (dtypes.resource, dtypes.variant):
202              if output not in self._direct_enters:
203                self._direct_enters.append(output)
204            else:
205              self._enters.append(output)
206
207  def __str__(self):
208    """String representation."""
209    return "while_loop(%s)" % self.name
210
211  @property
212  def inputs(self):
213    """Input to all the Enter nodes."""
214    return [x.op.inputs[0] for x in self._enters + self._direct_enters]
215
216  @property
217  def control_inputs(self):
218    """Control input to all the Enter nodes."""
219    control_inputs = []
220    for x in self._enters + self._direct_enters:
221      control_inputs.extend(x.op.control_inputs)
222    return control_inputs
223
224  @property
225  def outputs(self):
226    """Outputs of all the Exit nodes."""
227    return self._outputs
228
229  @property
230  def name(self):
231    """Context name for the while loop."""
232    return self._context_name
233
234  @property
235  def is_inside_loop(self):
236    """Returns true if the while_loop was created inside the pfor."""
237    return self._is_inside_loop
238
239  def op_is_inside_loop(self, op):
240    """True if op was created inside the pfor loop body."""
241    assert isinstance(op, ops.Operation)
242    # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
243    # since it appears there tensorflow API could return different python
244    # objects representing the same Operation node.
245    return op._id in self._pfor_op_ids
246
247  @property
248  def is_stateful(self):
249    return self._is_stateful
250
251  @property
252  def pfor_converter(self):
253    """Return a converter for the while loop."""
254    return self
255
256  def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs,
257                 inputs_stacked):
258    """Create a PFor object for converting parts of the while_loop.
259
260    Args:
261      parent_pfor: PFor object being used for converting the while_loop.
262      indices: int32 Tensor of ids for the iterations that are still active
263        (i.e. did not exit the while_loop).
264      cond_stacked: True if the while_loop condition is stacked.
265      inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note
266        that these Tensors are a subset of the loop variables for the generated
267        while_loop.
268      inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`,
269        indicating if the value is stacked or not.
270
271    Returns:
272      A PFor instance. The instance is initialized by adding conversion mappings
273        of nodes that will be external to the conversion that the returned
274        instance will be used for. e.g. Enter nodes as well as Merge and Switch
275        outputs are mapped to converted values.
276    """
277    num_outputs = len(self._outputs)
278    assert len(inputs) == len(self._enters)
279    assert len(inputs_stacked) == len(self._enters)
280    loop_var = parent_pfor.loop_var
281    loop_len = array_ops.size(indices)
282    pfor = PFor(
283        loop_var,
284        loop_len,
285        pfor_ops=self._pfor_ops,
286        all_indices=indices,
287        all_indices_partitioned=cond_stacked,
288        pfor_config=self._pfor_config)
289    # Map all inputs of Enter nodes in self._direct_enters to their converted
290    # values.
291    for enter in self._direct_enters:
292      enter_input = enter.op.inputs[0]
293      converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper(
294          enter_input)
295      # Since these are resources / variants, they should be unstacked.
296      assert not stacked and not is_sparse_stacked, (enter, converted_enter)
297      pfor._add_conversion(enter, wrap(converted_enter, False))
298
299    # Map all Enter nodes to the inputs.
300    for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked):
301      pfor._add_conversion(enter, wrap(inp, stacked))
302    # Map outputs of Switch and Merge.
303    for i in range(num_outputs):
304      wrapped_inp = wrap(inputs[i], inputs_stacked[i])
305      merge = self._enter_merges[i]
306      pfor._add_conversion(merge.outputs[0], wrapped_inp)
307      # Note that second output of Merge is typically not used, except possibly
308      # as a control dependency. To avoid trying to output the correct value, we
309      # employ a hack here. We output a dummy invalid value with an incorrect
310      # dtype. This will allow control dependency to work but if using it as an
311      # input, it should typically lead to errors during graph construction due
312      # to dtype mismatch.
313      # TODO(agarwal): Check in the original graph to see if there are any
314      # consumers of this Tensor that use it as an input.
315      pfor._add_conversion(merge.outputs[1],
316                           wrap(constant_op.constant(-1.0), False))
317      switch = self._exit_switches[i]
318      # Don't need to worry about switch.output[0] which will feed to Exit node.
319      pfor._add_conversion(switch.outputs[1], wrapped_inp)
320    return pfor
321
322  def _convert_enter(self, parent_pfor, enter):
323    """Converts an Enter node."""
324    inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0])
325    control_inputs = [
326        parent_pfor._convert_helper(x).t for x in enter.op.control_inputs
327    ]
328    if control_inputs:
329      with ops.control_dependencies(control_inputs):
330        inp = array_ops.identity(inp)
331    return inp, stacked
332
333  def _maybe_stacked(self, cache, inp):
334    """Heuristic to figue out if the coverting inp leads to a stacked value.
335
336
337    Args:
338      cache: map from Tensor to boolean indicating stacked/unstacked.
339      inp: input Tensor.
340
341    Returns:
342      True if `inp` could get stacked. If the function returns False, the
343      converted value should be guaranteed to be unstacked. If returning True,
344      it may or may not be stacked.
345    """
346    if inp in cache:
347      return cache[inp]
348    if not self.op_is_inside_loop(inp.op):
349      return False
350    op = inp.op
351    output = False
352    if op.type in [
353        "Shape",
354        "Rank"
355        "ShapeN",
356        "ZerosLike",
357        "TensorArrayV3",
358        "TensorArraySizeV3",
359    ]:
360      output = False
361    elif _is_stateful_pfor_op(op):
362      # This may be fairly aggressive.
363      output = True
364    elif op.type == "Exit":
365      # This may be fairly aggressive.
366      output = True
367    else:
368      for t in op.inputs:
369        if self._maybe_stacked(cache, t):
370          output = True
371          break
372    cache[inp] = output
373    return output
374
375  def _create_init_values(self, pfor_input):
376    """Create arguments passed to converted while_loop."""
377    with ops.name_scope("while_init"):
378      loop_len_vector = pfor_input.pfor.loop_len_vector
379      loop_len = loop_len_vector[0]
380      num_outputs = len(self._outputs)
381
382      inputs = []
383      maybe_stacked_cache = {}
384      # Convert all the Enters. Need to do this before checking for stacking
385      # below.
386      for i, enter in enumerate(self._enters):
387        inp, stacked = self._convert_enter(pfor_input.pfor, enter)
388        inputs.append(inp)
389        maybe_stacked_cache[enter] = stacked
390        # Since this enter node is part of the `loop_vars`, it corresponds to an
391        # output and its preceding switch. We mark this switch's output the same
392        # stackness, to act at the base case for the logic below. Below, we will
393        # be going through the body figuring out which inputs might need to be
394        # stacked and which inputs can safely remain unstacked.
395        if i < num_outputs:
396          maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked
397
398      # Shape invariants for init_values corresponding to self._enters.
399      input_shape_invariants = []
400      # TensorArrays for outputs of converted while loop
401      output_tas = []
402      # Shape invariants for output TensorArrays.
403      ta_shape_invariants = []
404      # List of booleans indicating stackness of inputs, i.e. tensors
405      # corresponding to self._enters.
406      inputs_stacked = []
407      for i, inp in enumerate(inputs):
408        enter = self._enters[i]
409        inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter)
410        # Note that even when an input is unstacked, the body could make it
411        # stacked. we use a heuristic below to figure out if body may be making
412        # it stacked.
413        if i < num_outputs:
414          body_output = self._body_outputs[i]
415          if enter.op in self._pfor_ops:
416            body_output_stacked = self._maybe_stacked(maybe_stacked_cache,
417                                                      body_output)
418          else:
419            # If constructed outside of pfor loop, then the output would not be
420            # stacked.
421            body_output_stacked = False
422          if body_output_stacked and not inp_stacked:
423            inp = _stack(inp, loop_len_vector).t
424            inputs[i] = inp
425            inp_stacked = True
426          # TODO(agarwal): other attributes for the TensorArray ?
427          output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len))
428          ta_shape_invariants.append(tensor_shape.TensorShape(None))
429
430        inputs_stacked.append(inp_stacked)
431        input_shape_invariants.append(tensor_shape.TensorShape(None))
432
433      # See documentation for __call__ for the structure of init_values.
434      init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas
435      # TODO(agarwal): try stricter shape invariants
436      shape_invariants = (
437          [tensor_shape.TensorShape(None),
438           tensor_shape.TensorShape(None)
439          ] + input_shape_invariants + ta_shape_invariants)
440
441      return init_values, inputs_stacked, shape_invariants
442
443  def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
444    """Handles case when condition is unstacked.
445
446    Note that all iterations end together. So we don't need to partition the
447    inputs. When all iterations are done, we write the inputs to the
448    TensorArrays. Note that we only write to index 0 of output_tas. Since all
449    iterations end together, they can all be output together.
450    """
451    not_all_done = array_ops.reshape(conditions, [])
452    new_output_tas = []
453    # pylint: disable=cell-var-from-loop
454    for i, out_ta in enumerate(output_tas):
455      inp = inputs[i]
456      new_output_tas.append(
457          control_flow_ops.cond(not_all_done,
458                                lambda: out_ta,
459                                lambda: out_ta.write(0, inp)))
460    # pylint: enable=cell-var-from-loop
461    return not_all_done, indices, inputs, new_output_tas
462
463  def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
464                            output_tas):
465    num_outputs = len(self._outputs)
466    # Compute if all iterations are done.
467    not_all_done = math_ops.reduce_any(conditions)
468    conditions_int = math_ops.cast(conditions, dtypes.int32)
469    # Partition the indices.
470    done_indices, new_indices = data_flow_ops.dynamic_partition(
471        indices, conditions_int, 2)
472
473    new_inputs = []
474    new_output_tas = []
475    for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
476      # Partition the inputs.
477      if stacked:
478        done_inp, new_inp = data_flow_ops.dynamic_partition(
479            inp, conditions_int, 2)
480      else:
481        # TODO(agarwal): avoid this stacking. See TODO earlier in
482        # _process_cond_unstacked.
483        done_inp = _stack(inp, [array_ops.size(done_indices)]).t
484        new_inp = inp
485      new_inputs.append(new_inp)
486      # For iterations that are done, write them to TensorArrays.
487      if i < num_outputs:
488        out_ta = output_tas[i]
489        # Note that done_indices can be empty. done_inp should also be empty in
490        # that case.
491        new_output_tas.append(out_ta.scatter(done_indices, done_inp))
492    return not_all_done, new_indices, new_inputs, new_output_tas
493
494  def _process_body(self, pfor_input, inputs_stacked,
495                    new_indices, cond_stacked, new_inputs,
496                    not_all_done):
497    """Convert the body function."""
498
499    def true_fn(control_inputs, body_pfor, body_output, stacked):
500      """Converts the body function for all but last iteration.
501
502      This essentially converts body_output. Additionally, it needs to handle
503      any control dependencies on the NextIteration node. So it creates another
504      Identity node with the converted dependencies.
505      """
506      converted_control_inp = []
507      for x in control_inputs:
508        for t in x.outputs:
509          converted_control_inp.append(body_pfor._convert_helper(t).t)
510      if stacked:
511        # Note convert always does the stacking.
512        output = body_pfor.convert(body_output)
513      else:
514        output, convert_stacked, _ = body_pfor._convert_helper(body_output)
515        assert convert_stacked == stacked, body_output
516      with ops.control_dependencies(converted_control_inp):
517        return array_ops.identity(output)
518
519    body_pfor = self._init_pfor(pfor_input.pfor, new_indices,
520                                cond_stacked, new_inputs,
521                                inputs_stacked)
522    new_outputs = []
523
524    for i, (body_output, stacked) in enumerate(
525        zip(self._body_outputs, inputs_stacked)):
526      control_inp = self._next_iter_control_inputs[i]
527      out_dtype = body_output.dtype
528      # Note that we want to run the body only if not all pfor iterations are
529      # done. If all are done, we return empty tensors since these values will
530      # not be used. Notice that the value returned by the loop is based on
531      # TensorArrays and not directly on these returned values.
532      # pylint: disable=cell-var-from-loop
533      new_output = control_flow_ops.cond(
534          not_all_done,
535          lambda: true_fn(control_inp, body_pfor, body_output, stacked),
536          lambda: constant_op.constant([], dtype=out_dtype))
537      # pylint: enable=cell-var-from-loop
538      new_outputs.append(new_output)
539    return new_outputs
540
541  def __call__(self, pfor_input):
542    """Converter for the while_loop.
543
544    The conversion of a while_loop is another while_loop.
545
546    The arguments to this converted while_loop are as follows:
547    not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
548      are done.
549    indices: int32 1-D Tensor storing the id of the iterations that are not
550      done.
551    args: Remaining arguments. These can be divided into 3 categories:
552      - First set of arguments are the tensors that correspond to the initial
553        elements of self._enters. The elements that appear in original while
554        loop's `loop_vars`.
555      - The second set of arguments are the tensors that correspond to the
556        remaining elements of self._enters. These are the tensors that directly
557        enter the original while loop body.
558       - Finally, the last set of arguments are TensorArrays. These TensorArrays
559         correspond to the outputs of the original while_loop, i.e. to the
560         elements in self._outputs. Each TensorArray has `PFor.loop_len`
561         elements, i.e. the number of pfor iterations. At the end, the i'th
562         element of each TensorArray will contain the output computed by the
563         i'th iteration of pfor. Note that elements can be written into these
564         tensors arrays in any order, depending on when the corresponding pfor
565         iteration is done.
566      If the original while_loop had `k` tensors in its `loop_vars` and its body
567      directly captured `m` tensors, the `args` will contain `2 * k + m` values.
568
569    In each iteration, the while_loop body recomputes the condition for all
570    active pfor iterations to see which of them are now done. It then partitions
571    all the inputs and passes them along to the converted body. Values for all
572    the iterations that are done are written to TensorArrays indexed by the pfor
573    iteration number. When all iterations are done, the TensorArrays are stacked
574    to get the final value.
575
576    Args:
577      pfor_input: A PForInput object corresponding to the output of any Exit
578        node from this while loop.
579
580    Returns:
581      List of converted outputs.
582    """
583    # Create init_values that will be passed to the while_loop.
584    init_values, inputs_stacked, shape_invariants = self._create_init_values(
585        pfor_input)
586    # Note that we use a list as a hack since we need the nested function body
587    # to set the value of cond_is_stacked. python2.x doesn't support nonlocal
588    # variables.
589    cond_is_stacked = [None]
590
591    def cond(not_all_done, *_):
592      return not_all_done
593
594    def body(not_all_done, indices, *args):
595      # See documentatin for __call__ for the structure of *args.
596      num_enters = len(self._enters)
597      inputs = args[:num_enters]
598      output_tas = args[num_enters:]
599      # TODO(agarwal): see which outputs have consumers and only populate the
600      # TensorArrays corresponding to those. Or do those paths get trimmed out
601      # from inside the while_loop body?
602      assert len(inputs) >= len(output_tas)
603      assert len(inputs) == len(inputs_stacked)
604
605      # Convert condition
606      with ops.name_scope("while_cond"):
607        # Note that we set cond_stacked to True here. At this point we don't
608        # know if it could be loop invariant, hence the conservative value is
609        # to assume stacked.
610        cond_pfor = self._init_pfor(pfor_input.pfor, indices,
611                                    cond_stacked=True,
612                                    inputs=inputs,
613                                    inputs_stacked=inputs_stacked)
614        conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition)
615        cond_is_stacked[0] = cond_stacked
616
617      # Recompute the new condition, write outputs of done iterations, and
618      # partition the inputs if needed.
619      if not cond_stacked:
620        (not_all_done, new_indices,
621         new_inputs, new_output_tas) = self._process_cond_unstacked(
622             conditions, indices, inputs, output_tas)
623      else:
624        (not_all_done, new_indices,
625         new_inputs, new_output_tas) = self._process_cond_stacked(
626             conditions, indices, inputs, inputs_stacked, output_tas)
627
628      # Convert body
629      with ops.name_scope("while_body"):
630        #  Compute the outputs from the body.
631        new_outputs = self._process_body(pfor_input, inputs_stacked,
632                                         new_indices, cond_stacked, new_inputs,
633                                         not_all_done)
634
635      # Note that the first num_outputs new values of inputs are computed using
636      # the body. Rest of them were direct Enters into the condition/body and
637      # the partitioning done earlier is sufficient to give the new value.
638      num_outputs = len(self._outputs)
639      new_args = ([not_all_done, new_indices] + new_outputs + list(
640          new_inputs[num_outputs:]) + new_output_tas)
641      return tuple(new_args)
642
643    while_outputs = control_flow_ops.while_loop(
644        cond, body, init_values, shape_invariants=shape_invariants)
645    output_tas = while_outputs[-len(self._outputs):]
646    outputs = []
647    assert cond_is_stacked[0] is not None
648    for inp_stacked, ta in zip(inputs_stacked, output_tas):
649      if cond_is_stacked[0]:
650        outputs.append(wrap(ta.stack(), True))
651      else:
652        # Note that if while_loop condition is unstacked, all iterations exit at
653        # the same time and we wrote those outputs in index 0 of the tensor
654        # array.
655        outputs.append(wrap(ta.read(0), inp_stacked))
656    return outputs
657
658
659class _PforInput(object):
660  """Input object passed to registered pfor converters."""
661
662  def __init__(self, pfor, op, inputs):
663    """Creates a _PforInput object.
664
665    Args:
666      pfor: PFor converter object.
667      op: the Operation object that is being converted.
668      inputs: list of WrappedTensor objects representing converted values of the
669        inputs of `op`.
670    """
671    self.pfor = pfor
672    self._op = op
673    self._inputs = inputs
674
675  def stack_inputs(self, stack_indices=None):
676    """Stacks unstacked inputs at `stack_indices`.
677
678    Args:
679      stack_indices: indices of inputs at which stacking is done. If None,
680        stacking is done at all indices.
681    """
682    if stack_indices is None:
683      stack_indices = range(len(self._inputs))
684    length = self.pfor.loop_len_vector
685    for i in stack_indices:
686      inp = self._inputs[i]
687      if not inp.is_stacked:
688        self._inputs[i] = _stack(inp.t, length)
689
690  def expanddim_inputs_for_broadcast(self):
691    """Reshapes stacked inputs to prepare them for broadcast.
692
693    Since stacked inputs have an extra leading dimension, automatic broadcasting
694    rules could incorrectly try to expand dimensions before that leading
695    dimension. To avoid that, we reshape these stacked inputs to the maximum
696    rank they will need to be broadcasted to.
697    """
698    if not self._inputs:
699      return
700
701    # Find max rank
702    def _get_rank(x):
703      rank = array_ops.rank(x.t)
704      if not x.is_stacked:
705        rank += 1
706      return rank
707
708    ranks = [_get_rank(x) for x in self._inputs]
709    max_rank = ranks[0]
710    for rank in ranks[1:]:
711      max_rank = math_ops.maximum(rank, max_rank)
712
713    for i, inp in enumerate(self._inputs):
714      if inp.is_stacked:
715        shape = array_ops.shape(inp.t)
716        rank_diff = array_ops.reshape(max_rank - ranks[i], [1])
717        ones = array_ops.tile([1], rank_diff)
718        new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0)
719        self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True)
720
721  @property
722  def inputs(self):
723    return self._inputs
724
725  @property
726  def num_inputs(self):
727    return len(self._inputs)
728
729  def input(self, index):
730    assert len(self._inputs) > index, (index, self._inputs)
731    return self._inputs[index]
732
733  def stacked_input(self, index):
734    t, is_stacked, _ = self.input(index)
735    if not is_stacked:
736      op_type = self.op_type
737      op_def = getattr(self._op, "op_def", None)
738      if op_def is None:
739        input_name = "at index %d" % index
740      else:
741        input_name = "\"%s\"" % op_def.input_arg[index].name
742      raise ValueError("Input %s of op \"%s\" expected to be not loop invariant"
743                       ".\nError while converting op %s"
744                       "with converted inputs\n%s" % (input_name, op_type,
745                                                      self._op, self.inputs))
746    return t
747
748  def unstacked_input(self, index):
749    t, is_stacked, _ = self.input(index)
750    if is_stacked:
751      op_type = self.op_type
752      op_def = getattr(self._op, "op_def", None)
753      if op_def is None:
754        input_name = "at index %d" % index
755      else:
756        input_name = "\"%s\"" % op_def.input_arg[index].name
757      raise ValueError("Input %s of op \"%s\" expected to be loop invariant"
758                       ".\nError while converting op %s"
759                       "with converted inputs\n%s" % (input_name, op_type,
760                                                      self._op, self.inputs))
761    return t
762
763  @property
764  def op(self):
765    return self._op
766
767  @property
768  def op_type(self):
769    return self._op.type
770
771  def get_attr(self, attr):
772    return self._op.get_attr(attr)
773
774  @property
775  def outputs(self):
776    return self._op.outputs
777
778  def output(self, index):
779    assert index < len(self._op.outputs)
780    return self._op.outputs[index]
781
782
783_pfor_converter_registry = {}
784
785
786class RegisterPFor(object):
787  """Utility to register converters for pfor.
788
789  Usage:
790  @RegisterPFor(foo_op_type)
791  def _foo_converter(pfor_input):
792    ...
793
794  The above will register conversion function `_foo_converter` for handling
795  conversion of `foo_op_type`. During conversion, the registered functin will be
796  called with a single argument of type `PForInput` which will contain state
797  needed for the conversion.  This registered function should output a list of
798  WrappedTensor object with the same length as the number of outputs of op being
799  converted. If the op had zero outputs, then it should return a ops.Operation
800  object.
801  """
802
803  def __init__(self, op_type):
804    """Creates an object to register a converter for op with type `op_type`."""
805    self.op_type = op_type
806
807  def __call__(self, converter):
808    name = self.op_type
809    assert name not in _pfor_converter_registry, "Re-registering %s " % name
810    _pfor_converter_registry[name] = converter
811    return converter
812
813
814class RegisterPForWithArgs(RegisterPFor):
815  """Utility to register converters for pfor.
816
817  Usage:
818  @RegisteRPFor(foo_op_type, foo=value, ....)
819  def _foo_converter(pfor_input, foo=None, ....):
820    ...
821
822  See RegisterPFor for details on the conversion function.
823  `RegisterPForWithArgs` allows binding extra arguments to the
824  conversion function at registration time.
825  """
826
827  def __init__(self, op_type, *args, **kw_args):
828    super(RegisterPForWithArgs, self).__init__(op_type)
829    self._args = args
830    self._kw_args = kw_args
831
832  def __call__(self, converter):
833
834    def _f(pfor_input):
835      return converter(pfor_input, self.op_type, *self._args, **self._kw_args)
836
837    super(RegisterPForWithArgs, self).__call__(_f)
838    return converter
839
840
841def _create_op(op_type, inputs, op_dtypes, attrs=None):
842  """Utility to create an op."""
843  return ops.get_default_graph().create_op(
844      op_type, inputs, op_dtypes, attrs=attrs, compute_device=True)
845
846
847WrappedTensor = collections.namedtuple("WrappedTensor",
848                                       ["t", "is_stacked", "is_sparse_stacked"])
849"""Wrapper around the result of a Tensor conversion.
850
851The additional fields are useful for keeping track of the conversion state as
852data flows through the ops in the loop body. For every op whose output is a
853Tensor, its converter should return either a WrappedTensor or a list of
854WrappedTensors.
855
856Args:
857  t: The converted tensor
858  is_stacked: True if the tensor is stacked, i.e. represents the results of all
859    the iterations of the loop, where each row i of the tensor corresponds to
860    that op's output on iteration i of the loop. False if the tensor is not
861    stacked, i.e. represents the result of the op on of a single iteration of
862    the loop, where the result does not vary between iterations.
863  is_sparse_stacked: True if the tensor corresponds to a component tensor
864    (indices, values, or dense_shape) of a sparse tensor, and has been logically
865    stacked via a sparse conversion.
866"""
867
868
869def wrap(tensor, is_stacked=True, is_sparse_stacked=False):
870  """Helper to create a WrappedTensor object."""
871  assert isinstance(is_stacked, bool)
872  assert isinstance(is_sparse_stacked, bool)
873  assert isinstance(tensor, ops.Tensor)
874  assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is "
875                                               "stacked via a sparse "
876                                               "conversion, it must also be "
877                                               "stacked.")
878  return WrappedTensor(tensor, is_stacked, is_sparse_stacked)
879
880
881def _fallback_converter(pfor_input):
882  logging.warn("Using a while_loop for converting %s", pfor_input.op_type)
883  output_dtypes = [x.dtype for x in pfor_input.outputs]
884  iters = pfor_input.pfor.loop_len_vector[0]
885
886  def while_body(i, *ta_list):
887    """Body of while loop."""
888    inputs = [
889        x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs
890    ]
891    op_outputs = _create_op(
892        pfor_input.op_type,
893        inputs,
894        output_dtypes,
895        attrs=pfor_input.op.node_def.attr).outputs
896
897    outputs = []
898    for out, ta in zip(op_outputs, ta_list):
899      assert isinstance(out, ops.Tensor)
900      outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
901    return tuple([i + 1] + outputs)
902
903  ta_list = control_flow_ops.while_loop(
904      lambda i, *ta: i < iters, while_body, [0] + [
905          tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes
906      ])[1:]
907  return tuple([wrap(ta.concat(), True) for ta in ta_list])
908
909
910class PForConfig(object):
911  """A configuration object used to communicate with loop body function."""
912
913  def __init__(self):
914    # This may be set to the number of iterations.
915    self._maybe_iters = None
916    # Map from output placeholder to the unvectorized tensor.
917    self._reduce_concat_map = {}
918    # Reverse map of `self._reduce_concat_map`.
919    self._reverse_reduce_concat_map = {}
920
921  def _has_reductions(self):
922    """True if some reductions where performed by loop body."""
923    return len(self._reduce_concat_map)
924
925  def _set_iters(self, iters):
926    """Set number of pfor iterations."""
927    self._maybe_iters = iters
928
929  # TODO(agarwal): handle reductions inside control flow constructs.
930  def reduce_concat(self, x):
931    """Performs a concat reduction on `x` across pfor iterations.
932
933    Note that this currently may not work inside a control flow construct.
934    Args:
935      x: an unvectorized Tensor.
936
937    Returns:
938      A Tensor that has rank one higher than `x`. The value is the vectorized
939      version of `x`, i.e. stacking the value of `x` across different pfor
940      iterations.
941    """
942    assert not context.executing_eagerly()
943    assert isinstance(x, ops.Tensor)
944    if x not in self._reduce_concat_map:
945      out_shape = tensor_shape.TensorShape([self._maybe_iters]).concatenate(
946          x.shape)
947      with ops.control_dependencies([x]):
948        # Control dependency to make sure out is converted after x.
949        out = array_ops.placeholder(x.dtype, out_shape)
950      self._reduce_concat_map[out] = x
951      self._reverse_reduce_concat_map[x] = out
952      return out
953    else:
954      return self._reverse_reduce_concat_map[x]
955
956  def reduce_mean(self, x):
957    """Performs a mean reduction on `x` across pfor iterations.
958
959    Note that this currently may not work inside a control flow construct.
960    Args:
961      x: an unvectorized Tensor.
962
963    Returns:
964      A Tensor that has same rank as `x`. The value is the mean of the values
965      of `x` across the pfor iterations.
966    """
967    y = self.reduce_concat(x)
968    return math_ops.reduce_mean(y, axis=0)
969
970  def reduce_sum(self, x):
971    """Performs a sum reduction on `x` across pfor iterations.
972
973    Note that this currently may not work inside a control flow construct.
974    Args:
975      x: an unvectorized Tensor.
976
977    Returns:
978      A Tensor that has same rank as `x`. The value is the sum of the values
979      of `x` across the pfor iterations.
980    """
981    y = self.reduce_concat(x)
982    return math_ops.reduce_sum(y, axis=0)
983
984  def _lookup_reduction(self, pl):
985    """Lookups Placeholder `pl` in the reduction map."""
986    assert isinstance(pl, ops.Tensor)
987    return self._reduce_concat_map.get(pl, None)
988
989
990class PFor(object):
991  """Implementation of rewrite of parallel-for loops.
992
993  This class takes a DAG or a set of DAGs representing the body of a
994  parallel-for loop, and adds new operations to the graph that implements
995  functionality equivalent to running that loop body for a specified number of
996  iterations. This new set of nodes may or may not use a tensorflow loop
997  construct.
998
999  The process of conversion does not delete or change any existing operations.
1000  It only adds operations that efficiently implement the equivalent
1001  functionality. We refer to the added ops as "converted ops".
1002
1003  The conversion process uses a simple greedy heuristic. It walks the loop body
1004  and tries to express the functionality of running each node in a loop with a
1005  new set of nodes. When converting an op several cases are possible:
1006  - The op is not inside the loop body. Hence it can be used as is.
1007  - The op does not depend on the iteration number and is stateless. In this
1008    case, it can be used as is.
1009  - The op is not stateful, and depends on iteration number only through control
1010    dependencies. In this case, we can create a single op with same inputs and
1011    attributes, but with "converted" control dependencies.
1012  - The op is not stateful, and all its inputs are loop invariant. In this
1013    case, similar to above, we can create a single op with same inputs and
1014    attributes, but with "converted" control dependencies.
1015  - The op is stateful or at least one of the inputs is not loop invariant. In
1016    this case, we run the registered converter for that op to create a set of
1017    converted ops. All nodes in the set will have converted control dependencies
1018    corresponding to control dependencies of the original op. If the op returned
1019    multiple outputs, "converted outputs" could be produced by different ops in
1020    this set.
1021  """
1022
1023  def __init__(self,
1024               loop_var,
1025               loop_len,
1026               pfor_ops,
1027               all_indices=None,
1028               all_indices_partitioned=False,
1029               pfor_config=None):
1030    """Creates an object to rewrite a parallel-for loop.
1031
1032    Args:
1033      loop_var: ops.Tensor output of a Placeholder operation. The value should
1034        be an int32 scalar representing the loop iteration number.
1035      loop_len: A scalar or scalar Tensor representing the number of iterations
1036        the loop is run for.
1037      pfor_ops: List of all ops inside the loop body.
1038      all_indices: If not None, an int32 vector with size `loop_len`
1039        representing the iteration ids that are still active. These values
1040        should be unique and sorted. However they may not be contiguous. This is
1041        typically the case when inside a control flow construct which has
1042        partitioned the indices of the iterations that are being converted.
1043      all_indices_partitioned: If True, this object is being constructed from a
1044       control flow construct where not all the pfor iterations are guaranteed
1045       to be active.
1046      pfor_config: PForConfig object used while constructing the loop body.
1047    """
1048    assert isinstance(loop_var, ops.Tensor)
1049    assert loop_var.op.type == "Placeholder"
1050    self._loop_var = loop_var
1051    loop_len_value = tensor_util.constant_value(loop_len)
1052    if loop_len_value is not None:
1053      loop_len = loop_len_value
1054    self._loop_len_vector = array_ops.reshape(loop_len, [1])
1055    self._all_indices_partitioned = all_indices_partitioned
1056    if all_indices_partitioned:
1057      assert all_indices is not None
1058    self.all_indices = (
1059        math_ops.range(loop_len) if all_indices is None else all_indices)
1060
1061    self._conversion_map = {}
1062    self._conversion_map[loop_var] = wrap(self.all_indices, True)
1063    self._pfor_ops = set(pfor_ops)
1064    self._pfor_op_ids = set([x._id for x in pfor_ops])
1065    self._pfor_config = pfor_config
1066
1067  def op_is_inside_loop(self, op):
1068    """True if op was created inside the pfor loop body."""
1069    assert isinstance(op, ops.Operation)
1070    # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
1071    # since it appears there tensorflow API could return different python
1072    # objects representing the same Operation node.
1073    return op._id in self._pfor_op_ids
1074
1075  def _convert_sparse(self, y):
1076    """Returns the converted value corresponding to SparseTensor y.
1077
1078    For SparseTensors, instead of stacking the component tensors separately,
1079    resulting in component tensors with shapes (N, m, rank), (N, m), and (N,
1080    rank) respectively for indices, values, and dense_shape (where N is the loop
1081    length and m is the number of sparse tensor values per loop iter), we want
1082    to logically stack the SparseTensors, to create a SparseTensor whose
1083    components are size (N * m, rank + 1), (N * m, ), and (rank + 1,)
1084    respectively.
1085
1086    Here, we try to get the conversion of each component tensor.
1087    If the tensors are stacked via a sparse conversion, return the resulting
1088    SparseTensor composed of the converted components. Otherwise, the component
1089    tensors are either unstacked or stacked naively. In the latter case, we
1090    unstack the component tensors to reform loop_len SparseTensor elements,
1091    then correctly batch them.
1092
1093    The unstacked tensors must have the same rank. Each dimension of each
1094    SparseTensor will expand to be the largest among all SparseTensor elements
1095    for that dimension. For example, if there are N SparseTensors of rank 3
1096    being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i),
1097    the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)).
1098
1099    Args:
1100      y: A tf.SparseTensor.
1101
1102    Returns:
1103      A tf.SparseTensor that is the converted value corresponding to y.
1104    """
1105    outputs = [
1106        self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape)
1107    ]
1108    assert all(isinstance(o, WrappedTensor) for o in outputs)
1109
1110    if all(w.is_sparse_stacked for w in outputs):
1111      return sparse_tensor.SparseTensor(*[w.t for w in outputs])
1112
1113    assert not any(w.is_sparse_stacked for w in outputs), (
1114        "Error converting SparseTensor. All components should be logically "
1115        "stacked, or none.")
1116
1117    # If component tensors were not sparsely stacked, they are either unstacked
1118    # or stacked without knowledge that they are components of sparse tensors.
1119    # In this case, we have to restack them.
1120    return self._restack_sparse_tensor_logically(
1121        *[self._unwrap_or_tile(w) for w in outputs])
1122
1123  def _restack_sparse_tensor_logically(self, indices, values, shape):
1124    sparse_tensor_rank = indices.get_shape().dims[-1].value
1125    if sparse_tensor_rank is not None:
1126      sparse_tensor_rank += 1
1127
1128    def fn(args):
1129      res = gen_sparse_ops.serialize_sparse(
1130          args[0], args[1], args[2], out_type=dtypes.variant)
1131      return res
1132
1133    # Applies a map function to the component tensors to serialize each
1134    # sparse tensor element and batch them all, then deserializes the batch.
1135    # TODO(rachelim): Try to do this without map_fn -- add the right offsets
1136    # to shape and indices tensors instead.
1137    result = map_fn.map_fn(
1138        fn, [indices, values, shape], dtype=dtypes.variant)
1139    return sparse_ops.deserialize_sparse(
1140        result, dtype=values.dtype, rank=sparse_tensor_rank)
1141
1142  def _unwrap_or_tile(self, wrapped_tensor):
1143    """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it."""
1144    output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked
1145    if is_stacked:
1146      return output
1147    else:
1148      return _stack(output, self._loop_len_vector).t
1149
1150  def convert(self, y):
1151    """Returns the converted value corresponding to y.
1152
1153    Args:
1154      y: A ops.Tensor or a ops.Operation object. If latter, y should not have
1155        any outputs.
1156
1157    Returns:
1158      If y does not need to be converted, it returns y as is. Else it returns
1159      the "converted value" corresponding to y.
1160    """
1161    if y is None:
1162      return None
1163    if isinstance(y, sparse_tensor.SparseTensor):
1164      return self._convert_sparse(y)
1165    output = self._convert_helper(y)
1166    if isinstance(output, WrappedTensor):
1167      assert isinstance(y, ops.Tensor)
1168      return self._unwrap_or_tile(output)
1169    else:
1170      assert isinstance(y, ops.Operation)
1171      assert not y.outputs
1172      assert isinstance(output, ops.Operation)
1173    return output
1174
1175  def _was_converted(self, t):
1176    """True if t is not a conversion of itself."""
1177    converted_t = self._conversion_map[t]
1178    return converted_t.t is not t
1179
1180  def _add_conversion(self, old_output, new_output):
1181    self._conversion_map[old_output] = new_output
1182
1183  def _convert_helper(self, op_or_tensor):
1184    stack = [op_or_tensor]
1185    while stack:
1186      y = stack[0]
1187      if y in self._conversion_map:
1188        assert isinstance(self._conversion_map[y],
1189                          (WrappedTensor, ops.Operation))
1190        stack.pop(0)
1191        continue
1192      if isinstance(y, ops.Operation):
1193        assert not y.outputs, (
1194            "We only support converting Operation objects with no outputs. "
1195            "Got %s", y)
1196        y_op = y
1197      else:
1198        assert isinstance(y, ops.Tensor), y
1199        y_op = y.op
1200
1201      is_while_loop = y_op.type == "Exit"
1202      if is_while_loop:
1203        while_op = WhileOp(
1204            y, pfor_ops=self._pfor_ops, pfor_config=self._pfor_config)
1205        is_inside_loop = while_op.is_inside_loop
1206        # If all nodes in the while_loop graph were created inside the pfor, we
1207        # treat the whole loop subgraph as a single op (y_op) and try to convert
1208        # it. For while_loops that are created completely or partially outside,
1209        # we treat them as external and should be able to simply return the Exit
1210        # node output as is without needing any conversion. Note that for
1211        # while_loops that are partially constructed inside, we assume they will
1212        # be loop invariant. If that is not the case, it will create runtime
1213        # errors since the converted graph would depend on the self._loop_var
1214        # placeholder.
1215        if is_inside_loop:
1216          y_op = while_op
1217      else:
1218        is_inside_loop = self.op_is_inside_loop(y_op)
1219
1220      # If this op was not created inside the loop body, we will return as is.
1221      # 1. Convert inputs and control inputs.
1222
1223      def _add_to_stack(x):
1224        if x not in self._conversion_map:
1225          stack.insert(0, x)
1226          return True
1227        else:
1228          return False
1229
1230      if is_inside_loop:
1231        added_to_stack = False
1232        for inp in y_op.inputs:
1233          added_to_stack |= _add_to_stack(inp)
1234        for cinp in y_op.control_inputs:
1235          if cinp.outputs:
1236            for t in cinp.outputs:
1237              added_to_stack |= _add_to_stack(t)
1238          else:
1239            added_to_stack |= _add_to_stack(cinp)
1240        if added_to_stack:
1241          continue
1242
1243        converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs]
1244        some_input_converted = any(self._was_converted(x) for x in y_op.inputs)
1245        some_input_stacked = any(x.is_stacked for x in converted_inputs)
1246
1247        converted_control_ops = set()
1248        some_control_input_converted = False
1249        for cinp in y_op.control_inputs:
1250          if cinp.outputs:
1251            for t in cinp.outputs:
1252              converted_t = self._conversion_map[t]
1253              if self._was_converted(t):
1254                some_control_input_converted = True
1255              converted_control_ops.add(converted_t.t.op)
1256          else:
1257            converted_cinp = self._conversion_map[cinp]
1258            assert isinstance(converted_cinp, ops.Operation)
1259            if converted_cinp != cinp:
1260              some_control_input_converted = True
1261            converted_control_ops.add(converted_cinp)
1262        converted_control_ops = list(converted_control_ops)
1263        is_stateful = _is_stateful_pfor_op(y_op)
1264      else:
1265        converted_inputs = []
1266        converted_control_ops = []
1267      logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op,
1268                   converted_inputs, converted_control_ops)
1269
1270      # 2. Convert y_op
1271      # If converting a while_loop, we let the while_loop convertor deal with
1272      # putting the control dependencies appropriately.
1273      control_dependencies = [] if is_while_loop else converted_control_ops
1274      with ops.control_dependencies(control_dependencies), ops.name_scope(
1275          y_op.name + "/pfor/"):
1276        # Op is a placeholder for a reduction.
1277        if (self._pfor_config is not None and
1278            self._pfor_config._lookup_reduction(y) is not None):
1279          # Handle reductions. Map the placeholder to the unvectorized input
1280          # that is being reduced.
1281          reduction_input = self._pfor_config._lookup_reduction(y)
1282          assert isinstance(reduction_input, ops.Tensor), reduction_input
1283          # Tensor being reduced should already be converted due to a control
1284          # dependency on the created placeholder.
1285          # Note that in cases where reduction_input is in an outer context, one
1286          # needs to locate the corresponding Enter node and use that to lookup
1287          # the conversion.
1288          # TODO(agarwal): handle reductions inside control flow constructs.
1289          assert reduction_input in self._conversion_map, (
1290              "Unable to handle reduction of %s, possibly as it was used "
1291              "inside a control flow construct. Note that reductions across "
1292              "pfor iterations are currently not supported inside control flow "
1293              "constructs." % reduction_input)
1294          output = self._conversion_map[reduction_input]
1295          # If original input is not stacked, we tile it. Also we always mark
1296          # output as unstacked.
1297          new_outputs = [wrap(self._unwrap_or_tile(output), False)]
1298        # None of the inputs and control inputs were converted.
1299        elif (not is_inside_loop or
1300              (not is_stateful and not some_input_converted and
1301               not some_control_input_converted)):
1302          if y == y_op:
1303            assert not isinstance(y_op, WhileOp)
1304            new_outputs = y_op
1305          else:
1306            new_outputs = [wrap(x, False) for x in y_op.outputs]
1307        elif not (is_stateful or is_while_loop or some_input_stacked):
1308          # All inputs are unstacked or uncoverted but some control inputs are
1309          # converted.
1310          # TODO(rachelim): Handle the case where some inputs are sparsely
1311          # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs))
1312          new_op = _create_op(y_op.type, [x.t for x in converted_inputs],
1313                              [x.dtype for x in y_op.outputs],
1314                              y_op.node_def.attr)
1315          if y == y_op:
1316            new_outputs = new_op
1317          else:
1318            new_outputs = [wrap(x, False) for x in new_op.outputs]
1319        else:
1320          # Either some inputs are not loop invariant or op is stateful.
1321          if hasattr(y_op, "pfor_converter"):
1322            converter = y_op.pfor_converter
1323          else:
1324            converter = _pfor_converter_registry.get(y_op.type, None)
1325          if converter is None:
1326            if flags.FLAGS.op_conversion_fallback_to_while_loop:
1327              converter = _fallback_converter
1328            else:
1329              raise ValueError(
1330                  "No converter defined for %s\n%s\ninputs: %s. "
1331                  "\nEither add a converter or set "
1332                  "--op_conversion_fallback_to_while_loop=True, "
1333                  "which may run slower" % (y_op.type, y_op, converted_inputs))
1334          # TODO(rachelim): Handle the case where some inputs are sparsely
1335          # stacked. We should only call the converter if it supports handling
1336          # those inputs.
1337          new_outputs = converter(_PforInput(self, y_op, converted_inputs))
1338          if isinstance(new_outputs, WrappedTensor):
1339            new_outputs = [new_outputs]
1340          assert isinstance(new_outputs,
1341                            (list, tuple, ops.Operation)), new_outputs
1342        logging.vlog(2, "converted %s %s", y_op, new_outputs)
1343
1344        # Insert into self._conversion_map
1345        if y == y_op:
1346          assert isinstance(new_outputs, ops.Operation)
1347          self._add_conversion(y_op, new_outputs)
1348        else:
1349          for old_output, new_output in zip(y_op.outputs, new_outputs):
1350            assert isinstance(new_output, WrappedTensor), (new_output, y, y_op)
1351            self._add_conversion(old_output, new_output)
1352        stack.pop(0)
1353
1354    return self._conversion_map[op_or_tensor]
1355
1356  @property
1357  def loop_len_vector(self):
1358    """Returns a single element vector whose value is number of iterations."""
1359    return self._loop_len_vector
1360
1361  @property
1362  def loop_var(self):
1363    """Returns placeholder loop variable."""
1364    return self._loop_var
1365
1366  @property
1367  def pfor_ops(self):
1368    return self._pfor_ops
1369
1370  @property
1371  def all_indices_partitioned(self):
1372    """all_indices_partitioned property.
1373
1374    Returns:
1375      True if we are inside a control flow construct and not all pfor iterations
1376      may be active.
1377    """
1378    return self._all_indices_partitioned
1379
1380# nn_ops
1381
1382
1383def _flatten_first_two_dims(x):
1384  """Merges first two dimensions."""
1385  old_shape = array_ops.shape(x)
1386  new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0)
1387  return array_ops.reshape(x, new_shape)
1388
1389
1390def _unflatten_first_dim(x, first_dim):
1391  """Splits first dimension into [first_dim, -1]."""
1392  old_shape = array_ops.shape(x)
1393  new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0)
1394  return array_ops.reshape(x, new_shape)
1395
1396
1397def _inputs_with_flattening(pfor_input, input_indices):
1398  """Stacks and flattens first dim of inputs at indices `input_indices`."""
1399  if input_indices is None:
1400    input_indices = []
1401  pfor_input.stack_inputs(stack_indices=input_indices)
1402  inputs = []
1403  for i in range(pfor_input.num_inputs):
1404    if i in input_indices:
1405      inp = pfor_input.stacked_input(i)
1406      inp = _flatten_first_two_dims(inp)
1407    else:
1408      inp = pfor_input.unstacked_input(i)
1409    inputs.append(inp)
1410  return inputs
1411
1412
1413@RegisterPForWithArgs("Conv2D", dims=[0])
1414@RegisterPForWithArgs("AvgPool", dims=[0])
1415@RegisterPForWithArgs("MaxPool", dims=[0])
1416@RegisterPForWithArgs("MaxPool3D", dims=[0])
1417@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2])
1418@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2])
1419@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2])
1420@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2])
1421@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1])
1422def _convert_flatten_batch(pfor_input, op_type, dims):
1423  del op_type
1424  inputs = _inputs_with_flattening(pfor_input, dims)
1425  outputs = _create_op(
1426      pfor_input.op_type,
1427      inputs, [x.dtype for x in pfor_input.outputs],
1428      attrs=pfor_input.op.node_def.attr).outputs
1429  n = pfor_input.pfor.loop_len_vector
1430  outputs = [_unflatten_first_dim(x, n) for x in outputs]
1431  return [wrap(x, True) for x in outputs]
1432
1433
1434_channel_flatten_input_cache = {}
1435
1436
1437def _channel_flatten_input(x, data_format):
1438  """Merge the stack dimension with the channel dimension.
1439
1440  If S is pfor's stacking dimension, then,
1441    - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose
1442      should be cheap.
1443    - for SNHWC, we transpose to NHWCS.
1444  We then merge the S and C dimension.
1445
1446  Args:
1447    x: ops.Tensor to transform.
1448    data_format: "NCHW" or "NHWC".
1449
1450  Returns:
1451    A 3-element tuple with the transformed value, along with the shape for
1452    reshape and order for transpose required to transform back.
1453  """
1454
1455  graph = ops.get_default_graph()
1456  cache_key = (graph, x, data_format)
1457  if cache_key not in _channel_flatten_input_cache:
1458    x_shape = array_ops.shape(x)
1459    if data_format == b"NCHW":
1460      order = [1, 0, 2, 3, 4]
1461      shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0)
1462      reverse_order = order
1463    else:
1464      order = [1, 2, 3, 0, 4]
1465      shape = array_ops.concat([x_shape[1:4], [-1]], axis=0)
1466      reverse_order = [3, 0, 1, 2, 4]
1467    # Move S dimension next to C dimension.
1468    x = array_ops.transpose(x, order)
1469    reverse_shape = array_ops.shape(x)
1470    # Reshape to merge the S and C dimension.
1471    x = array_ops.reshape(x, shape)
1472    outputs = x, reverse_order, reverse_shape
1473    _channel_flatten_input_cache[cache_key] = outputs
1474  else:
1475    outputs = _channel_flatten_input_cache[cache_key]
1476  return outputs
1477
1478
1479# Note that with training=True, running FusedBatchNorm on individual examples
1480# is very different from running FusedBatchNorm on a batch of those examples.
1481# This is because, for the latter case, the operation can be considered as first
1482# computing the mean and variance over all the examples and then using these
1483# to scale all those examples. This creates a data dependency between these
1484# different "iterations" since the inputs to the scaling step depends on the
1485# statistics coming from all these inputs.
1486# As with other kernels, the conversion here effectively runs the kernel
1487# independently for each iteration, and returns outputs by stacking outputs from
1488# each of those iterations.
1489@RegisterPFor("FusedBatchNorm")
1490def _convert_fused_batch_norm(pfor_input):
1491  is_training = pfor_input.get_attr("is_training")
1492  # When BatchNorm is used with training=False, mean and variance are provided
1493  # externally and used as is by the op. Thus, we can merge the S and N
1494  # dimensions as we do for regular operations.
1495  # When BatchNorm is used with training=True, mean and variance are computed
1496  # for each channel across the batch dimension (first one). If we merge S and N
1497  # dimensions, mean and variances will be computed over a larger set. So, we
1498  # merge the S and C dimensions instead.
1499  if not is_training:
1500    # We return zeros for batch_mean and batch_variance output. Note that CPU
1501    # and GPU seem to have different behavior for those two outputs. CPU outputs
1502    # zero because these values are not used during inference. GPU outputs
1503    # something, probably real means and variances.
1504    inputs = _inputs_with_flattening(pfor_input, [0])
1505    outputs = _create_op(
1506        pfor_input.op_type,
1507        inputs, [x.dtype for x in pfor_input.outputs],
1508        attrs=pfor_input.op.node_def.attr).outputs
1509    y = outputs[0]
1510    n = pfor_input.pfor.loop_len_vector
1511    y = _unflatten_first_dim(y, n)
1512    mean = pfor_input.unstacked_input(3)
1513    zeros = array_ops.zeros_like(mean)
1514    return [wrap(y, True), wrap(zeros, False), wrap(zeros, False)]
1515
1516  pfor_input.stack_inputs()
1517  data_format = pfor_input.get_attr("data_format")
1518  # We merge the first dimension with the "C" dimension, run FusedBatchNorm, and
1519  # then transpose back.
1520  x = pfor_input.stacked_input(0)
1521  x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format)
1522  # Note that we stack all the other inputs as well so that they are the same
1523  # size as the new size of the channel dimension.
1524  inputs = [x] + [
1525      array_ops.reshape(pfor_input.stacked_input(i), [-1])
1526      for i in range(1, pfor_input.num_inputs)
1527  ]
1528  outputs = _create_op(
1529      pfor_input.op_type,
1530      inputs, [x.dtype for x in pfor_input.outputs],
1531      attrs=pfor_input.op.node_def.attr).outputs
1532  y = outputs[0]
1533  y = array_ops.reshape(y, reverse_shape)
1534  y = array_ops.transpose(y, reverse_order)
1535  n = pfor_input.pfor.loop_len_vector
1536  outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
1537  outputs = [y] + outputs
1538  return [wrap(x, True) for x in outputs]
1539
1540
1541@RegisterPFor("FusedBatchNormGrad")
1542def _convert_fused_batch_norm_grad(pfor_input):
1543  pfor_input.stack_inputs()
1544  data_format = pfor_input.get_attr("data_format")
1545  y_backprop = pfor_input.stacked_input(0)
1546  y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format)
1547  x = pfor_input.stacked_input(1)
1548  x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format)
1549  inputs = [y_backprop, x] + [
1550      array_ops.reshape(pfor_input.stacked_input(i), [-1])
1551      for i in range(2, pfor_input.num_inputs)
1552  ]
1553  outputs = _create_op(
1554      pfor_input.op_type,
1555      inputs, [x.dtype for x in pfor_input.outputs],
1556      attrs=pfor_input.op.node_def.attr).outputs
1557  x_backprop = outputs[0]
1558  x_backprop = array_ops.reshape(x_backprop, x_reverse_shape)
1559  x_backprop = array_ops.transpose(x_backprop, x_reverse_order)
1560  n = pfor_input.pfor.loop_len_vector
1561  outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
1562  outputs = [x_backprop] + outputs
1563  return [wrap(output, True) for output in outputs]
1564
1565
1566@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0)
1567@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0)
1568def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims,
1569                                       shape_dim):
1570  del op_type
1571  inputs = _inputs_with_flattening(pfor_input, flatten_dims)
1572  n = pfor_input.pfor.loop_len_vector
1573  # Adjust the `input_sizes` input.
1574  ones = array_ops.ones(
1575      [array_ops.shape(inputs[shape_dim])[0] - 1], dtype=n.dtype)
1576  inputs[shape_dim] *= array_ops.concat([n, ones], axis=0)
1577  outputs = _create_op(
1578      pfor_input.op_type,
1579      inputs, [x.dtype for x in pfor_input.outputs],
1580      attrs=pfor_input.op.node_def.attr).outputs
1581  outputs = [_unflatten_first_dim(x, n) for x in outputs]
1582  return [wrap(x, True) for x in outputs]
1583
1584
1585@RegisterPFor("Conv2DBackpropFilter")
1586def _convert_conv2d_backprop_filter(pfor_input):
1587  pfor_input.stack_inputs(stack_indices=[2])
1588  inputs, inputs_stacked, _ = pfor_input.input(0)
1589  filter_sizes = pfor_input.unstacked_input(1)
1590  grads = pfor_input.stacked_input(2)
1591  strides = pfor_input.get_attr("strides")
1592  padding = pfor_input.get_attr("padding")
1593  use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu")
1594  data_format = pfor_input.get_attr("data_format")
1595  dilations = pfor_input.get_attr("dilations")
1596  if inputs_stacked:
1597    # TODO(agarwal): Implement this efficiently.
1598    logging.warn("Conv2DBackpropFilter uses a while_loop. Fix that!")
1599
1600    def while_body(i, ta):
1601      inp_i = inputs[i, ...]
1602      grad_i = grads[i, ...]
1603      output = nn_ops.conv2d_backprop_filter(
1604          inp_i,
1605          filter_sizes,
1606          grad_i,
1607          strides=strides,
1608          padding=padding,
1609          use_cudnn_on_gpu=use_cudnn_on_gpu,
1610          data_format=data_format,
1611          dilations=dilations)
1612      return i + 1, ta.write(i, array_ops.expand_dims(output, 0))
1613
1614    n = array_ops.reshape(pfor_input.pfor.loop_len_vector, [])
1615    _, ta = control_flow_ops.while_loop(
1616        lambda i, ta: i < n, while_body,
1617        (0, tensor_array_ops.TensorArray(inputs.dtype, n)))
1618    output = ta.concat()
1619    return wrap(output, True)
1620  else:
1621    # We merge the stack dimension with the channel dimension of the gradients
1622    # and pretend we had a larger filter (see change to filter_sizes below).
1623    # Once the filter backprop is computed, we reshape and transpose back
1624    # appropriately.
1625    grads, _, _ = _channel_flatten_input(grads, data_format)
1626    n = pfor_input.pfor.loop_len_vector
1627    old_filter_sizes = filter_sizes
1628    filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0)
1629    output = nn_ops.conv2d_backprop_filter(
1630        inputs,
1631        filter_sizes,
1632        grads,
1633        strides=strides,
1634        padding=padding,
1635        use_cudnn_on_gpu=use_cudnn_on_gpu,
1636        data_format=data_format,
1637        dilations=dilations)
1638    new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0)
1639    output = array_ops.reshape(output, new_filter_shape)
1640    output = array_ops.transpose(output, [3, 0, 1, 2, 4])
1641    return wrap(output, True)
1642
1643
1644# array_ops
1645
1646
1647@RegisterPForWithArgs("Identity", array_ops.identity)
1648@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient)
1649@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part)
1650def _convert_identity(pfor_input, op_type, op_func):
1651  del op_type
1652  return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
1653
1654
1655@RegisterPFor("IdentityN")
1656def _convert_identity_n(pfor_input):
1657  outputs = array_ops.identity_n([x.t for x in pfor_input.inputs])
1658  return [wrap(out, inp.is_stacked) for out, inp in
1659          zip(outputs, pfor_input.inputs)]
1660
1661
1662@RegisterPFor("Reshape")
1663def _convert_reshape(pfor_input):
1664  t = pfor_input.stacked_input(0)
1665  shape = pfor_input.unstacked_input(1)
1666  new_dim = array_ops.shape(t)[:1]
1667  new_shape = array_ops.concat([new_dim, shape], axis=0)
1668  return wrap(array_ops.reshape(t, new_shape), True)
1669
1670
1671@RegisterPFor("ExpandDims")
1672def _convert_expanddims(pfor_input):
1673  t = pfor_input.stacked_input(0)
1674  dim = pfor_input.unstacked_input(1)
1675  dim += math_ops.cast(dim >= 0, dtypes.int32)
1676  return wrap(array_ops.expand_dims(t, axis=dim), True)
1677
1678
1679@RegisterPFor("Slice")
1680def _convert_slice(pfor_input):
1681  t = pfor_input.stacked_input(0)
1682  begin = pfor_input.unstacked_input(1)
1683  size = pfor_input.unstacked_input(2)
1684  begin = array_ops.concat([[0], begin], axis=0)
1685  size = array_ops.concat([[-1], size], axis=0)
1686  return wrap(array_ops.slice(t, begin, size), True)
1687
1688
1689@RegisterPFor("Tile")
1690def _convert_tile(pfor_input):
1691  t = pfor_input.stacked_input(0)
1692  multiples = pfor_input.unstacked_input(1)
1693  multiples = array_ops.concat([[1], multiples], 0)
1694  return wrap(array_ops.tile(t, multiples), True)
1695
1696
1697@RegisterPFor("Pack")
1698def _convert_pack(pfor_input):
1699  pfor_input.stack_inputs()
1700  axis = pfor_input.get_attr("axis")
1701  if axis >= 0:
1702    axis += 1
1703  return wrap(
1704      array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True)
1705
1706
1707@RegisterPFor("Unpack")
1708def _convert_unpack(pfor_input):
1709  value = pfor_input.stacked_input(0)
1710  axis = pfor_input.get_attr("axis")
1711  if axis >= 0:
1712    axis += 1
1713  num = pfor_input.get_attr("num")
1714  return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)]
1715
1716
1717@RegisterPFor("Pad")
1718def _convert_pad(pfor_input):
1719  t = pfor_input.stacked_input(0)
1720  paddings = pfor_input.unstacked_input(1)
1721  paddings = array_ops.concat([[[0, 0]], paddings], 0)
1722  return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True)
1723
1724
1725@RegisterPFor("Split")
1726def _convert_split(pfor_input):
1727  split_dim = pfor_input.unstacked_input(0)
1728  t = pfor_input.stacked_input(1)
1729  num_split = pfor_input.get_attr("num_split")
1730  split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
1731  return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)]
1732
1733
1734@RegisterPFor("SplitV")
1735def _convert_split_v(pfor_input):
1736  t = pfor_input.stacked_input(0)
1737  splits = pfor_input.unstacked_input(1)
1738  split_dim = pfor_input.unstacked_input(2)
1739  split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
1740  return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)]
1741
1742
1743@RegisterPFor("Transpose")
1744def _convert_transpose(pfor_input):
1745  t = pfor_input.stacked_input(0)
1746  perm = pfor_input.unstacked_input(1)
1747  new_perm = array_ops.concat([[0], perm + 1], axis=0)
1748  return wrap(array_ops.transpose(t, new_perm), True)
1749
1750
1751@RegisterPFor("ZerosLike")
1752def _convert_zeroslike(pfor_input):
1753  t = pfor_input.stacked_input(0)
1754  shape = array_ops.shape(t)[1:]
1755  return wrap(array_ops.zeros(shape, dtype=t.dtype), False)
1756
1757
1758@RegisterPFor("Gather")
1759@RegisterPFor("GatherV2")
1760def _convert_gather(pfor_input):
1761  param, param_stacked, _ = pfor_input.input(0)
1762  indices, indices_stacked, _ = pfor_input.input(1)
1763  op_type = pfor_input.op_type
1764  if op_type == "Gather":
1765    validate_indices = pfor_input.get_attr("validate_indices")
1766    axis = 0
1767  else:
1768    validate_indices = None
1769    axis = pfor_input.unstacked_input(2)
1770    axis_value = tensor_util.constant_value(axis)
1771    if axis_value is not None:
1772      axis = axis_value
1773  if indices_stacked and not param_stacked:
1774    if indices == pfor_input.pfor.all_indices and axis == 0:
1775      param_shape0 = param.shape.dims[0].value
1776      indices_shape0 = indices.shape.dims[0].value
1777      if param_shape0 is not None and indices_shape0 == param_shape0:
1778        # Note that with loops and conditionals, indices may not be contiguous.
1779        # However they will be sorted and unique. So if the shape matches, then
1780        # it must be picking up all the rows of param.
1781        return wrap(param, True)
1782      # TODO(agarwal): use array_ops.slice here.
1783    output = array_ops.gather(
1784        param, indices, validate_indices=validate_indices, axis=axis)
1785    if axis != 0:
1786      axis = control_flow_ops.cond(
1787          axis < 0, lambda: axis + array_ops.rank(param), lambda: axis)
1788      order = array_ops.concat(
1789          [[axis],
1790           math_ops.range(axis),
1791           math_ops.range(axis + 1, array_ops.rank(output))],
1792          axis=0)
1793      output = control_flow_ops.cond(
1794          math_ops.equal(axis, 0), lambda: output,
1795          lambda: array_ops.transpose(output, order))
1796    return wrap(output, True)
1797  if param_stacked:
1798    loop_len_vector = pfor_input.pfor.loop_len_vector
1799    pfor_input.stack_inputs(stack_indices=[1])
1800    indices = pfor_input.stacked_input(1)
1801    param_flat = _flatten_first_two_dims(param)
1802
1803    # Recompute indices to handle stacked param.
1804    indices_offset = math_ops.range(
1805        loop_len_vector[0]) * array_ops.shape(param)[1]
1806    # Reshape indices_offset to allow broadcast addition
1807    ones = array_ops.ones([array_ops.rank(indices) - 1], dtype=dtypes.int32)
1808    new_shape = array_ops.concat([loop_len_vector, ones], axis=0)
1809    indices_offset = array_ops.reshape(indices_offset, new_shape)
1810    indices += indices_offset
1811
1812    # TODO(agarwal): handle axis != 0. May need to transpose param or
1813    # array_ops.gather_nd.
1814    if isinstance(axis, ops.Tensor):
1815      axis_value = tensor_util.constant_value(axis)
1816    else:
1817      try:
1818        axis_value = int(axis)
1819      except TypeError:
1820        axis_value = None
1821    msg = ("Gather, where indices and param are both loop dependent, currently "
1822           "requires axis=0")
1823    if axis_value is not None and axis_value != 0:
1824      raise ValueError("Error while converting %s. %s. Got axis=%d" %
1825                       (pfor_input.op, msg, axis))
1826    with ops.control_dependencies(
1827        [check_ops.assert_equal(axis, 0, message=msg)]):
1828      output = array_ops.gather(param_flat, indices)
1829    return wrap(output, True)
1830
1831
1832@RegisterPFor("ConcatV2")
1833def _convert_concatv2(pfor_input):
1834  n = pfor_input.num_inputs
1835  pfor_input.stack_inputs(stack_indices=range(n - 1))
1836  axis = pfor_input.unstacked_input(n - 1)
1837  axis += math_ops.cast(axis >= 0, axis.dtype)
1838  return wrap(
1839      array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis),
1840      True)
1841
1842
1843@RegisterPFor("StridedSlice")
1844def _convert_strided_slice(pfor_input):
1845  inp = pfor_input.stacked_input(0)
1846  begin = pfor_input.unstacked_input(1)
1847  end = pfor_input.unstacked_input(2)
1848  strides = pfor_input.unstacked_input(3)
1849  begin_mask = pfor_input.get_attr("begin_mask")
1850  end_mask = pfor_input.get_attr("end_mask")
1851  ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
1852  new_axis_mask = pfor_input.get_attr("new_axis_mask")
1853  shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
1854
1855  begin = array_ops.concat([[0], begin], axis=0)
1856  end = array_ops.concat([[0], end], axis=0)
1857  strides = array_ops.concat([[1], strides], axis=0)
1858  begin_mask = begin_mask << 1 | 1
1859  end_mask = end_mask << 1 | 1
1860  ellipsis_mask <<= 1
1861  new_axis_mask <<= 1
1862  shrink_axis_mask <<= 1
1863  return wrap(
1864      array_ops.strided_slice(
1865          inp,
1866          begin,
1867          end,
1868          strides,
1869          begin_mask=begin_mask,
1870          end_mask=end_mask,
1871          ellipsis_mask=ellipsis_mask,
1872          new_axis_mask=new_axis_mask,
1873          shrink_axis_mask=shrink_axis_mask), True)
1874
1875
1876@RegisterPFor("StridedSliceGrad")
1877def _convert_strided_slice_grad(pfor_input):
1878  shape = pfor_input.unstacked_input(0)
1879  begin = pfor_input.unstacked_input(1)
1880  end = pfor_input.unstacked_input(2)
1881  strides = pfor_input.unstacked_input(3)
1882  dy = pfor_input.stacked_input(4)
1883  begin_mask = pfor_input.get_attr("begin_mask")
1884  end_mask = pfor_input.get_attr("end_mask")
1885  ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
1886  new_axis_mask = pfor_input.get_attr("new_axis_mask")
1887  shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
1888
1889  shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
1890  begin = array_ops.concat([[0], begin], axis=0)
1891  end = array_ops.concat([[0], end], axis=0)
1892  strides = array_ops.concat([[1], strides], axis=0)
1893  begin_mask = begin_mask << 1 | 1
1894  end_mask = end_mask << 1 | 1
1895  ellipsis_mask <<= 1
1896  new_axis_mask <<= 1
1897  shrink_axis_mask <<= 1
1898  return wrap(
1899      array_ops.strided_slice_grad(
1900          shape,
1901          begin,
1902          end,
1903          strides,
1904          dy,
1905          begin_mask=begin_mask,
1906          end_mask=end_mask,
1907          ellipsis_mask=ellipsis_mask,
1908          new_axis_mask=new_axis_mask,
1909          shrink_axis_mask=shrink_axis_mask), True)
1910
1911
1912# math_ops
1913
1914
1915@RegisterPFor("MatMul")
1916def _convert_matmul(pfor_input):
1917  # TODO(agarwal): Check if tiling is faster than two transposes.
1918  a, a_stacked, _ = pfor_input.input(0)
1919  b, b_stacked, _ = pfor_input.input(1)
1920  tr_a = pfor_input.get_attr("transpose_a")
1921  tr_b = pfor_input.get_attr("transpose_b")
1922  if a_stacked and b_stacked:
1923    output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True)
1924    return output
1925  elif a_stacked:
1926    if tr_a:
1927      a = array_ops.transpose(a, [0, 2, 1])
1928    if a.shape.is_fully_defined():
1929      x, y, z = a.shape
1930    else:
1931      x, y, z = [
1932          array_ops.reshape(i, [])
1933          for i in array_ops.split(array_ops.shape(a), 3)
1934      ]
1935    a = array_ops.reshape(a, [x * y, z])
1936    prod = math_ops.matmul(a, b, transpose_b=tr_b)
1937    return wrap(array_ops.reshape(prod, [x, y, -1]), True)
1938  else:
1939    assert b_stacked
1940    if tr_b:
1941      perm = [2, 0, 1]
1942      b = array_ops.transpose(b, perm)
1943    else:
1944      # As an optimization, if one of the first two dimensions is 1, then we can
1945      # reshape instead of transpose.
1946      # TODO(agarwal): This check can be done inside Transpose kernel.
1947      b_shape = array_ops.shape(b)
1948      min_dim = math_ops.minimum(b_shape[0], b_shape[1])
1949      perm = control_flow_ops.cond(
1950          math_ops.equal(min_dim, 1), lambda: [0, 1, 2], lambda: [1, 0, 2])
1951      new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]])
1952      b = array_ops.transpose(b, perm)
1953      b = array_ops.reshape(b, new_shape)
1954
1955    if b.shape.is_fully_defined():
1956      x, y, z = b.shape
1957    else:
1958      x, y, z = [
1959          array_ops.reshape(i, [])
1960          for i in array_ops.split(array_ops.shape(b), 3)
1961      ]
1962    b = array_ops.reshape(b, [x, y * z])
1963    prod = math_ops.matmul(a, b, transpose_a=tr_a)
1964    prod = array_ops.reshape(prod, [-1, y, z])
1965    prod = array_ops.transpose(prod, [1, 0, 2])
1966    return wrap(prod, True)
1967
1968
1969@RegisterPFor("BatchMatMul")
1970def _convert_batch_mat_mul(pfor_input):
1971  # TODO(agarwal): There may be a more efficient way to do this instead of
1972  # stacking the inputs.
1973  pfor_input.stack_inputs()
1974  x = pfor_input.stacked_input(0)
1975  y = pfor_input.stacked_input(1)
1976  adj_x = pfor_input.get_attr("adj_x")
1977  adj_y = pfor_input.get_attr("adj_y")
1978
1979  x = _flatten_first_two_dims(x)
1980  y = _flatten_first_two_dims(y)
1981  output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
1982  output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector)
1983  return wrap(output, True)
1984
1985
1986@RegisterPForWithArgs("Sum", math_ops.reduce_sum)
1987@RegisterPForWithArgs("Prod", math_ops.reduce_prod)
1988@RegisterPForWithArgs("Max", math_ops.reduce_max)
1989@RegisterPForWithArgs("Min", math_ops.reduce_min)
1990@RegisterPForWithArgs("Mean", math_ops.reduce_mean)
1991def _convert_reduction(pfor_input, _, op_func):
1992  t = pfor_input.stacked_input(0)
1993  indices = pfor_input.unstacked_input(1)
1994  # Shift positive indices by one to account for the extra dimension.
1995  indices += math_ops.cast(indices >= 0, dtypes.int32)
1996  keep_dims = pfor_input.get_attr("keep_dims")
1997  return wrap(op_func(t, indices, keepdims=keep_dims), True)
1998
1999
2000@RegisterPForWithArgs("Cumsum", math_ops.cumsum)
2001@RegisterPForWithArgs("Cumprod", math_ops.cumprod)
2002def _convert_cumfoo(pfor_input, _, op_func):
2003  t = pfor_input.stacked_input(0)
2004  axis = pfor_input.unstacked_input(1)
2005  # Shift positive indices by one to account for the extra dimension.
2006  axis += math_ops.cast(axis >= 0, dtypes.int32)
2007  exclusive = pfor_input.get_attr("exclusive")
2008  reverse = pfor_input.get_attr("reverse")
2009  return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True)
2010
2011
2012@RegisterPFor("BiasAdd")
2013def _convert_biasadd(pfor_input):
2014  t, t_stacked, _ = pfor_input.input(0)
2015  bias, bias_stacked, _ = pfor_input.input(1)
2016  data_format = pfor_input.get_attr("data_format").decode()
2017  if bias_stacked:
2018    # BiasAdd only supports 1-D biases, so cast bias to match value and use Add.
2019    pfor_input.expanddim_inputs_for_broadcast()
2020    t, _, _ = pfor_input.input(0)
2021    bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype)
2022    if compat.as_bytes(data_format) == b"NCHW":
2023      b_shape = array_ops.shape(bias)
2024      new_b_shape = array_ops.concat(
2025          [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0)
2026      bias = array_ops.reshape(bias, new_b_shape)
2027    return wrap(math_ops.add(t, bias), True)
2028  else:
2029    assert t_stacked, "At least one input to BiasAdd should be loop variant."
2030    if compat.as_bytes(data_format) == b"NCHW":
2031      shape = array_ops.shape(t)
2032      flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0)
2033      t = array_ops.reshape(t, flattened_shape)
2034      t = nn_ops.bias_add(t, bias, data_format="NCHW")
2035      t = array_ops.reshape(t, shape)
2036      return wrap(t, True)
2037    return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True)
2038
2039
2040@RegisterPFor("UnsortedSegmentSum")
2041def _convert_unsortedsegmentsum(pfor_input):
2042  data, data_stacked, _ = pfor_input.input(0)
2043  # TODO(agarwal): handle unstacked?
2044  segment_ids = pfor_input.stacked_input(1)
2045  # TODO(agarwal): handle stacked?
2046  num_segments = pfor_input.unstacked_input(2)
2047  if not data_stacked:
2048    data = _stack(data, pfor_input.pfor.loop_len_vector).t
2049  segment_shape = array_ops.shape(segment_ids)
2050  n = segment_shape[0]
2051  ones = array_ops.ones_like(segment_shape)[1:]
2052  segment_offset = num_segments * math_ops.range(n)
2053  segment_offset = array_ops.reshape(segment_offset,
2054                                     array_ops.concat([[n], ones], axis=0))
2055  segment_ids += segment_offset
2056  num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast(
2057      n, dtypes.int64)
2058  output = math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
2059  new_output_shape = array_ops.concat(
2060      [[n, -1], array_ops.shape(output)[1:]], axis=0)
2061  output = array_ops.reshape(output, new_output_shape)
2062  return wrap(output, True)
2063
2064
2065@RegisterPFor("Cast")
2066def _convert_cast(pfor_input):
2067  inp = pfor_input.stacked_input(0)
2068  dtype = pfor_input.get_attr("DstT")
2069  return wrap(math_ops.cast(inp, dtype), True)
2070
2071
2072@RegisterPForWithArgs("Abs", math_ops.abs)
2073@RegisterPForWithArgs("Acos", math_ops.acos)
2074@RegisterPForWithArgs("Acosh", math_ops.acosh)
2075@RegisterPForWithArgs("Add", math_ops.add)
2076@RegisterPForWithArgs("AddV2", math_ops.add_v2)
2077@RegisterPForWithArgs("Angle", math_ops.angle)
2078@RegisterPForWithArgs("Asin", math_ops.asin)
2079@RegisterPForWithArgs("Asinh", math_ops.asinh)
2080@RegisterPForWithArgs("Atan", math_ops.atan)
2081@RegisterPForWithArgs("Atan2", math_ops.atan2)
2082@RegisterPForWithArgs("Atanh", math_ops.atanh)
2083@RegisterPForWithArgs("BesselI0e", math_ops.bessel_i0e)
2084@RegisterPForWithArgs("BesselI1e", math_ops.bessel_i1e)
2085@RegisterPForWithArgs("BitwiseAnd", bitwise_ops.bitwise_and)
2086@RegisterPForWithArgs("BitwiseOr", bitwise_ops.bitwise_or)
2087@RegisterPForWithArgs("BitwiseXor", bitwise_ops.bitwise_xor)
2088@RegisterPForWithArgs("Ceil", math_ops.ceil)
2089@RegisterPForWithArgs("Complex", math_ops.complex)
2090@RegisterPForWithArgs("ComplexAbs", math_ops.complex_abs)
2091@RegisterPForWithArgs("Conj", math_ops.conj)
2092@RegisterPForWithArgs("Cos", math_ops.cos)
2093@RegisterPForWithArgs("Cosh", math_ops.cosh)
2094@RegisterPForWithArgs("Digamma", math_ops.digamma)
2095@RegisterPForWithArgs("Div", math_ops.div)
2096@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan)
2097@RegisterPForWithArgs("Elu", nn_ops.elu)
2098@RegisterPForWithArgs("Equal", math_ops.equal)
2099@RegisterPForWithArgs("Erf", math_ops.erf)
2100@RegisterPForWithArgs("Erfc", math_ops.erfc)
2101@RegisterPForWithArgs("Exp", math_ops.exp)
2102@RegisterPForWithArgs("Expm1", math_ops.expm1)
2103@RegisterPForWithArgs("Floor", math_ops.floor)
2104@RegisterPForWithArgs("FloorDiv", math_ops.floor_div)
2105@RegisterPForWithArgs("FloorMod", math_ops.floor_mod)
2106@RegisterPForWithArgs("Greater", math_ops.greater)
2107@RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal)
2108@RegisterPForWithArgs("Igamma", math_ops.igamma)
2109@RegisterPForWithArgs("IgammaGradA", math_ops.igamma_grad_a)
2110@RegisterPForWithArgs("Igammac", math_ops.igammac)
2111@RegisterPForWithArgs("Imag", math_ops.imag)
2112@RegisterPForWithArgs("Inv", math_ops.inv)
2113@RegisterPForWithArgs("Invert", bitwise_ops.invert)
2114@RegisterPForWithArgs("IsFinite", math_ops.is_finite)
2115@RegisterPForWithArgs("IsInf", math_ops.is_inf)
2116@RegisterPForWithArgs("LeftShift", bitwise_ops.left_shift)
2117@RegisterPForWithArgs("Less", math_ops.less)
2118@RegisterPForWithArgs("LessEqual", math_ops.less_equal)
2119@RegisterPForWithArgs("Lgamma", math_ops.lgamma)
2120@RegisterPForWithArgs("Log", math_ops.log)
2121@RegisterPForWithArgs("Log1p", math_ops.log1p)
2122@RegisterPForWithArgs("LogicalAnd", math_ops.logical_and)
2123@RegisterPForWithArgs("LogicalNot", math_ops.logical_not)
2124@RegisterPForWithArgs("LogicalOr", math_ops.logical_or)
2125@RegisterPForWithArgs("LogicalXor", math_ops.logical_xor)
2126@RegisterPForWithArgs("Maximum", math_ops.maximum)
2127@RegisterPForWithArgs("Minimum", math_ops.minimum)
2128@RegisterPForWithArgs("Mod", math_ops.mod)
2129@RegisterPForWithArgs("Mul", math_ops.multiply)
2130@RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan)
2131@RegisterPForWithArgs("Neg", math_ops.negative)
2132@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
2133@RegisterPForWithArgs("Polygamma", math_ops.polygamma)
2134@RegisterPForWithArgs("Pow", math_ops.pow)
2135@RegisterPForWithArgs("Real", math_ops.real)
2136@RegisterPForWithArgs("RealDiv", math_ops.divide)
2137@RegisterPForWithArgs("Reciprocal", math_ops.reciprocal)
2138@RegisterPForWithArgs("Relu", nn_ops.relu)
2139@RegisterPForWithArgs("Relu6", nn_ops.relu6)
2140@RegisterPForWithArgs("RightShift", bitwise_ops.right_shift)
2141@RegisterPForWithArgs("Rint", math_ops.rint)
2142@RegisterPForWithArgs("Round", math_ops.round)
2143@RegisterPForWithArgs("Rsqrt", math_ops.rsqrt)
2144@RegisterPForWithArgs("Selu", nn_ops.selu)
2145@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid)
2146@RegisterPForWithArgs("Sign", math_ops.sign)
2147@RegisterPForWithArgs("Sin", math_ops.sin)
2148@RegisterPForWithArgs("Sinh", math_ops.sinh)
2149@RegisterPForWithArgs("Softplus", nn_ops.softplus)
2150@RegisterPForWithArgs("Softsign", nn_ops.softsign)
2151@RegisterPForWithArgs("Sqrt", math_ops.sqrt)
2152@RegisterPForWithArgs("Square", math_ops.square)
2153@RegisterPForWithArgs("SquaredDifference", math_ops.squared_difference)
2154@RegisterPForWithArgs("Sub", math_ops.subtract)
2155@RegisterPForWithArgs("Tan", math_ops.tan)
2156@RegisterPForWithArgs("Tanh", math_ops.tanh)
2157@RegisterPForWithArgs("TruncateDiv", math_ops.truncate_div)
2158@RegisterPForWithArgs("TruncateMod", math_ops.truncate_mod)
2159@RegisterPForWithArgs("Xdivy", math_ops.xdivy)
2160@RegisterPForWithArgs("Xlogy", math_ops.xlogy)
2161@RegisterPForWithArgs("Zeta", math_ops.zeta)
2162def _convert_cwise(pfor_input, op_type, op_func):
2163  # Note that ops handled here do not have attributes except "T" and "Tout", and
2164  # hence don't need extra arguments passed to the cwise_op call below.
2165  for attr in pfor_input.op.node_def.attr.keys():
2166    assert attr in [u"T", u"Tout"], (op_type, attr)
2167  pfor_input.expanddim_inputs_for_broadcast()
2168  return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
2169
2170
2171@RegisterPFor("ApproximateEqual")
2172def _convert_approximate_equal(pfor_input):
2173  pfor_input.expanddim_inputs_for_broadcast()
2174  x = pfor_input.input(0)[0]
2175  y = pfor_input.input(1)[0]
2176  tolerance = pfor_input.get_attr("tolerance")
2177  return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True)
2178
2179
2180@RegisterPFor("Shape")
2181def _convert_shape(pfor_input):
2182  out_type = pfor_input.get_attr("out_type")
2183  return wrap(
2184      array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:],
2185      False)
2186
2187
2188@RegisterPFor("ShapeN")
2189def _convert_shape_n(pfor_input):
2190  out_type = pfor_input.get_attr("out_type")
2191  shapes = [
2192      array_ops.shape(x, out_type=out_type)[1:]
2193      if stacked else array_ops.shape(x) for x, stacked, _ in pfor_input.inputs
2194  ]
2195  return [wrap(x, False) for x in shapes]
2196
2197
2198@RegisterPFor("Size")
2199def _convert_size(pfor_input):
2200  out_type = pfor_input.get_attr("out_type")
2201  n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type)
2202  return wrap(
2203      array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n,
2204      False)
2205
2206
2207@RegisterPFor("Rank")
2208def _convert_rank(pfor_input):
2209  return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False)
2210
2211
2212@RegisterPFor("AddN")
2213def _convert_addn(pfor_input):
2214  # AddN does not support broadcasting.
2215  pfor_input.stack_inputs()
2216  return wrap(math_ops.add_n([x.t for x in pfor_input.inputs]), True)
2217
2218
2219@RegisterPFor("BiasAddGrad")
2220def _convert_biasaddgrad(pfor_input):
2221  grad = pfor_input.stacked_input(0)
2222  fmt = pfor_input.get_attr("data_format")
2223  if fmt == b"NCHW":
2224    output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False)
2225  else:
2226    grad_shape = array_ops.shape(grad)
2227    last_dim_shape = grad_shape[-1]
2228    first_dim_shape = grad_shape[0]
2229    output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape])
2230    output = math_ops.reduce_sum(output, axis=[1], keepdims=False)
2231  return wrap(output, True)
2232
2233
2234# Some required ops are not exposed under the tf namespace. Hence relying on
2235# _create_op to create them.
2236@RegisterPForWithArgs("EluGrad")
2237@RegisterPForWithArgs("Relu6Grad")
2238@RegisterPForWithArgs("ReluGrad")
2239@RegisterPForWithArgs("SeluGrad")
2240@RegisterPForWithArgs("SigmoidGrad")
2241@RegisterPForWithArgs("SoftplusGrad")
2242@RegisterPForWithArgs("SoftsignGrad")
2243@RegisterPForWithArgs("TanhGrad")
2244@RegisterPForWithArgs("SqrtGrad")
2245@RegisterPForWithArgs("RsqrtGrad")
2246@RegisterPForWithArgs("ReciprocalGrad")
2247def _convert_grads(pfor_input, op_type, *args, **kw_args):
2248  del args
2249  del kw_args
2250  # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we
2251  # have to use tiling here.
2252  pfor_input.stack_inputs()
2253  outputs = _create_op(
2254      op_type, [x.t for x in pfor_input.inputs],
2255      [x.dtype for x in pfor_input.outputs],
2256      attrs=pfor_input.op.node_def.attr).outputs
2257  return [wrap(x, True) for x in outputs]
2258
2259
2260@RegisterPFor("Select")
2261def _convert_select(pfor_input):
2262  pfor_input.stack_inputs()
2263  cond = pfor_input.stacked_input(0)
2264  t = pfor_input.stacked_input(1)
2265  e = pfor_input.stacked_input(2)
2266  cond_rank = array_ops.rank(cond)
2267  cond, t, e = control_flow_ops.cond(
2268      cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]),
2269      lambda: [cond, t, e])
2270  outputs = _create_op(
2271      pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs],
2272      attrs=pfor_input.op.node_def.attr).outputs
2273  n = pfor_input.pfor.loop_len_vector
2274  out = control_flow_ops.cond(cond_rank > 1,
2275                              lambda: _unflatten_first_dim(outputs[0], n),
2276                              lambda: outputs[0])
2277  return [wrap(out, True) for x in outputs]
2278
2279
2280# random_ops
2281
2282
2283@RegisterPForWithArgs("RandomUniform")
2284@RegisterPForWithArgs("RandomUniformInt")
2285@RegisterPForWithArgs("RandomStandardNormal")
2286@RegisterPForWithArgs("TruncatedNormal")
2287@RegisterPForWithArgs("RandomGamma")
2288@RegisterPForWithArgs("RandomPoissonV2")
2289def _convert_random(pfor_input, op_type, *args, **kw_args):
2290  del args
2291  del kw_args
2292  inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)]
2293  # inputs[0] is "shape"
2294  inputs[0] = array_ops.concat(
2295      [pfor_input.pfor.loop_len_vector, inputs[0]], axis=0)
2296  logging.warning(
2297      "Note that %s inside pfor op may not give same output as "
2298      "inside a sequential loop.", op_type)
2299  outputs = _create_op(
2300      op_type,
2301      inputs, [x.dtype for x in pfor_input.outputs],
2302      attrs=pfor_input.op.node_def.attr).outputs
2303  return [wrap(x, True) for x in outputs]
2304
2305
2306# logging_ops
2307
2308
2309@RegisterPFor("Assert")
2310def _convert_assert(pfor_input):
2311  cond, cond_stacked, _ = pfor_input.input(0)
2312  if cond_stacked:
2313    cond = math_ops.reduce_all(cond)
2314
2315  data_list = [x.t for x in pfor_input.inputs][1:]
2316  return _create_op("Assert", [cond] + data_list, [],
2317                    attrs=pfor_input.op.node_def.attr)
2318
2319
2320@RegisterPFor("Print")
2321def _convert_print(pfor_input):
2322  # Note that we don't stack all the inputs. Hence unstacked values are printed
2323  # once here vs multiple times in a while_loop.
2324  pfor_input.stack_inputs([0])
2325  outputs = _create_op(
2326      "Print", [x.t for x in pfor_input.inputs],
2327      [x.dtype for x in pfor_input.outputs],
2328      attrs=pfor_input.op.node_def.attr).outputs
2329  return [wrap(x, True) for x in outputs]
2330
2331
2332# data_flow_ops
2333
2334# TensorArray conversion is tricky since we don't support arrays of
2335# TensorArrays. For converting them, we consider two distinct cases:
2336#
2337# 1. The array is constructed outside the pfor call, and read/written inside the
2338# loop.
2339# This is an easier case since we don't need to make an array of TensorArrays.
2340# A correctness requirement is that these parallel iterations shouldn't attempt
2341# to write to the same location. Hence at conversion time we disallow indices to
2342# be loop-invariant as that would guarantee a collision. Even if the indices are
2343# not loop-invariant, they could conflict and that shall trigger runtime errors.
2344#
2345# 2. The array is constructed and used entirely inside each pfor iteration.
2346# For simplicity, here we require that the indices used for write/scatter are
2347# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in
2348# different pfor iterations. We consider two sub_cases:
2349#
2350# 2a Elements written to the array are "stacked"
2351# To simulate multiple TensorArrays, we may increase the dimension of each
2352# element of the array. i.e. the i_th row of the j_th entry of the converted
2353# TensorArray corresponds to the j_th entry of the TensorArray in the i_th
2354# pfor iteration.
2355#
2356# 2b Elements written to the array are "unstacked"
2357# In this case we don't increase the dimensions to avoid redundant tiling. Each
2358# iteration is trying to write the same value. So we convert that to a single
2359# write.
2360#
2361# Here are some tricks used to implement the above:
2362# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of
2363# trying to trace whether future writes are stacked or unstacked in order to set
2364# this attr, we set it to correspond to unknown shape.
2365# - We use the "flow" output of the different ops to track whether the array
2366# elements are stacked or unstacked. If a stacked write/scatter is done, we make
2367# the flow stacked as well.
2368# - We use some heuristic traversal of the graph to track whether the
2369# TensorArray handle was created inside or outside the pfor loop.
2370
2371
2372@RegisterPFor("TensorArrayV3")
2373def _convert_tensor_array_v3(pfor_input):
2374  size = pfor_input.unstacked_input(0)
2375  dtype = pfor_input.get_attr("dtype")
2376  dynamic_size = pfor_input.get_attr("dynamic_size")
2377  clear_after_read = pfor_input.get_attr("clear_after_read")
2378  identical_element_shapes = pfor_input.get_attr("identical_element_shapes")
2379  tensor_array_name = pfor_input.get_attr("tensor_array_name")
2380  handle, flow = data_flow_ops.tensor_array_v3(
2381      size,
2382      dtype=dtype,
2383      # We don't set element shape since we don't know if writes are stacked or
2384      # not yet.
2385      element_shape=None,
2386      dynamic_size=dynamic_size,
2387      clear_after_read=clear_after_read,
2388      identical_element_shapes=identical_element_shapes,
2389      tensor_array_name=tensor_array_name)
2390  # Note we keep flow unstacked for now since we don't know if writes will be
2391  # stacked or not.
2392  return wrap(handle, False), wrap(flow, False)
2393
2394
2395@RegisterPFor("TensorArraySizeV3")
2396def _convert_tensor_array_size_v3(pfor_input):
2397  handle = pfor_input.unstacked_input(0)
2398  flow, flow_stacked, _ = pfor_input.input(1)
2399  if flow_stacked:
2400    flow = _unstack_flow(flow)
2401  size = data_flow_ops.tensor_array_size_v3(handle, flow)
2402  return wrap(size, False)
2403
2404
2405def _handle_inside_pfor(pfor_input, handle):
2406  """Returns True if handle was created inside the pfor loop."""
2407  # We use some heuristic to find the original TensorArray creation op.
2408  # The logic should handle the common cases (except cond based subgraphs).
2409  # In theory the user could perform different operations on the handle (like
2410  # Reshape, stack multiple handles, etc) which could break this logic.
2411  # TODO(agarwal): handle Switch/Merge.
2412  while handle.op.type in ("Enter", "Identity"):
2413    handle = handle.op.inputs[0]
2414  if handle.op.type not in [
2415      "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape"]:
2416    raise ValueError("Unable to find source for handle %s" % handle)
2417  else:
2418    return pfor_input.pfor.op_is_inside_loop(handle.op)
2419
2420
2421def _unstack_flow(value):
2422  # TODO(agarwal): consider looking if this is a Tile op then get its input.
2423  # This may avoid running the Tile operations.
2424  return array_ops.gather(value, 0)
2425
2426
2427@RegisterPFor("TensorArrayReadV3")
2428def _convert_tensor_array_read_v3(pfor_input):
2429  handle = pfor_input.unstacked_input(0)
2430  index, index_stacked, _ = pfor_input.input(1)
2431  dtype = pfor_input.get_attr("dtype")
2432  flow, flow_stacked, _ = pfor_input.input(2)
2433  if flow_stacked:
2434    flow = _unstack_flow(flow)
2435
2436  is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
2437  if is_inside_pfor:
2438    # Note that if we are inside a control flow construct inside the pfor, and
2439    # only some of the iterations are doing the read (i.e.
2440    # `all_indices_partitioned` is True), then the read operation should only
2441    # return values for the currently active pfor iterations (`all_indices`
2442    # below). Hence, whenever the returned value is stacked (i.e. `flow` is
2443    # stacked), we may need to do an extra gather after reading the values. Also
2444    # note that if `is_inside` is false, then values in the tensor array are
2445    # unstacked. So the check is only needed in this branch.
2446    all_indices = pfor_input.pfor.all_indices
2447    all_indices_partitioned = pfor_input.pfor.all_indices_partitioned
2448    # Note: flow_stacked indicates if values in the TensorArray are stacked or
2449    # not.
2450    if index_stacked:
2451      if flow_stacked:
2452        raise ValueError(
2453            "It looks like TensorArrayReadV3 was called on a TensorArray whose"
2454            " values are not loop-invariant, and the read indices were also"
2455            " not loop invariant. This is currently unsupported.")
2456      value = data_flow_ops.tensor_array_gather_v3(
2457          handle, index, flow, dtype=dtype)
2458      return wrap(value, True)
2459    value = data_flow_ops.tensor_array_read_v3(
2460        handle, index, flow, dtype=dtype)
2461    if flow_stacked and all_indices_partitioned:
2462      value = array_ops.gather(value, all_indices)
2463    return wrap(value, flow_stacked)
2464  # Values in the TensorArray should be unstacked (since different iterations
2465  # couldn't write to the same location). So whether output is stacked or not
2466  # depends on index_stacked.
2467  if index_stacked:
2468    value = data_flow_ops.tensor_array_gather_v3(
2469        handle, index, flow, dtype=dtype)
2470  else:
2471    value = data_flow_ops.tensor_array_read_v3(
2472        handle, index, flow, dtype=dtype)
2473  return wrap(value, index_stacked)
2474
2475
2476@RegisterPFor("TensorArrayWriteV3")
2477def _convert_tensor_array_write_v3(pfor_input):
2478  handle = pfor_input.unstacked_input(0)
2479  index, index_stacked, _ = pfor_input.input(1)
2480  value, value_stacked, _ = pfor_input.input(2)
2481  flow, flow_stacked, _ = pfor_input.input(3)
2482  if value_stacked and pfor_input.pfor.all_indices_partitioned:
2483    # Looks like we are in a control flow in a pfor where not all iterations are
2484    # active now. We don't allow that since that could lead to different indices
2485    # having different shapes which will be hard to merge later.
2486    raise ValueError("Writing non loop invariant values to TensorArray from "
2487                     "inside a while_loop/cond not supported.")
2488  if flow_stacked:
2489    flow = _unstack_flow(flow)
2490  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
2491  if is_inside:
2492    if index_stacked:
2493      raise ValueError("Need indices for %s to be loop invariant" % handle)
2494    if not flow_stacked and not value_stacked:
2495      flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
2496      return wrap(flow_out, False)
2497    else:
2498      if not value_stacked:
2499        value = _stack(value, pfor_input.pfor.loop_len_vector).t
2500      # TODO(agarwal): Note that if flow is unstacked and value is stacked, then
2501      # this may or may not be a safe situation. flow is unstacked both for a
2502      # freshly created TensorArray, as well as after unstacked values are
2503      # written to it. If it is the latter, then we cannot write a stacked value
2504      # now since that may cause runtime errors due to different shapes in the
2505      # array. At the moment we are not able to handle this gracefully and
2506      # distinguish between the two cases. That would require some heuristic
2507      # traversal of the graph to figure out whether all the writes are
2508      # unstacked or not.
2509      flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
2510      return _stack(flow_out, pfor_input.pfor.loop_len_vector)
2511  else:
2512    if not index_stacked:
2513      raise ValueError("Need indices for %s to be not loop invariant" % handle)
2514    # Note that even when index_stacked is true, actual values in index may
2515    # still not be unique. However that will cause runtime error when executing
2516    # the scatter operation below.
2517    if not value_stacked:
2518      value = _stack(value, pfor_input.pfor.loop_len_vector).t
2519    flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow)
2520    return _stack(flow_out, pfor_input.pfor.loop_len_vector)
2521
2522
2523def _transpose_first_two_dims(value):
2524  # TODO(agarwal): optimize if one of the dims == 1.
2525  value_shape = array_ops.shape(value)
2526  v0 = value_shape[0]
2527  v1 = value_shape[1]
2528  value = array_ops.reshape(value, [v0, v1, -1])
2529  value = array_ops.transpose(value, [1, 0, 2])
2530  new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0)
2531  return array_ops.reshape(value, new_shape)
2532
2533
2534@RegisterPFor("TensorArrayGatherV3")
2535def _convert_tensor_array_gather_v3(pfor_input):
2536  handle = pfor_input.unstacked_input(0)
2537  indices, indices_stacked, _ = pfor_input.input(1)
2538  indices = array_ops.reshape(indices, [-1])
2539  flow, flow_stacked, _ = pfor_input.input(2)
2540  if flow_stacked:
2541    flow = _unstack_flow(flow)
2542  dtype = pfor_input.get_attr("dtype")
2543  # TODO(agarwal): support element_shape attr?
2544
2545  n = pfor_input.pfor.loop_len_vector
2546  value = data_flow_ops.tensor_array_gather_v3(
2547      handle, indices, flow, dtype=dtype)
2548  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
2549  if is_inside:
2550    # flow_stacked indicates if values in the TensorArray are stacked or not.
2551    if indices_stacked:
2552      if flow_stacked:
2553        raise ValueError(
2554            "It looks like TensorArrayGatherV3 was called on a TensorArray "
2555            "whose values are not loop-invariant, and the indices were also "
2556            "not loop invariant. This is currently unsupported.")
2557      else:
2558        value = _unflatten_first_dim(value, n)
2559        return wrap(value, True)
2560    else:
2561      if flow_stacked:
2562        # Since elements in this array are stacked and `value` was produced by
2563        # gather, its first two dims are "gathered elements" and "stack
2564        # dimension". Our semantics require these two to be flipped.
2565        value = _transpose_first_two_dims(value)
2566      return wrap(value, flow_stacked)
2567  else:
2568    # Values in the TensorArray should be unstacked (since different iterations
2569    # couldn't write to the same location). So whether output is stacked or not
2570    # depends on indices_stacked.
2571    if indices_stacked:
2572      value = _unflatten_first_dim(value, n)
2573    return wrap(value, indices_stacked)
2574
2575
2576@RegisterPFor("TensorArrayScatterV3")
2577def _convert_tensor_array_scatter_v3(pfor_input):
2578  handle = pfor_input.unstacked_input(0)
2579  indices, indices_stacked, _ = pfor_input.input(1)
2580  indices = array_ops.reshape(indices, [-1])
2581  value, value_stacked, _ = pfor_input.input(2)
2582  flow, flow_stacked, _ = pfor_input.input(3)
2583
2584  if flow_stacked:
2585    flow = _unstack_flow(flow)
2586
2587  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
2588  if is_inside:
2589    if indices_stacked:
2590      raise ValueError("Need indices for %s to be loop invariant" % handle)
2591    # Note that flow_stacked indicates if existing values in the array are
2592    # stacked or not.
2593    if not flow_stacked and not value_stacked:
2594      flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
2595                                                       flow)
2596      return wrap(flow_out, False)
2597    if not value_stacked:
2598      # TODO(agarwal): tile in the second dimension directly instead of
2599      # transposing below.
2600      value = _stack(value, pfor_input.pfor.loop_len_vector).t
2601
2602    value = _transpose_first_two_dims(value)
2603    # TODO(agarwal): Note that if a previous write was unstacked, flow will be
2604    # unstacked, and a stacked value may be written here which may cause
2605    # runtime error due to different elements having different shape. We do
2606    # not try to prevent that.
2607    flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
2608                                                     flow)
2609    return _stack(flow_out, pfor_input.pfor.loop_len_vector)
2610  if not indices_stacked:
2611    raise ValueError("Need indices for %s to be not loop invariant" % handle)
2612  if not value_stacked:
2613    value = _stack(value, pfor_input.pfor.loop_len_vector).t
2614  value = _flatten_first_two_dims(value)
2615  flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
2616                                                   flow)
2617  return _stack(flow_out, pfor_input.pfor.loop_len_vector)
2618
2619
2620@RegisterPFor("TensorArrayGradV3")
2621def _convert_tensor_array_grad_v3(pfor_input):
2622  handle = pfor_input.unstacked_input(0)
2623  flow, flow_stacked, _ = pfor_input.input(1)
2624  if flow_stacked:
2625    flow = _unstack_flow(flow)
2626  source = pfor_input.get_attr("source")
2627  # TODO(agarwal): For now, we assume that gradients are stacked if the
2628  # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong
2629  # will give runtime error due to incorrect shape being written to the
2630  # accumulator. It is difficult to know in advance if gradients written will be
2631  # stacked or not. Note that flow being stacked is not indicative of the
2632  # gradient being stacked or not. Revisit this later.
2633  shape_to_prepend = pfor_input.pfor.loop_len_vector
2634  grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape(
2635      handle=handle,
2636      flow_in=flow,
2637      shape_to_prepend=shape_to_prepend,
2638      source=source)
2639  flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t
2640  return [wrap(grad_handle, False), wrap(flow_out, True)]
2641
2642
2643# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar
2644# to TensorArrays, we convert them by changing the dimension of the elements
2645# inside the stack.
2646#
2647# We consider two cases:
2648#
2649# 1. StackV2 is constructed and used entirely inside the pfor loop.
2650# We keep a single Stack and perform the push/pop operations of all the
2651# iterations in lock-step. We also assume that all the iterations perform these
2652# operations. In case of dynamic control flow, if only some of the iterations
2653# try to perform a push/pop, then the conversion may not work correctly and may
2654# cause undefined behavior.
2655# TODO(agarwal): test StackV2 with dynamic control flow.
2656#
2657# 2. StackV2 is constructed outside the pfor loop.
2658# Performing stack push/pop in a parallel fashion is ill-defined. However given
2659# that reading stacks created externally is a common operation when computing
2660# jacobians, we provide some special semantics here as follows.
2661#  - disallow push operations to the stack
2662#  - pop operations are performed in lock step by all iterations, similar to the
2663#  case when the stack is created inside. A single value is popped during the
2664#  lock-step operation and broadcast to all the iterations. Values in the stack
2665#  are assumed to be loop-invariant.
2666#
2667# Some other implementation details:
2668# We use an ugly logic to find whether values in Stack data structure are
2669# loop invariant or not. When converting push/pop operations, we keep track of
2670# whether the last conversion used a stacked value or not (see _stack_cache
2671# below). As a result if an unstacked value is written first, subsequent stacked
2672# writes are disallowed when they could have been allowed in theory.
2673
2674# Map from cache key based on StackV2 handle to a bool indicating whether values
2675# are stacked or not.
2676# TODO(agarwal): move _stack_cache inside pfor?
2677_stack_cache = {}
2678
2679
2680def _stack_cache_key(pfor_input):
2681  """Create cache key corresponding to a stack handle."""
2682  op_type = pfor_input.op_type
2683  assert op_type in ["StackPushV2", "StackPopV2"], op_type
2684  orig_handle = pfor_input.op.inputs[0]
2685  while orig_handle.op.type in ["Identity", "Enter"]:
2686    orig_handle = orig_handle.op.inputs[0]
2687  assert orig_handle.op.type == "StackV2", orig_handle.op
2688  return ops.get_default_graph(), pfor_input.pfor, orig_handle
2689
2690
2691def _stack_handle_inside_pfor(handle, pfor_input):
2692  while handle.op.type in ["Identity", "Enter"]:
2693    handle = handle.op.inputs[0]
2694  assert handle.op.type == "StackV2", (
2695      "Unable to find StackV2 op. Got %s" % handle.op)
2696  return pfor_input.pfor.op_is_inside_loop(handle.op)
2697
2698
2699@RegisterPFor("StackPushV2")
2700def _convert_stack_push_v2(pfor_input):
2701  handle = pfor_input.unstacked_input(0)
2702  elem, elem_stacked, _ = pfor_input.input(1)
2703  swap_memory = pfor_input.get_attr("swap_memory")
2704
2705  if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input):
2706    raise ValueError("StackPushV2 not allowed on stacks created outside pfor")
2707  stack_cache_key = _stack_cache_key(pfor_input)
2708  stacked = _stack_cache.get(stack_cache_key, None)
2709  if stacked is None:
2710    stacked = elem_stacked
2711    _stack_cache[stack_cache_key] = stacked
2712  else:
2713    # If we previously made it unstacked then we can't revert to being stacked.
2714    if not stacked and elem_stacked:
2715      raise ValueError(
2716          "It looks like the stack was previously determined to be loop"
2717          " invariant, but we are now trying to push a loop dependent value"
2718          " to it. This is currently unsupported.")
2719    if stacked and not elem_stacked:
2720      elem = _stack(elem, pfor_input.pfor.loop_len_vector).t
2721  out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory)
2722  return wrap(out, stacked)
2723
2724
2725# Note that inputs to this convertor will be unstacked. However it should get
2726# called since it is a stateful op.
2727@RegisterPFor("StackPopV2")
2728def _convert_stack_pop_v2(pfor_input):
2729  handle = pfor_input.unstacked_input(0)
2730  stack_cache_key = _stack_cache_key(pfor_input)
2731  stacked = _stack_cache.get(stack_cache_key, None)
2732  # If a StackPushV2 has not been converted yet, we default to unstacked since
2733  # the push could be outside of pfor, or the covertor may not be called if the
2734  # inputs are unconverted.
2735  if stacked is None:
2736    stacked = False
2737    _stack_cache[stack_cache_key] = False
2738  elem_type = pfor_input.get_attr("elem_type")
2739  out = data_flow_ops.stack_pop_v2(handle, elem_type)
2740  return wrap(out, stacked)
2741
2742
2743# parsing_ops
2744
2745
2746@RegisterPFor("DecodeCSV")
2747def _convert_decode_csv(pfor_input):
2748  lines = pfor_input.stacked_input(0)
2749  record_defaults = [
2750      pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
2751  ]
2752  field_delim = pfor_input.get_attr("field_delim")
2753  use_quote_delim = pfor_input.get_attr("use_quote_delim")
2754  select_cols = pfor_input.get_attr("select_cols")
2755  if not select_cols:
2756    select_cols = None
2757  return [
2758      wrap(t, True) for t in parsing_ops.decode_csv(
2759          lines,
2760          record_defaults,
2761          field_delim=field_delim,
2762          use_quote_delim=use_quote_delim,
2763          select_cols=select_cols)
2764  ]
2765
2766
2767@RegisterPFor("ParseSingleExample")
2768def _convert_parse_single_example(pfor_input):
2769  serialized = pfor_input.stacked_input(0)
2770  dense_defaults = [
2771      pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
2772  ]
2773  sparse_keys = pfor_input.get_attr("sparse_keys")
2774  dense_keys = pfor_input.get_attr("dense_keys")
2775  sparse_types = pfor_input.get_attr("sparse_types")
2776  dense_shapes = pfor_input.get_attr("dense_shapes")
2777  output = gen_parsing_ops.parse_example(
2778      serialized=serialized,
2779      names=[],
2780      dense_defaults=dense_defaults,
2781      sparse_keys=sparse_keys,
2782      dense_keys=dense_keys,
2783      sparse_types=sparse_types,
2784      dense_shapes=dense_shapes)
2785  return [wrap(t, True, True) for t in nest.flatten(output)]
2786