1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Compiled parallel-for loop."""
16# pylint: disable=missing-docstring,g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import string
24import sys
25import traceback
26
27import numpy as np
28import six
29
30from tensorflow.compiler.tf2xla.python import xla
31from tensorflow.core.framework import types_pb2
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.eager import execute
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import func_graph
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import sparse_tensor
40from tensorflow.python.framework import tensor_shape
41from tensorflow.python.framework import tensor_spec
42from tensorflow.python.framework import tensor_util
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import bitwise_ops
45from tensorflow.python.ops import control_flow_ops
46from tensorflow.python.ops import custom_gradient
47from tensorflow.python.ops import data_flow_ops
48from tensorflow.python.ops import gen_array_ops
49from tensorflow.python.ops import gen_dataset_ops
50from tensorflow.python.ops import gen_image_ops
51from tensorflow.python.ops import gen_linalg_ops
52from tensorflow.python.ops import gen_list_ops
53from tensorflow.python.ops import gen_math_ops
54from tensorflow.python.ops import gen_nn_ops
55from tensorflow.python.ops import gen_parsing_ops
56from tensorflow.python.ops import gen_random_ops
57from tensorflow.python.ops import gen_sparse_ops
58from tensorflow.python.ops import gen_spectral_ops
59from tensorflow.python.ops import linalg_ops
60from tensorflow.python.ops import list_ops
61from tensorflow.python.ops import map_fn
62from tensorflow.python.ops import math_ops
63from tensorflow.python.ops import nn_ops
64from tensorflow.python.ops import parsing_ops
65from tensorflow.python.ops import resource_variable_ops
66from tensorflow.python.ops import sparse_ops
67from tensorflow.python.ops import special_math_ops
68from tensorflow.python.ops import tensor_array_ops
69from tensorflow.python.platform import flags
70from tensorflow.python.platform import tf_logging as logging
71from tensorflow.python.util import compat
72from tensorflow.python.util import nest
73from tensorflow.python.util import object_identity
74
75
76# TODO(agarwal): remove flag.
77flags.DEFINE_bool(
78    "op_conversion_fallback_to_while_loop", True,
79    "DEPRECATED: Flag is ignored.")
80
81
82def _variant_handle_data(t):
83  """Fetches handle data for a variant tensor `t`, or None if unavailable."""
84  handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
85  if not handle_data.is_set:
86    return None
87  return handle_data.shape_and_type
88
89
90def _is_variant_with_internal_stacking(t):
91  """Identifies variant tensors which pfor always maintains as scalars.
92
93  For these, the pfor tensor is recorded as "stacked" if the content of the
94  variant tensor (e.g. the elements of a TensorList) are all stacked.
95
96  Args:
97    t: A tensor to identify.
98  Returns:
99    True if `t` is a TensorList/Optional, False not, None if unknown.
100  """
101  if t.dtype != dtypes.variant:
102    return False
103  shapes_and_types = _variant_handle_data(t)
104  if shapes_and_types is None or not shapes_and_types:
105    # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
106    # make this an error instead of assuming TensorLists have handle data.
107    return None  # Presumed not a TensorList/Optional
108  return (shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST or
109          shapes_and_types[0].specialized_type == types_pb2.ST_OPTIONAL)
110
111
112def _parse_variant_shapes_and_types(t):
113  """Extracts shape and dtype information from a variant tensor `t`."""
114  shapes_and_types = _variant_handle_data(t)
115  if shapes_and_types is None or not shapes_and_types:
116    raise ValueError("Required handle data not set for {!r}".format(t))
117  if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST:
118    return shapes_and_types
119  else:
120    if shapes_and_types[0].specialized_type != types_pb2.ST_INVALID:
121      return shapes_and_types
122    else:
123      raise ValueError(
124          "Attempted to stack a variant-dtype tensor with no type set ({!r})"
125          .format(t))
126
127
128def _stack(t, length):
129  """stacks `t` `length` times."""
130  # Note that this stacking may currently be triggered, for example, when a
131  # loop invariant tensor with dtype variant is input to a while_loop which then
132  # produces a loop dependent output. Simply stacking the variants may not be
133  # suitable since operations on stacked handles may expect a vectorized version
134  # of the variant.
135  if t.dtype == dtypes.variant:
136    shapes_and_types = _parse_variant_shapes_and_types(t)
137    if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST:
138      if len(shapes_and_types) != 1:
139        raise ValueError(
140            "Expected handle data of length 1, got {!r} of length {}"
141            .format(shapes_and_types, len(shapes_and_types)))
142      return wrap(
143          _stack_tensor_list(t, shapes_and_types[0].dtype, length),
144          True)
145    else:
146      raise ValueError(
147          ("Attempted to stack an unhandled variant-dtype tensor of "
148           "type {!r} ({!r})").format(shapes_and_types[0].specialized_type, t))
149  ones = array_ops.ones_like(array_ops.shape(t))
150  ones = array_ops.reshape(ones, [-1])
151  length = array_ops.reshape(length, [-1])
152  multiples = array_ops.concat([length, ones], 0)
153  t = array_ops.tile(array_ops.expand_dims(t, 0), multiples)
154  return wrap(t, True)
155
156
157# The following stateful ops can be safely called once, and with the same
158# signature as the unconverted version, if their inputs are loop invariant.
159# TODO(agarwal): implement a strategy for converting Variable reads/writes. The
160# plan is to map each read/write in the loop_fn to a corresponding merged
161# read/write in the converted graph. Writes need to be mergeable (e.g.
162# AssignAdd) to be used in `pfor`. Given a certain read/write order in the
163# loop_fn, doing a one-to-one conversion will simulate executing such
164# instructions in lock-step across all iterations.
165passthrough_stateful_ops = set([
166    "VariableV2",
167    "VarHandleOp",
168    "VariableShape",
169    "ReadVariableOp",
170    "StackV2",
171    "TensorArrayWriteV3",
172    "TensorArrayReadV3",
173    "TensorArraySizeV3",
174])
175
176
177# Ops which we will treat like stateful for the purpose of vectorization.
178# Typically this is used to force pfor converters to run for these ops.
179force_stateful_ops = set([
180    # We vectorize this since we need to change the element shape set on the
181    # list.
182    "TensorListReserve",
183])
184
185
186def _is_stateful_pfor_op(op):
187  if isinstance(op, WhileOp):
188    return op.is_stateful
189  if op.type == "Const":
190    # Const didn't have an op_def.
191    return False
192  if op.type in passthrough_stateful_ops:
193    return False
194  if op.type in force_stateful_ops:
195    return True
196  assert hasattr(op, "op_def") and op.op_def is not None, op
197  return op.op_def.is_stateful
198
199
200# pylint: disable=protected-access
201class WhileOp(object):
202  """Object for storing state for converting the outputs of a while_loop."""
203
204  def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config):
205    """Initializer.
206
207    Args:
208      exit_node: A tensor output from the while_loop.
209      pfor_ops: list of ops inside the current pfor loop.
210      fallback_to_while_loop: If True, fallback to while loop when conversion of
211        an op is not supported
212      pfor_config: PForConfig object used while constructing loop body.
213    """
214    self._fallback_to_while_loop = fallback_to_while_loop
215    self._pfor_config = pfor_config
216    self._pfor_ops = set(pfor_ops)
217    self._pfor_op_ids = set(x._id for x in pfor_ops)
218    assert isinstance(exit_node, ops.Tensor)
219    self._while_context = exit_node.op._get_control_flow_context()
220    assert isinstance(self._while_context, control_flow_ops.WhileContext)
221    self._context_name = self._while_context.name
222    self._condition = self._while_context.pivot.op.inputs[0]
223    # Parts of an external while_loop could be created inside a pfor loop.
224    # However for the purpose here, we declare such loops to be external. Also
225    # note that we check if the condition was created inside or outside to
226    # determine if the while_loop was first created inside or outside.
227    # TODO(agarwal): check that the Enter and Exit of this loop are unstacked.
228    self._is_inside_loop = self.op_is_inside_loop(self._condition.op)
229    if self._is_inside_loop:
230      for e in self._while_context.loop_exits:
231        assert self.op_is_inside_loop(e.op)
232
233    # Note the code below tries to reverse engineer an existing while_loop graph
234    # by assuming the following pattern of nodes.
235    #
236    #          NextIteration <---- Body <--- Enter
237    #              |                ^
238    #              V             ___| Y
239    #    Enter -> Merge -> Switch___
240    #                       ^       | N
241    #                       |       V
242    #                  LoopCond    Exit
243
244    # Node that elements in the list below correspond one-to-one with each
245    # other. i.e. these lists are the same size, and the i_th entry corresponds
246    # to different Operations/Tensors of a single cycle as illustrated above.
247    # List of Switch ops (ops.Operation) that feed into an Exit Node.
248    self._exit_switches = []
249    # List of inputs (ops.Tensor) to NextIteration.
250    self._body_outputs = []
251    # List of list of control inputs of the NextIteration nodes.
252    self._next_iter_control_inputs = []
253    # List of Merge ops (ops.Operation).
254    self._enter_merges = []
255    # List of output (ops.Tensor) of Exit nodes.
256    self._outputs = []
257
258    # List of Enter Tensors.
259    # There are two types of Enter nodes:
260    # - The Enter nodes that are used in the `loop_vars` argument to
261    # `while_loop` (see
262    # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect
263    # these Enter nodes immediately below by tracing backwards from the Exit
264    # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the
265    # diagram above. This allows us to have a 1:1 correspondence between the
266    # self._outputs and the first elements in self._enters.
267    # - The Enter nodes that are used only by the body. They don't appear in the
268    # `loop_vars` and are not returned from the `while_loop`. In Python code,
269    # they are usually captured by the body lambda. We collect them below by
270    # iterating over all the ops in the graph. They are appended to the end of
271    # self._enters or self._direct_enters, and don't correspond to any outputs
272    # in self._outputs. Note that we keep the resource/variant Enter nodes in
273    # self._direct_enters and the constructed while_loop's body uses them
274    # directly as opposed to passing them as loop variables. This is done
275    # because the while_body cannot partition the resource/variant Tensors, so
276    # it has to leave them unchanged.
277    self._enters = []
278    self._direct_enters = []
279
280    for e in self._while_context.loop_exits:
281      self._outputs.append(e.op.outputs[0])
282      switch = e.op.inputs[0].op
283      assert switch.type == "Switch", switch
284      self._exit_switches.append(switch)
285      merge = switch.inputs[0].op
286      assert merge.type == "Merge", merge
287      self._enter_merges.append(merge)
288      enter = merge.inputs[0].op
289      assert enter.type == "Enter", enter
290      self._enters.append(enter.outputs[0])
291      next_iter = merge.inputs[1].op
292      assert next_iter.type == "NextIteration", next_iter
293      self._body_outputs.append(next_iter.inputs[0])
294      self._next_iter_control_inputs.append(next_iter.control_inputs)
295
296    # Collect all the Enter nodes that are not part of `loop_vars`, the second
297    # category described above.
298    # Also track whether the loop body has any stateful ops.
299    self._is_stateful = False
300    for op in ops.get_default_graph().get_operations():
301      # TODO(agarwal): make sure this works with nested case.
302      control_flow_context = op._get_control_flow_context()
303      if control_flow_context is None:
304        continue
305      if control_flow_context.name == self._context_name:
306        self._is_stateful |= _is_stateful_pfor_op(op)
307        if op.type == "Enter":
308          output = op.outputs[0]
309          if output not in self._enters:
310            if output.dtype in (dtypes.resource, dtypes.variant):
311              if output not in self._direct_enters:
312                self._direct_enters.append(output)
313            else:
314              self._enters.append(output)
315
316  def __str__(self):
317    """String representation."""
318    return "while_loop(%s)" % self.name
319
320  @property
321  def inputs(self):
322    """Input to all the Enter nodes."""
323    return [x.op.inputs[0] for x in self._enters + self._direct_enters]
324
325  @property
326  def control_inputs(self):
327    """Control input to all the Enter nodes."""
328    control_inputs = []
329    for x in self._enters + self._direct_enters:
330      control_inputs.extend(x.op.control_inputs)
331    return control_inputs
332
333  @property
334  def outputs(self):
335    """Outputs of all the Exit nodes."""
336    return self._outputs
337
338  @property
339  def name(self):
340    """Context name for the while loop."""
341    return self._context_name
342
343  @property
344  def is_inside_loop(self):
345    """Returns true if the while_loop was created inside the pfor."""
346    return self._is_inside_loop
347
348  def op_is_inside_loop(self, op):
349    """True if op was created inside the pfor loop body."""
350    assert isinstance(op, ops.Operation)
351    # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
352    # since it appears there tensorflow API could return different python
353    # objects representing the same Operation node.
354    return op._id in self._pfor_op_ids
355
356  @property
357  def is_stateful(self):
358    return self._is_stateful
359
360  @property
361  def pfor_converter(self):
362    """Return a converter for the while loop."""
363    return self
364
365  def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs,
366                 inputs_stacked):
367    """Create a PFor object for converting parts of the while_loop.
368
369    Args:
370      parent_pfor: PFor object being used for converting the while_loop.
371      indices: int32 Tensor of ids for the iterations that are still active
372        (i.e. did not exit the while_loop).
373      cond_stacked: True if the while_loop condition is stacked.
374      inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note
375        that these Tensors are a subset of the loop variables for the generated
376        while_loop.
377      inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`,
378        indicating if the value is stacked or not.
379
380    Returns:
381      A PFor instance. The instance is initialized by adding conversion mappings
382        of nodes that will be external to the conversion that the returned
383        instance will be used for. e.g. Enter nodes as well as Merge and Switch
384        outputs are mapped to converted values.
385    """
386    num_outputs = len(self._outputs)
387    assert len(inputs) == len(self._enters)
388    assert len(inputs_stacked) == len(self._enters)
389    loop_var = parent_pfor.loop_var
390    loop_len = array_ops.size(indices)
391    pfor = PFor(
392        loop_var,
393        loop_len,
394        pfor_ops=self._pfor_ops,
395        all_indices=indices,
396        all_indices_partitioned=cond_stacked,
397        fallback_to_while_loop=self._fallback_to_while_loop,
398        pfor_config=self._pfor_config)
399    # Map all inputs of Enter nodes in self._direct_enters to their converted
400    # values.
401    for enter in self._direct_enters:
402      enter_input = enter.op.inputs[0]
403      converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper(
404          enter_input)
405      # Since these are resources / variants, they should be unstacked.
406      assert not stacked and not is_sparse_stacked, (enter, converted_enter)
407      pfor._add_conversion(enter, wrap(converted_enter, False))
408
409    # Map all Enter nodes to the inputs.
410    for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked):
411      pfor._add_conversion(enter, wrap(inp, stacked))
412    # Map outputs of Switch and Merge.
413    for i in range(num_outputs):
414      wrapped_inp = wrap(inputs[i], inputs_stacked[i])
415      merge = self._enter_merges[i]
416      pfor._add_conversion(merge.outputs[0], wrapped_inp)
417      # Note that second output of Merge is typically not used, except possibly
418      # as a control dependency. To avoid trying to output the correct value, we
419      # employ a hack here. We output a dummy invalid value with an incorrect
420      # dtype. This will allow control dependency to work but if using it as an
421      # input, it should typically lead to errors during graph construction due
422      # to dtype mismatch.
423      # TODO(agarwal): Check in the original graph to see if there are any
424      # consumers of this Tensor that use it as an input.
425      pfor._add_conversion(merge.outputs[1],
426                           wrap(constant_op.constant(-1.0), False))
427      switch = self._exit_switches[i]
428      # Don't need to worry about switch.output[0] which will feed to Exit node.
429      pfor._add_conversion(switch.outputs[1], wrapped_inp)
430    return pfor
431
432  def _convert_enter(self, parent_pfor, enter):
433    """Converts an Enter node."""
434    inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0])
435    control_inputs = []
436    for x in enter.op.control_inputs:
437      converted = parent_pfor._convert_helper(x)
438      if not isinstance(converted, ops.Operation):
439        converted = converted.t
440      control_inputs.append(converted)
441    if control_inputs:
442      with ops.control_dependencies(control_inputs):
443        inp = array_ops.identity(inp)
444    return inp, stacked
445
446  def _maybe_stacked(self, cache, inp):
447    """Heuristic to figure out if the converting inp leads to a stacked value.
448
449
450    Args:
451      cache: map from Tensor to boolean indicating stacked/unstacked.
452      inp: input Tensor.
453
454    Returns:
455      True if `inp` could get stacked. If the function returns False, the
456      converted value should be guaranteed to be unstacked. If returning True,
457      it may or may not be stacked.
458    """
459    if inp in cache:
460      return cache[inp]
461    if not self.op_is_inside_loop(inp.op):
462      return False
463    op = inp.op
464    output = False
465    if op.type in [
466        "Shape",
467        "Rank",
468        "ShapeN",
469        "ZerosLike",
470        "TensorArrayV3",
471        "TensorArraySizeV3",
472    ]:
473      output = False
474    elif _is_stateful_pfor_op(op):
475      # This may be fairly aggressive.
476      output = True
477    elif op.type == "Exit":
478      # This may be fairly aggressive.
479      output = True
480    else:
481      for t in op.inputs:
482        if self._maybe_stacked(cache, t):
483          output = True
484          break
485    cache[inp] = output
486    return output
487
488  def _create_init_values(self, pfor_input):
489    """Create arguments passed to converted while_loop."""
490    with ops.name_scope("while_init"):
491      loop_len_vector = pfor_input.pfor.loop_len_vector
492      loop_len = loop_len_vector[0]
493      num_outputs = len(self._outputs)
494
495      inputs = []
496      maybe_stacked_cache = {}
497      # Convert all the Enters. Need to do this before checking for stacking
498      # below.
499      for i, enter in enumerate(self._enters):
500        inp, stacked = self._convert_enter(pfor_input.pfor, enter)
501        inputs.append(inp)
502        maybe_stacked_cache[enter] = stacked
503        # Since this enter node is part of the `loop_vars`, it corresponds to an
504        # output and its preceding switch. We mark this switch's output the same
505        # stackness, to act at the base case for the logic below. Below, we will
506        # be going through the body figuring out which inputs might need to be
507        # stacked and which inputs can safely remain unstacked.
508        if i < num_outputs:
509          maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked
510
511      # Shape invariants for init_values corresponding to self._enters.
512      input_shape_invariants = []
513      # TensorArrays for outputs of converted while loop
514      output_tas = []
515      # Shape invariants for output TensorArrays.
516      ta_shape_invariants = []
517      # List of booleans indicating stackness of inputs, i.e. tensors
518      # corresponding to self._enters.
519      inputs_stacked = []
520      for i, inp in enumerate(inputs):
521        enter = self._enters[i]
522        inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter)
523        # Note that even when an input is unstacked, the body could make it
524        # stacked. we use a heuristic below to figure out if body may be making
525        # it stacked.
526        if i < num_outputs:
527          body_output = self._body_outputs[i]
528          if enter.op in self._pfor_ops:
529            body_output_stacked = self._maybe_stacked(maybe_stacked_cache,
530                                                      body_output)
531          else:
532            # If constructed outside of pfor loop, then the output would not be
533            # stacked.
534            body_output_stacked = False
535          if body_output_stacked and not inp_stacked:
536            inp = _stack(inp, loop_len_vector).t
537            inputs[i] = inp
538            inp_stacked = True
539          # TODO(agarwal): other attributes for the TensorArray ?
540          output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len))
541          ta_shape_invariants.append(tensor_shape.TensorShape(None))
542
543        inputs_stacked.append(inp_stacked)
544        input_shape_invariants.append(tensor_shape.TensorShape(None))
545
546      # See documentation for __call__ for the structure of init_values.
547      init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas
548      # TODO(agarwal): try stricter shape invariants
549      shape_invariants = (
550          [tensor_shape.TensorShape(None),
551           tensor_shape.TensorShape(None)] + input_shape_invariants +
552          ta_shape_invariants)
553
554      return init_values, inputs_stacked, shape_invariants
555
556  def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
557    """Handles case when condition is unstacked.
558
559    Note that all iterations end together. So we don't need to partition the
560    inputs. When all iterations are done, we write the inputs to the
561    TensorArrays. Note that we only write to index 0 of output_tas. Since all
562    iterations end together, they can all be output together.
563    """
564    not_all_done = array_ops.reshape(conditions, [])
565    new_output_tas = []
566    # pylint: disable=cell-var-from-loop
567    for i, out_ta in enumerate(output_tas):
568      inp = inputs[i]
569      new_output_tas.append(
570          control_flow_ops.cond(not_all_done, lambda: out_ta,
571                                lambda: out_ta.write(0, inp)))
572    # pylint: enable=cell-var-from-loop
573    return not_all_done, indices, inputs, new_output_tas
574
575  def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
576                            output_tas):
577    num_outputs = len(self._outputs)
578    # Compute if all iterations are done.
579    not_all_done = math_ops.reduce_any(conditions)
580    conditions_int = math_ops.cast(conditions, dtypes.int32)
581    # Partition the indices.
582    done_indices, new_indices = data_flow_ops.dynamic_partition(
583        indices, conditions_int, 2)
584
585    new_inputs = []
586    new_output_tas = []
587    for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
588      # Partition the inputs.
589      if stacked:
590        done_inp, new_inp = data_flow_ops.dynamic_partition(
591            inp, conditions_int, 2)
592      else:
593        # TODO(agarwal): avoid this stacking. See TODO earlier in
594        # _process_cond_unstacked.
595        done_inp = _stack(inp, [array_ops.size(done_indices)]).t
596        new_inp = inp
597      new_inputs.append(new_inp)
598      # For iterations that are done, write them to TensorArrays.
599      if i < num_outputs:
600        out_ta = output_tas[i]
601        # Note that done_indices can be empty. done_inp should also be empty in
602        # that case.
603        new_output_tas.append(out_ta.scatter(done_indices, done_inp))
604    return not_all_done, new_indices, new_inputs, new_output_tas
605
606  def _process_body(self, pfor_input, inputs_stacked, new_indices, cond_stacked,
607                    new_inputs, not_all_done):
608    """Convert the body function."""
609
610    def true_fn(control_inputs, body_pfor, body_output, stacked):
611      """Converts the body function for all but last iteration.
612
613      This essentially converts body_output. Additionally, it needs to handle
614      any control dependencies on the NextIteration node. So it creates another
615      Identity node with the converted dependencies.
616      """
617      converted_control_inp = []
618      for x in control_inputs:
619        for t in x.outputs:
620          converted_control_inp.append(body_pfor._convert_helper(t).t)
621      if stacked:
622        # Note convert always does the stacking.
623        output = body_pfor.convert(body_output)
624      else:
625        output, convert_stacked, _ = body_pfor._convert_helper(body_output)
626        assert convert_stacked == stacked, body_output
627      with ops.control_dependencies(converted_control_inp):
628        return array_ops.identity(output)
629
630    body_pfor = self._init_pfor(pfor_input.pfor, new_indices, cond_stacked,
631                                new_inputs, inputs_stacked)
632    new_outputs = []
633
634    for i, (body_output,
635            stacked) in enumerate(zip(self._body_outputs, inputs_stacked)):
636      control_inp = self._next_iter_control_inputs[i]
637      out_dtype = body_output.dtype
638      # Note that we want to run the body only if not all pfor iterations are
639      # done. If all are done, we return empty tensors since these values will
640      # not be used. Notice that the value returned by the loop is based on
641      # TensorArrays and not directly on these returned values.
642      # pylint: disable=cell-var-from-loop
643      new_output = control_flow_ops.cond(
644          not_all_done,
645          lambda: true_fn(control_inp, body_pfor, body_output, stacked),
646          lambda: constant_op.constant([], dtype=out_dtype))
647      # pylint: enable=cell-var-from-loop
648      new_outputs.append(new_output)
649    return new_outputs
650
651  def __call__(self, pfor_input):
652    """Converter for the while_loop.
653
654    The conversion of a while_loop is another while_loop.
655
656    The arguments to this converted while_loop are as follows:
657    not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
658      are done.
659    indices: int32 1-D Tensor storing the id of the iterations that are not
660      done.
661    args: Remaining arguments. These can be divided into 3 categories:
662      - First set of arguments are the tensors that correspond to the initial
663        elements of self._enters. The elements that appear in original while
664        loop's `loop_vars`.
665      - The second set of arguments are the tensors that correspond to the
666        remaining elements of self._enters. These are the tensors that directly
667        enter the original while loop body.
668       - Finally, the last set of arguments are TensorArrays. These TensorArrays
669         correspond to the outputs of the original while_loop, i.e. to the
670         elements in self._outputs. Each TensorArray has `PFor.loop_len`
671         elements, i.e. the number of pfor iterations. At the end, the i'th
672         element of each TensorArray will contain the output computed by the
673         i'th iteration of pfor. Note that elements can be written into these
674         tensors arrays in any order, depending on when the corresponding pfor
675         iteration is done.
676      If the original while_loop had `k` tensors in its `loop_vars` and its body
677      directly captured `m` tensors, the `args` will contain `2 * k + m` values.
678
679    In each iteration, the while_loop body recomputes the condition for all
680    active pfor iterations to see which of them are now done. It then partitions
681    all the inputs and passes them along to the converted body. Values for all
682    the iterations that are done are written to TensorArrays indexed by the pfor
683    iteration number. When all iterations are done, the TensorArrays are stacked
684    to get the final value.
685
686    Args:
687      pfor_input: A PForInput object corresponding to the output of any Exit
688        node from this while loop.
689
690    Returns:
691      List of converted outputs.
692    """
693    # Create init_values that will be passed to the while_loop.
694    init_values, inputs_stacked, shape_invariants = self._create_init_values(
695        pfor_input)
696    # Note that we use a list as a hack since we need the nested function body
697    # to set the value of cond_is_stacked. python2.x doesn't support nonlocal
698    # variables.
699    cond_is_stacked = [None]
700
701    def cond(not_all_done, *_):
702      return not_all_done
703
704    def body(not_all_done, indices, *args):
705      # See documentation for __call__ for the structure of *args.
706      num_enters = len(self._enters)
707      inputs = args[:num_enters]
708      output_tas = args[num_enters:]
709      # TODO(agarwal): see which outputs have consumers and only populate the
710      # TensorArrays corresponding to those. Or do those paths get trimmed out
711      # from inside the while_loop body?
712      assert len(inputs) >= len(output_tas)
713      assert len(inputs) == len(inputs_stacked)
714
715      # Convert condition
716      with ops.name_scope("while_cond"):
717        # Note that we set cond_stacked to True here. At this point we don't
718        # know if it could be loop invariant, hence the conservative value is
719        # to assume stacked.
720        cond_pfor = self._init_pfor(
721            pfor_input.pfor,
722            indices,
723            cond_stacked=True,
724            inputs=inputs,
725            inputs_stacked=inputs_stacked)
726        conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition)
727        cond_is_stacked[0] = cond_stacked
728
729      # Recompute the new condition, write outputs of done iterations, and
730      # partition the inputs if needed.
731      if not cond_stacked:
732        (not_all_done, new_indices, new_inputs,
733         new_output_tas) = self._process_cond_unstacked(conditions, indices,
734                                                        inputs, output_tas)
735      else:
736        (not_all_done, new_indices, new_inputs,
737         new_output_tas) = self._process_cond_stacked(conditions, indices,
738                                                      inputs, inputs_stacked,
739                                                      output_tas)
740
741      # Convert body
742      with ops.name_scope("while_body"):
743        #  Compute the outputs from the body.
744        new_outputs = self._process_body(pfor_input, inputs_stacked,
745                                         new_indices, cond_stacked, new_inputs,
746                                         not_all_done)
747
748      # Note that the first num_outputs new values of inputs are computed using
749      # the body. Rest of them were direct Enters into the condition/body and
750      # the partitioning done earlier is sufficient to give the new value.
751      num_outputs = len(self._outputs)
752      new_args = ([not_all_done, new_indices] + new_outputs +
753                  list(new_inputs[num_outputs:]) + new_output_tas)
754      return tuple(new_args)
755
756    while_outputs = control_flow_ops.while_loop(
757        cond, body, init_values, shape_invariants=shape_invariants)
758    output_tas = while_outputs[-len(self._outputs):]
759    outputs = []
760    assert cond_is_stacked[0] is not None
761    for inp_stacked, ta in zip(inputs_stacked, output_tas):
762      if cond_is_stacked[0]:
763        outputs.append(wrap(ta.stack(), True))
764      else:
765        # Note that if while_loop condition is unstacked, all iterations exit at
766        # the same time and we wrote those outputs in index 0 of the tensor
767        # array.
768        outputs.append(wrap(ta.read(0), inp_stacked))
769    return outputs
770
771
772class ConversionNotImplementedError(Exception):
773  pass
774
775
776class _PforInput(object):
777  """Input object passed to registered pfor converters."""
778
779  __slots__ = ["pfor", "_op", "_inputs"]
780
781  def __init__(self, pfor, op, inputs):
782    """Creates a _PforInput object.
783
784    Args:
785      pfor: PFor converter object.
786      op: the Operation object that is being converted.
787      inputs: list of WrappedTensor objects representing converted values of the
788        inputs of `op`.
789    """
790    self.pfor = pfor
791    self._op = op
792    self._inputs = inputs
793
794  def stack_inputs(self, stack_indices=None, tile_variants=False):
795    """Stacks unstacked inputs at `stack_indices`.
796
797    Args:
798      stack_indices: indices of inputs at which stacking is done. If None,
799        stacking is done at all indices.
800      tile_variants: If True, affected indices which have a variant dtype will
801        be tiled after this operation to match the expected shape of a
802        vectorized tensor. Variants generally need to be un-tiled when they are
803        inputs to operations and tiled when returned.
804    """
805    if stack_indices is None:
806      stack_indices = range(len(self._inputs))
807    length = self.pfor.loop_len_vector
808    for i in stack_indices:
809      inp = self._inputs[i]
810      is_variant = inp.t.dtype == dtypes.variant
811      if not inp.is_stacked:
812        self._inputs[i] = _stack(inp.t, length)
813        if tile_variants and is_variant:
814          self._inputs[i] = wrap(
815              _tile_variant_with_length(self._inputs[i].t, length), True)
816      elif not tile_variants and is_variant:
817        self._inputs[i] = wrap(_untile_variant(self._inputs[i].t), True)
818
819  def expanddim_inputs_for_broadcast(self):
820    """Reshapes stacked inputs to prepare them for broadcast.
821
822    Since stacked inputs have an extra leading dimension, automatic broadcasting
823    rules could incorrectly try to expand dimensions before that leading
824    dimension. To avoid that, we reshape these stacked inputs to the maximum
825    rank they will need to be broadcasted to.
826    """
827    if not self._inputs:
828      return
829
830    # Find max rank
831    def _get_rank(x):
832      rank = array_ops.rank(x.t)
833      if not x.is_stacked:
834        rank += 1
835      return rank
836
837    ranks = [_get_rank(x) for x in self._inputs]
838    max_rank = ranks[0]
839    for rank in ranks[1:]:
840      max_rank = math_ops.maximum(rank, max_rank)
841
842    for i, inp in enumerate(self._inputs):
843      if inp.is_stacked:
844        shape = array_ops.shape(inp.t)
845        rank_diff = array_ops.reshape(max_rank - ranks[i], [1])
846        ones = array_ops.tile([1], rank_diff)
847        new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0)
848        self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True)
849
850  @property
851  def inputs(self):
852    return self._inputs
853
854  @property
855  def num_inputs(self):
856    return len(self._inputs)
857
858  def input(self, index):
859    assert len(self._inputs) > index, (index, self._inputs)
860    return self._inputs[index]
861
862  def stacked_input(self, index):
863    t, is_stacked, _ = self.input(index)
864    if not is_stacked:
865      op_type = self.op_type
866      op_def = getattr(self._op, "op_def", None)
867      if op_def is None:
868        input_name = "at index %d" % index
869      else:
870        input_name = "\"%s\"" % op_def.input_arg[index].name
871      raise ConversionNotImplementedError(
872          "Input %s of op \"%s\" expected to be not loop invariant" %
873          (input_name, op_type))
874    return t
875
876  def unstacked_input(self, index):
877    t, is_stacked, _ = self.input(index)
878    if is_stacked:
879      op_type = self.op_type
880      op_def = getattr(self._op, "op_def", None)
881      if op_def is None:
882        input_name = "at index %d" % index
883      else:
884        input_name = "\"%s\"" % op_def.input_arg[index].name
885      raise ConversionNotImplementedError(
886          "Input %s of op \"%s\" expected to be loop invariant" %
887          (input_name, op_type))
888    return t
889
890  @property
891  def op(self):
892    return self._op
893
894  @property
895  def op_type(self):
896    return self._op.type
897
898  def get_attr(self, attr):
899    return self._op.get_attr(attr)
900
901  @property
902  def outputs(self):
903    return self._op.outputs
904
905  def output(self, index):
906    assert index < len(self._op.outputs)
907    return self._op.outputs[index]
908
909
910_pfor_converter_registry = {}
911
912
913class RegisterPFor(object):
914  """Utility to register converters for pfor.
915
916  Usage:
917  @RegisterPFor(foo_op_type)
918  def _foo_converter(pfor_input):
919    ...
920
921  The above will register conversion function `_foo_converter` for handling
922  conversion of `foo_op_type`. These converters are called during vectorization
923  of a `pfor` loop body. For each operation node in this loop body,
924  the vectorization process will call the converter corresponding to the
925  operation type of the node.
926
927  During conversion, the registered function will be called with a single
928  argument `pfor_input`, of type `PForInput`, which will contain state needed
929  for the conversion.  When the converter is called for a node, all its inputs
930  should already have been converted and these converted values are stored in
931  `pfor_input.inputs`.  This registered function should output a list of
932  WrappedTensor objects with the same length as the number of outputs of the
933  node being converted. If the node had zero outputs, then it should return an
934  ops.Operation object.  These new sets of nodes should implement the
935  functionality of running that operation for the number of iterations specified
936  by `pfor_input.pfor.loop_len_vector[0]` where the inputs of the node for each
937  iteration are picked from `pfor_inputs.inputs()`.
938
939  One tricky aspect of the conversion process is keeping track of, and
940  leveraging loop invariance of computation. Each converted input is a
941  WrappedTensor which indicates whether the input was loop invariant or not. If
942  the converted value is loop invariant, its rank should match the rank of the
943  corresponding tensor in the loop body, else its rank is larger by 1. The
944  converter should look at the loop invariance of the inputs and generate new
945  nodes based on that. Note that the converter will not be called if all inputs
946  are loop invariant and the operation is not stateful. The converter should
947  determine if its own output is loop invariant and `wrap` its output
948  accordingly.
949
950  Example:
951
952  Here, the converter is trying to convert a Reshape node in the loop body. This
953  node will have two inputs: the tensor to reshape, and the new shape.  The
954  example here only handles the case where the shape is loop invariant.
955
956  @RegisterPFor("Reshape")
957  def _convert_reshape(pfor_input):
958    # We assume that input is not loop invariant. Call to `stacked_input`
959    # asserts that and returns the converted value. This value will have a rank
960    # larger by 1 compared to the rank of the input in the loop body.
961    t = pfor_input.stacked_input(0)
962
963    # We assume that shape input is loop invariant. Call to `unstacked_input`
964    # asserts that and returns the converted value.
965    shape = pfor_input.unstacked_input(1)
966
967    # We compute `new_shape` by prepending the number of iterations to the
968    # original shape.
969    new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape],
970                                 axis=0)
971
972    # The vectorized output involves reshaping the converted input `t` using
973    # `new_shape`.
974    new_output = array_ops.reshape(t, new_shape)
975
976    # The converted output is marked as not loop invariant using the call to
977    # wrap.
978    return wrap(new_output, True)
979  """
980
981  def __init__(self, op_type):
982    """Creates an object to register a converter for op with type `op_type`."""
983    self.op_type = op_type
984
985  def __call__(self, converter):
986    name = self.op_type
987    assert name not in _pfor_converter_registry, "Re-registering %s " % name
988    _pfor_converter_registry[name] = converter
989    return converter
990
991
992class RegisterPForWithArgs(RegisterPFor):
993  """Utility to register converters for pfor.
994
995  Usage:
996  @RegisteRPFor(foo_op_type, foo=value, ....)
997  def _foo_converter(pfor_input, foo=None, ....):
998    ...
999
1000  See RegisterPFor for details on the conversion function.
1001  `RegisterPForWithArgs` allows binding extra arguments to the
1002  conversion function at registration time.
1003  """
1004
1005  def __init__(self, op_type, *args, **kw_args):
1006    super(RegisterPForWithArgs, self).__init__(op_type)
1007    self._args = args
1008    self._kw_args = kw_args
1009
1010  def __call__(self, converter):
1011
1012    def _f(pfor_input):
1013      return converter(pfor_input, self.op_type, *self._args, **self._kw_args)
1014
1015    super(RegisterPForWithArgs, self).__call__(_f)
1016    return converter
1017
1018
1019# TODO(agarwal): call raw_ops instead of calling these low level routines.
1020def _create_op(op_type, inputs, op_dtypes, attrs=None):
1021  """Utility to create an op."""
1022  op = ops.get_default_graph().create_op(
1023      op_type, inputs, op_dtypes, attrs=attrs, compute_device=True)
1024  flat_attrs = []
1025  # The tape expects an alternating flat list of names and attribute values.
1026  for a in attrs:
1027    flat_attrs.append(str(a))
1028    flat_attrs.append(op.get_attr(str(a)))
1029  execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:])
1030  return op
1031
1032
1033WrappedTensor = collections.namedtuple("WrappedTensor",
1034                                       ["t", "is_stacked", "is_sparse_stacked"])
1035"""Wrapper around the result of a Tensor conversion.
1036
1037The additional fields are useful for keeping track of the conversion state as
1038data flows through the ops in the loop body. For every op whose output is a
1039Tensor, its converter should return either a WrappedTensor or a list of
1040WrappedTensors.
1041
1042Args:
1043  t: The converted tensor
1044  is_stacked: True if the tensor is stacked, i.e. represents the results of all
1045    the iterations of the loop, where each row i of the tensor corresponds to
1046    that op's output on iteration i of the loop. False if the tensor is not
1047    stacked, i.e. represents the result of the op on of a single iteration of
1048    the loop, where the result does not vary between iterations.
1049  is_sparse_stacked: True if the tensor corresponds to a component tensor
1050    (indices, values, or dense_shape) of a sparse tensor, and has been logically
1051    stacked via a sparse conversion.
1052"""
1053
1054
1055def wrap(tensor, is_stacked=True, is_sparse_stacked=False):
1056  """Helper to create a WrappedTensor object."""
1057  assert isinstance(is_stacked, bool)
1058  assert isinstance(is_sparse_stacked, bool)
1059  assert isinstance(tensor, ops.Tensor)
1060  assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is "
1061                                               "stacked via a sparse "
1062                                               "conversion, it must also be "
1063                                               "stacked.")
1064  return WrappedTensor(tensor, is_stacked, is_sparse_stacked)
1065
1066
1067def _wrap_and_tile_variants(tensor, length):
1068  if tensor.dtype == dtypes.variant:
1069    tensor = _tile_variant_with_length(tensor, length)
1070  return wrap(tensor)
1071
1072
1073def _fallback_converter(pfor_input, warn=True):
1074  if warn:
1075    logging.warn("Using a while_loop for converting %s", pfor_input.op_type)
1076  output_dtypes = [x.dtype for x in pfor_input.outputs]
1077  iters = pfor_input.pfor.loop_len_vector[0]
1078
1079  def while_body(i, *ta_list):
1080    """Body of while loop."""
1081    inputs = [
1082        x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs
1083    ]
1084    op_outputs = _create_op(
1085        pfor_input.op_type,
1086        inputs,
1087        output_dtypes,
1088        attrs=pfor_input.op.node_def.attr).outputs
1089
1090    outputs = []
1091    # TODO(agarwal): Add tf.debugging asserts to check that the shapes across
1092    # the different iterations are the same.
1093    for out, ta in zip(op_outputs, ta_list):
1094      assert isinstance(out, ops.Tensor)
1095      outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
1096    return tuple([i + 1] + outputs)
1097
1098  ta_list = control_flow_ops.while_loop(
1099      lambda i, *ta: i < iters, while_body, [0] +
1100      [tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes
1101      ])[1:]
1102  return tuple([wrap(ta.concat(), True) for ta in ta_list])
1103
1104
1105class PForConfig(object):
1106  """A configuration object used to communicate with loop body function."""
1107
1108  def __init__(self):
1109    # This may be set to the number of iterations.
1110    self._maybe_iters = None
1111    # Map from reduction node, created by `reduce`, to the bundle of reduction
1112    # function and arguments.
1113    self._reduce_map = {}
1114
1115  def _has_reductions(self):
1116    """True if some reductions where performed by loop body."""
1117    return len(self._reduce_map)
1118
1119  def _set_iters(self, iters):
1120    """Set number of pfor iterations."""
1121    if isinstance(iters, ops.Tensor):
1122      iters = tensor_util.constant_value(iters)
1123    self._maybe_iters = iters
1124
1125  def reduce(self, fn, *args):
1126    """Performs reduction `fn` on `args` vectorized across pfor iterations.
1127
1128    Note that `fn` is traced once inside the loop function context. Hence any
1129    captures or side-effects will happen in that context. Call to the traced
1130    version of `fn` happens during the construction of the vectorized code.
1131
1132    Note that this currently may not work inside a control flow construct.
1133    Args:
1134      fn: a reduction function. It will be called with arguments that have the
1135        same structure as *args but with individual values whose rank may be
1136        higher by 1 since they represent loop invariant vectorized versions of
1137        the corresponding Tensors in *args.
1138      *args: unvectorized Tensors.
1139
1140    Returns:
1141      The result of running `fn` on the vectorized versions of `*args`. These
1142      outputs will be available as loop invariant values to all the iterations.
1143    """
1144    assert not context.executing_eagerly()
1145    # Creates a concrete function that will be used for reduction.
1146    tensor_specs = []
1147    for arg in args:
1148      if not isinstance(arg, ops.Tensor):
1149        raise ValueError("Got a non-Tensor argument %s in reduce" % arg)
1150      batched_shape = tensor_shape.TensorShape([self._maybe_iters
1151                                               ]).concatenate(arg.shape)
1152      tensor_specs.append(
1153          tensor_spec.TensorSpec(shape=batched_shape, dtype=arg.dtype))
1154    concrete_function = def_function.function(fn).get_concrete_function(
1155        *tensor_specs)
1156
1157    # Creates PlaceholderWithDefault and IdentityN nodes corresponding the
1158    # reduction.
1159    pl_outputs = []
1160    with ops.control_dependencies(args):
1161      for output in concrete_function.outputs:
1162        if not isinstance(output, ops.Tensor):
1163          raise ValueError("Got a non-Tensor output %s while running reduce" %
1164                           output)
1165        # Note that we use placeholder_with_default just to make XLA happy since
1166        # it does not like placeholder ops.
1167        if output.shape.is_fully_defined():
1168          dummy = array_ops.zeros(output.shape.as_list(), dtype=output.dtype)
1169          pl_outputs.append(
1170              array_ops.placeholder_with_default(dummy, shape=output.shape))
1171        else:
1172          # TODO(agarwal): support case when under XLA and output.shape is not
1173          # fully defined.
1174          pl_outputs.append(
1175              array_ops.placeholder(output.dtype, shape=output.shape))
1176
1177      reduction_op = array_ops.identity_n(pl_outputs)[0].op
1178    self._reduce_map[reduction_op] = (concrete_function, args)
1179    if len(reduction_op.outputs) == 1:
1180      return reduction_op.outputs[0]
1181    else:
1182      return tuple(reduction_op.outputs)
1183
1184  # TODO(agarwal): handle reductions inside control flow constructs.
1185  def reduce_concat(self, x):
1186    """Performs a concat reduction on `x` across pfor iterations.
1187
1188    Note that this currently may not work inside a control flow construct.
1189    Args:
1190      x: an unvectorized Tensor.
1191
1192    Returns:
1193      A Tensor that has rank one higher than `x`. The value is the vectorized
1194      version of `x`, i.e. stacking the value of `x` across different pfor
1195      iterations.
1196    """
1197    return self.reduce(lambda y: y, x)
1198
1199  def reduce_mean(self, x):
1200    """Performs a mean reduction on `x` across pfor iterations.
1201
1202    Note that this currently may not work inside a control flow construct.
1203    Args:
1204      x: an unvectorized Tensor.
1205
1206    Returns:
1207      A Tensor that has same rank as `x`. The value is the mean of the values
1208      of `x` across the pfor iterations.
1209    """
1210    return self.reduce(lambda y: math_ops.reduce_mean(y, axis=0), x)
1211
1212  def reduce_sum(self, x):
1213    """Performs a sum reduction on `x` across pfor iterations.
1214
1215    Note that this currently may not work inside a control flow construct.
1216    Args:
1217      x: an unvectorized Tensor.
1218
1219    Returns:
1220      A Tensor that has same rank as `x`. The value is the sum of the values
1221      of `x` across the pfor iterations.
1222    """
1223    return self.reduce(lambda y: math_ops.reduce_sum(y, axis=0), x)
1224
1225  def _lookup_reduction(self, t):
1226    """Lookups Tensor `t` in the reduction maps."""
1227    assert isinstance(t, ops.Tensor), t
1228    return self._reduce_map.get(t.op)
1229
1230
1231class PFor(object):
1232  """Implementation of rewrite of parallel-for loops.
1233
1234  This class takes a DAG or a set of DAGs representing the body of a
1235  parallel-for loop, and adds new operations to the graph that implements
1236  functionality equivalent to running that loop body for a specified number of
1237  iterations. This new set of nodes may or may not use a tensorflow loop
1238  construct.
1239
1240  The process of conversion does not delete or change any existing operations.
1241  It only adds operations that efficiently implement the equivalent
1242  functionality. We refer to the added ops as "converted ops".
1243
1244  The conversion process uses a simple greedy heuristic. It walks the loop body
1245  and tries to express the functionality of running each node in a loop with a
1246  new set of nodes. When converting an op several cases are possible:
1247  - The op is not inside the loop body. Hence it can be used as is.
1248  - The op does not depend on the iteration number and is stateless. In this
1249    case, it can be used as is.
1250  - The op is not stateful, and depends on iteration number only through control
1251    dependencies. In this case, we can create a single op with same inputs and
1252    attributes, but with "converted" control dependencies.
1253  - The op is not stateful, and all its inputs are loop invariant. In this
1254    case, similar to above, we can create a single op with same inputs and
1255    attributes, but with "converted" control dependencies.
1256  - The op is stateful or at least one of the inputs is not loop invariant. In
1257    this case, we run the registered converter for that op to create a set of
1258    converted ops. All nodes in the set will have converted control dependencies
1259    corresponding to control dependencies of the original op. If the op returned
1260    multiple outputs, "converted outputs" could be produced by different ops in
1261    this set.
1262  """
1263
1264  def __init__(self,
1265               loop_var,
1266               loop_len,
1267               pfor_ops,
1268               fallback_to_while_loop,
1269               all_indices=None,
1270               all_indices_partitioned=False,
1271               pfor_config=None):
1272    """Creates an object to rewrite a parallel-for loop.
1273
1274    Args:
1275      loop_var: ops.Tensor output of a Placeholder operation. The value should
1276        be an int32 scalar representing the loop iteration number.
1277      loop_len: A scalar or scalar Tensor representing the number of iterations
1278        the loop is run for.
1279      pfor_ops: List of all ops inside the loop body.
1280      fallback_to_while_loop: If True, on failure to vectorize an op, a while
1281        loop is used to sequentially execute that op.
1282      all_indices: If not None, an int32 vector with size `loop_len`
1283        representing the iteration ids that are still active. These values
1284        should be unique and sorted. However they may not be contiguous. This is
1285        typically the case when inside a control flow construct which has
1286        partitioned the indices of the iterations that are being converted.
1287      all_indices_partitioned: If True, this object is being constructed from a
1288        control flow construct where not all the pfor iterations are guaranteed
1289        to be active.
1290      pfor_config: PForConfig object used while constructing the loop body.
1291    """
1292    assert isinstance(loop_var, ops.Tensor)
1293    assert loop_var.op.type == "PlaceholderWithDefault"
1294    self._loop_var = loop_var
1295    loop_len_value = tensor_util.constant_value(loop_len)
1296    if loop_len_value is not None:
1297      loop_len = loop_len_value
1298    self._loop_len_vector = array_ops.reshape(loop_len, [1])
1299    self._all_indices_partitioned = all_indices_partitioned
1300    if all_indices_partitioned:
1301      assert all_indices is not None
1302    self.all_indices = (
1303        math_ops.range(loop_len) if all_indices is None else all_indices)
1304
1305    self._conversion_map = object_identity.ObjectIdentityDictionary()
1306    self._conversion_map[loop_var] = wrap(self.all_indices, True)
1307    self._pfor_ops = set(pfor_ops)
1308    self._pfor_op_ids = set(x._id for x in pfor_ops)
1309    self._fallback_to_while_loop = fallback_to_while_loop
1310    self._pfor_config = pfor_config
1311
1312  def op_is_inside_loop(self, op):
1313    """True if op was created inside the pfor loop body."""
1314    assert isinstance(op, ops.Operation)
1315    # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
1316    # since it appears there tensorflow API could return different python
1317    # objects representing the same Operation node.
1318    return op._id in self._pfor_op_ids
1319
1320  def _convert_sparse(self, y):
1321    """Returns the converted value corresponding to SparseTensor y.
1322
1323    For SparseTensors, instead of stacking the component tensors separately,
1324    resulting in component tensors with shapes (N, m, rank), (N, m), and (N,
1325    rank) respectively for indices, values, and dense_shape (where N is the loop
1326    length and m is the number of sparse tensor values per loop iter), we want
1327    to logically stack the SparseTensors, to create a SparseTensor whose
1328    components are size (N * m, rank + 1), (N * m, ), and (rank + 1,)
1329    respectively.
1330
1331    Here, we try to get the conversion of each component tensor.
1332    If the tensors are stacked via a sparse conversion, return the resulting
1333    SparseTensor composed of the converted components. Otherwise, the component
1334    tensors are either unstacked or stacked naively. In the latter case, we
1335    unstack the component tensors to reform loop_len SparseTensor elements,
1336    then correctly batch them.
1337
1338    The unstacked tensors must have the same rank. Each dimension of each
1339    SparseTensor will expand to be the largest among all SparseTensor elements
1340    for that dimension. For example, if there are N SparseTensors of rank 3
1341    being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i),
1342    the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)).
1343
1344    Args:
1345      y: A tf.sparse.SparseTensor.
1346
1347    Returns:
1348      A tf.sparse.SparseTensor that is the converted value corresponding to y.
1349    """
1350    outputs = [
1351        self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape)
1352    ]
1353    assert all(isinstance(o, WrappedTensor) for o in outputs)
1354
1355    if all(w.is_sparse_stacked for w in outputs):
1356      return sparse_tensor.SparseTensor(*[w.t for w in outputs])
1357
1358    assert not any(w.is_sparse_stacked for w in outputs), (
1359        "Error converting SparseTensor. All components should be logically "
1360        "stacked, or none.")
1361
1362    # If component tensors were not sparsely stacked, they are either unstacked
1363    # or stacked without knowledge that they are components of sparse tensors.
1364    # In this case, we have to restack them.
1365    return self._restack_sparse_tensor_logically(
1366        *[self._unwrap_or_tile(w) for w in outputs])
1367
1368  def _restack_sparse_tensor_logically(self, indices, values, shape):
1369    sparse_tensor_rank = indices.get_shape().dims[-1].value
1370    if sparse_tensor_rank is not None:
1371      sparse_tensor_rank += 1
1372
1373    def fn(args):
1374      res = gen_sparse_ops.serialize_sparse(
1375          args[0], args[1], args[2], out_type=dtypes.variant)
1376      return res
1377
1378    # Applies a map function to the component tensors to serialize each
1379    # sparse tensor element and batch them all, then deserializes the batch.
1380    # TODO(rachelim): Try to do this without map_fn -- add the right offsets
1381    # to shape and indices tensors instead.
1382    result = map_fn.map_fn(fn, [indices, values, shape], dtype=dtypes.variant)
1383    return sparse_ops.deserialize_sparse(
1384        result, dtype=values.dtype, rank=sparse_tensor_rank)
1385
1386  def _unwrap_or_tile(self, wrapped_tensor):
1387    """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it."""
1388    output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked
1389    if is_stacked:
1390      return output
1391    else:
1392      return _stack(output, self._loop_len_vector).t
1393
1394  def convert(self, y):
1395    """Returns the converted value corresponding to y.
1396
1397    Args:
1398      y: A ops.Tensor or a ops.Operation object. If latter, y should not have
1399        any outputs.
1400
1401    Returns:
1402      If y does not need to be converted, it returns y as is. Else it returns
1403      the "converted value" corresponding to y.
1404    """
1405    if y is None:
1406      return None
1407    if isinstance(y, sparse_tensor.SparseTensor):
1408      return self._convert_sparse(y)
1409    assert isinstance(y, (ops.Tensor, ops.Operation)), y
1410    output = self._convert_helper(y)
1411    if isinstance(output, WrappedTensor):
1412      assert isinstance(y, ops.Tensor)
1413      return self._unwrap_or_tile(output)
1414    else:
1415      assert isinstance(y, ops.Operation)
1416      assert not y.outputs
1417      assert isinstance(output, ops.Operation)
1418    return output
1419
1420  def _was_converted(self, t):
1421    """True if t is not a conversion of itself."""
1422    converted_t = self._conversion_map[t]
1423    return converted_t.t is not t
1424
1425  def _add_conversion(self, old_output, new_output):
1426    assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output
1427    assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output
1428    self._conversion_map[old_output] = new_output
1429
1430  def _convert_reduction(self, y):
1431    # Handle reductions.
1432    if self._pfor_config is None or isinstance(y, ops.Operation):
1433      return None
1434    reduction = self._pfor_config._lookup_reduction(y)
1435    if reduction is None:
1436      return None
1437    (reduction_fn, reduction_args) = reduction
1438    batched_args = []
1439    for reduction_arg in reduction_args:
1440      assert isinstance(reduction_arg, ops.Tensor), reduction_arg
1441      # Tensor being reduced should already be converted due to a control
1442      # dependency on the created placeholder.
1443      # Note that in cases where reduction_arg is in an outer context, one
1444      # needs to locate the corresponding Enter node and use that to lookup
1445      # the conversion.
1446      # TODO(agarwal): handle reductions inside control flow constructs.
1447      assert reduction_arg in self._conversion_map, (
1448          "Unable to handle reduction of %s, possibly as it was used "
1449          "inside a control flow construct. Note that reductions across "
1450          "pfor iterations are currently not supported inside control flow "
1451          "constructs." % reduction_arg)
1452      batched_arg = self._conversion_map[reduction_arg]
1453      batched_args.append(self._unwrap_or_tile(batched_arg))
1454    outputs = reduction_fn(*batched_args)
1455    return [wrap(output, False) for output in nest.flatten(outputs)]
1456
1457  def _convert_helper(self, op_or_tensor):
1458    stack = collections.deque([op_or_tensor])
1459    while stack:
1460      y = stack[0]
1461      if y in self._conversion_map:
1462        assert isinstance(self._conversion_map[y],
1463                          (WrappedTensor, ops.Operation))
1464        stack.popleft()
1465        continue
1466      if isinstance(y, ops.Operation):
1467        assert not y.outputs, (
1468            "We only support converting Operation objects with no outputs. "
1469            "Got %s", y)
1470        y_op = y
1471      else:
1472        assert isinstance(y, ops.Tensor), y
1473        y_op = y.op
1474
1475      is_while_loop = y_op.type == "Exit"
1476      if is_while_loop:
1477        while_op = WhileOp(
1478            y, pfor_ops=self._pfor_ops,
1479            fallback_to_while_loop=self.fallback_to_while_loop,
1480            pfor_config=self._pfor_config)
1481        is_inside_loop = while_op.is_inside_loop
1482        # If all nodes in the while_loop graph were created inside the pfor, we
1483        # treat the whole loop subgraph as a single op (y_op) and try to convert
1484        # it. For while_loops that are created completely or partially outside,
1485        # we treat them as external and should be able to simply return the Exit
1486        # node output as is without needing any conversion. Note that for
1487        # while_loops that are partially constructed inside, we assume they will
1488        # be loop invariant. If that is not the case, it will create runtime
1489        # errors since the converted graph would depend on the self._loop_var
1490        # placeholder.
1491        if is_inside_loop:
1492          y_op = while_op
1493      else:
1494        is_inside_loop = self.op_is_inside_loop(y_op)
1495
1496      # If this op was not created inside the loop body, we will return as is.
1497      # 1. Convert inputs and control inputs.
1498
1499      def _add_to_stack(x):
1500        if x not in self._conversion_map:
1501          stack.appendleft(x)
1502          return True
1503        else:
1504          return False
1505
1506      if is_inside_loop:
1507        added_to_stack = False
1508        for inp in y_op.inputs:
1509          added_to_stack |= _add_to_stack(inp)
1510        for cinp in y_op.control_inputs:
1511          if cinp.outputs:
1512            for t in cinp.outputs:
1513              added_to_stack |= _add_to_stack(t)
1514          else:
1515            added_to_stack |= _add_to_stack(cinp)
1516        if added_to_stack:
1517          continue
1518
1519        converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs]
1520        some_input_converted = any(self._was_converted(x) for x in y_op.inputs)
1521        some_input_stacked = any(x.is_stacked for x in converted_inputs)
1522
1523        converted_control_ops = set()
1524        some_control_input_converted = False
1525        for cinp in y_op.control_inputs:
1526          if cinp.outputs:
1527            for t in cinp.outputs:
1528              converted_t = self._conversion_map[t]
1529              if self._was_converted(t):
1530                some_control_input_converted = True
1531              converted_control_ops.add(converted_t.t.op)
1532          else:
1533            converted_cinp = self._conversion_map[cinp]
1534            assert isinstance(converted_cinp, ops.Operation)
1535            if converted_cinp != cinp:
1536              some_control_input_converted = True
1537            converted_control_ops.add(converted_cinp)
1538        converted_control_ops = list(converted_control_ops)
1539        is_stateful = _is_stateful_pfor_op(y_op)
1540      else:
1541        converted_inputs = []
1542        converted_control_ops = []
1543      logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op,
1544                   converted_inputs, converted_control_ops)
1545
1546      # 2. Convert y_op
1547      # If converting a while_loop, we let the while_loop convertor deal with
1548      # putting the control dependencies appropriately.
1549      control_dependencies = [] if is_while_loop else converted_control_ops
1550      with ops.control_dependencies(control_dependencies), ops.name_scope(
1551          y_op.name + "/pfor/"), ops.get_default_graph()._original_op(y_op):
1552        # Op is a placeholder for a reduction.
1553        reduce_output = self._convert_reduction(y)
1554        if reduce_output is not None:
1555          new_outputs = reduce_output
1556        # None of the inputs and control inputs were converted.
1557        elif ((not is_inside_loop or
1558               (not is_stateful and not some_input_converted and
1559                not some_control_input_converted)) and
1560              y.graph == ops.get_default_graph()):
1561          if y is y_op:
1562            assert not isinstance(y_op, WhileOp)
1563            new_outputs = y_op
1564          else:
1565            new_outputs = [wrap(x, False) for x in y_op.outputs]
1566        elif not (is_stateful or is_while_loop or some_input_stacked):
1567          # All inputs are unstacked or unconverted but some control inputs are
1568          # converted.
1569          # TODO(rachelim): Handle the case where some inputs are sparsely
1570          # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs))
1571          new_op = _create_op(y_op.type, [x.t for x in converted_inputs],
1572                              [x.dtype for x in y_op.outputs],
1573                              y_op.node_def.attr)
1574          if y is y_op:
1575            new_outputs = new_op
1576          else:
1577            new_outputs = []
1578            for old_output, new_output in zip(y_op.outputs, new_op.outputs):
1579              custom_gradient.copy_handle_data(old_output, new_output)
1580              new_outputs.append(wrap(new_output, False))
1581        else:
1582          # Either some inputs are not loop invariant or op is stateful.
1583          if hasattr(y_op, "pfor_converter"):
1584            converter = y_op.pfor_converter
1585          else:
1586            converter = _pfor_converter_registry.get(y_op.type, None)
1587          if converter is None:
1588            has_variant_outputs = any(x.dtype == dtypes.variant for x in
1589                                      y_op.outputs)
1590            if self._fallback_to_while_loop and not has_variant_outputs:
1591              converter = _fallback_converter
1592            else:
1593              message = ("No pfor vectorization defined for %s\n"
1594                         "%s\n"
1595                         "inputs: %s. " %
1596                         (y_op.type, y_op, converted_inputs))
1597              if not self._fallback_to_while_loop:
1598                message += ("Consider enabling the fallback_to_while_loop "
1599                            "option to pfor, which may run slower.")
1600              raise ValueError(message)
1601          # TODO(rachelim): Handle the case where some inputs are sparsely
1602          # stacked. We should only call the converter if it supports handling
1603          # those inputs.
1604          pfor_inputs = _PforInput(self, y_op, converted_inputs)
1605          try:
1606            try:
1607              new_outputs = converter(pfor_inputs)
1608            except ConversionNotImplementedError as e:
1609              if self._fallback_to_while_loop:
1610                new_outputs = _fallback_converter(pfor_inputs)
1611              else:
1612                six.reraise(ValueError, ValueError(str(e)), sys.exc_info()[2])
1613          except Exception as e:  # pylint: disable=broad-except
1614            logging.error(
1615                "Got error while pfor was converting op %s"
1616                "with inputs %s\n, converted inputs %s\n"
1617                "%s\n"
1618                "Here are the pfor conversion stack traces:", y_op,
1619                y_op.inputs[:], pfor_inputs.inputs, str(e))
1620            original_op = y_op
1621            while isinstance(original_op, ops.Operation):
1622              logging.error(
1623                  "%s\ncreated at:\n  %s", original_op,
1624                  "  ".join(traceback.format_list(original_op.traceback)))
1625              original_op = original_op._original_op
1626            six.reraise(e.__class__, e, sys.exc_info()[2])
1627
1628          if isinstance(new_outputs, WrappedTensor):
1629            new_outputs = [new_outputs]
1630          assert isinstance(new_outputs,
1631                            (list, tuple, ops.Operation)), new_outputs
1632        logging.vlog(2, "converted %s %s", y_op, new_outputs)
1633
1634        # Insert into self._conversion_map
1635        if y is y_op:
1636          assert isinstance(new_outputs, ops.Operation)
1637          self._add_conversion(y_op, new_outputs)
1638        else:
1639          assert len(y_op.outputs) == len(new_outputs), (y_op, y_op.outputs,
1640                                                         new_outputs)
1641          for old_output, new_output in zip(y_op.outputs, new_outputs):
1642            assert isinstance(new_output, WrappedTensor), (new_output, y, y_op)
1643            assert old_output.dtype == new_output.t.dtype, (new_output, y, y_op)
1644            # Set shape for converted output.
1645            output_shape = old_output.shape
1646            if not new_output.is_sparse_stacked:
1647              if new_output.is_stacked:
1648                loop_len = tensor_util.constant_value(self.loop_len_vector)
1649                if loop_len is None:
1650                  batch_dim = tensor_shape.TensorShape([None])
1651                else:
1652                  batch_dim = tensor_shape.TensorShape(loop_len)
1653                output_shape = batch_dim.concatenate(output_shape)
1654              if _is_variant_with_internal_stacking(new_output.t):
1655                new_output.t.set_shape([])
1656              else:
1657                new_output.t.set_shape(output_shape)
1658            self._add_conversion(old_output, new_output)
1659        stack.popleft()
1660
1661    return self._conversion_map[op_or_tensor]
1662
1663  @property
1664  def loop_len_vector(self):
1665    """Returns a single element vector whose value is number of iterations."""
1666    return self._loop_len_vector
1667
1668  @property
1669  def loop_var(self):
1670    """Returns placeholder loop variable."""
1671    return self._loop_var
1672
1673  @property
1674  def pfor_ops(self):
1675    return self._pfor_ops
1676
1677  @property
1678  def pfor_config(self):
1679    return self._pfor_config
1680
1681  @property
1682  def all_indices_partitioned(self):
1683    """all_indices_partitioned property.
1684
1685    Returns:
1686      True if we are inside a control flow construct and not all pfor iterations
1687      may be active.
1688    """
1689    return self._all_indices_partitioned
1690
1691  @property
1692  def fallback_to_while_loop(self):
1693    return self._fallback_to_while_loop
1694
1695
1696# The code below defines converters for different operations. Please see comment
1697# for RegisterPFor to see how converters should be defined.
1698
1699
1700# image_ops
1701
1702
1703@RegisterPFor("AdjustContrastv2")
1704def _convert_adjust_contrastv2(pfor_input):
1705  images = pfor_input.stacked_input(0)
1706  contrast_factor = pfor_input.unstacked_input(1)
1707  return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True)
1708
1709
1710@RegisterPFor("AdjustHue")
1711def _convert_adjust_hue(pfor_input):
1712  images = pfor_input.stacked_input(0)
1713  delta = pfor_input.unstacked_input(1)
1714  return wrap(gen_image_ops.adjust_hue(images, delta), True)
1715
1716
1717@RegisterPFor("AdjustSaturation")
1718def _convert_adjust_saturation(pfor_input):
1719  images = pfor_input.stacked_input(0)
1720  scale = pfor_input.unstacked_input(1)
1721  return wrap(gen_image_ops.adjust_saturation(images, scale), True)
1722
1723
1724# nn_ops
1725
1726
1727def _flatten_first_two_dims(x):
1728  """Merges first two dimensions."""
1729  old_shape = array_ops.shape(x)
1730  new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0)
1731  return array_ops.reshape(x, new_shape)
1732
1733
1734def _unflatten_first_dim(x, first_dim):
1735  """Splits first dimension into [first_dim, -1]."""
1736  old_shape = array_ops.shape(x)
1737  new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0)
1738  return array_ops.reshape(x, new_shape)
1739
1740
1741def _inputs_with_flattening(pfor_input, input_indices):
1742  """Stacks and flattens first dim of inputs at indices `input_indices`."""
1743  if input_indices is None:
1744    input_indices = []
1745  pfor_input.stack_inputs(stack_indices=input_indices)
1746  inputs = []
1747  for i in range(pfor_input.num_inputs):
1748    if i in input_indices:
1749      inp = pfor_input.stacked_input(i)
1750      inp = _flatten_first_two_dims(inp)
1751    else:
1752      inp = pfor_input.unstacked_input(i)
1753    inputs.append(inp)
1754  return inputs
1755
1756
1757@RegisterPForWithArgs("Conv2D", dims=[0])
1758@RegisterPForWithArgs("DepthToSpace", dims=[0])
1759@RegisterPForWithArgs("AvgPool", dims=[0])
1760@RegisterPForWithArgs("AvgPool3D", dims=[0])
1761@RegisterPForWithArgs("MaxPool", dims=[0])
1762@RegisterPForWithArgs("MaxPoolV2", dims=[0])
1763@RegisterPForWithArgs("MaxPool3D", dims=[0])
1764@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2])
1765@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2])
1766@RegisterPForWithArgs("MaxPoolGradV2", dims=[0, 1, 2])
1767@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2])
1768@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2])
1769@RegisterPForWithArgs("MaxPoolGradGradV2", dims=[0, 1, 2])
1770@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1])
1771@RegisterPForWithArgs("SparseSoftmaxCrossEntropyWithLogits", dims=[0, 1])
1772@RegisterPForWithArgs("SpaceToDepth", dims=[0])
1773def _convert_flatten_batch(pfor_input, op_type, dims):
1774  del op_type
1775  inputs = _inputs_with_flattening(pfor_input, dims)
1776  outputs = _create_op(
1777      pfor_input.op_type,
1778      inputs, [x.dtype for x in pfor_input.outputs],
1779      attrs=pfor_input.op.node_def.attr).outputs
1780  n = pfor_input.pfor.loop_len_vector
1781  outputs = [_unflatten_first_dim(x, n) for x in outputs]
1782  return [wrap(x, True) for x in outputs]
1783
1784
1785_channel_flatten_input_cache = {}
1786
1787
1788@RegisterPFor("BatchToSpaceND")
1789def _convert_batch_to_space_nd(pfor_input):
1790  inp = pfor_input.stacked_input(0)
1791  block_shape = pfor_input.unstacked_input(1)
1792  crops = pfor_input.unstacked_input(2)
1793
1794  inp_shape = array_ops.shape(inp)
1795  n = pfor_input.pfor.loop_len_vector
1796
1797  # Reshape and transpose to move the vectorization axis inside the axes that
1798  # will move to space.
1799  # Reshape to 4D and transpose
1800  block_size = math_ops.reduce_prod(block_shape)
1801  new_shape = [n[0], block_size, inp_shape[1] // block_size, -1]
1802  inp = array_ops.reshape(inp, new_shape)
1803  inp = array_ops.transpose(inp, [1, 0, 2, 3])
1804  # Reshape back to merge the block, vectorization and batch dimension, and
1805  # restore the other dimensions.
1806  new_shape = array_ops.concat([n * inp_shape[1], inp_shape[2:]], axis=0)
1807  inp = array_ops.reshape(inp, new_shape)
1808  # Call batch_to_space and then split the new batch axis.
1809  output = gen_array_ops.batch_to_space_nd(inp, block_shape, crops)
1810  output = _unflatten_first_dim(output, n)
1811  return wrap(output, True)
1812
1813
1814@RegisterPFor("SpaceToBatchND")
1815def _convert_space_to_batch_nd(pfor_input):
1816  inp = pfor_input.stacked_input(0)
1817  block_shape = pfor_input.unstacked_input(1)
1818  paddings = pfor_input.unstacked_input(2)
1819
1820  n = pfor_input.pfor.loop_len_vector
1821  inp_shape = array_ops.shape(inp)
1822  inp = _flatten_first_two_dims(inp)
1823  output = gen_array_ops.space_to_batch_nd(inp, block_shape, paddings)
1824  output_shape = array_ops.shape(output)
1825  block_size = math_ops.reduce_prod(block_shape)
1826  new_shape = [block_size, n[0], -1]
1827  output = array_ops.reshape(output, new_shape)
1828  output = array_ops.transpose(output, [1, 0, 2])
1829  new_shape = array_ops.concat(
1830      [n, block_size * inp_shape[1:2], output_shape[1:]], axis=0)
1831  output = array_ops.reshape(output, new_shape)
1832  return wrap(output, True)
1833
1834
1835def _channel_flatten_input(x, data_format):
1836  """Merge the stack dimension with the channel dimension.
1837
1838  If S is pfor's stacking dimension, then,
1839    - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose
1840      should be cheap.
1841    - for SNHWC, we transpose to NHWCS.
1842  We then merge the S and C dimension.
1843
1844  Args:
1845    x: ops.Tensor to transform.
1846    data_format: "NCHW" or "NHWC".
1847
1848  Returns:
1849    A 3-element tuple with the transformed value, along with the shape for
1850    reshape and order for transpose required to transform back.
1851  """
1852
1853  graph = ops.get_default_graph()
1854  cache_key = (graph, x.ref(), data_format)
1855  if cache_key not in _channel_flatten_input_cache:
1856    x_shape = array_ops.shape(x)
1857    if data_format == b"NCHW":
1858      order = [1, 0, 2, 3, 4]
1859      shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0)
1860      reverse_order = order
1861    else:
1862      order = [1, 2, 3, 0, 4]
1863      shape = array_ops.concat([x_shape[1:4], [-1]], axis=0)
1864      reverse_order = [3, 0, 1, 2, 4]
1865    # Move S dimension next to C dimension.
1866    x = array_ops.transpose(x, order)
1867    reverse_shape = array_ops.shape(x)
1868    # Reshape to merge the S and C dimension.
1869    x = array_ops.reshape(x, shape)
1870    outputs = x, reverse_order, reverse_shape
1871    _channel_flatten_input_cache[cache_key] = outputs
1872  else:
1873    outputs = _channel_flatten_input_cache[cache_key]
1874  return outputs
1875
1876
1877# Note that with training=True, running FusedBatchNormV3 on individual examples
1878# is very different from running FusedBatchNormV3 on a batch of those examples.
1879# This is because, for the latter case, the operation can be considered as first
1880# computing the mean and variance over all the examples and then using these
1881# to scale all those examples. This creates a data dependency between these
1882# different "iterations" since the inputs to the scaling step depends on the
1883# statistics coming from all these inputs.
1884# As with other kernels, the conversion here effectively runs the kernel
1885# independently for each iteration, and returns outputs by stacking outputs from
1886# each of those iterations.
1887@RegisterPFor("FusedBatchNormV3")
1888def _convert_fused_batch_norm(pfor_input):
1889  is_training = pfor_input.get_attr("is_training")
1890  # When BatchNorm is used with training=False, mean and variance are provided
1891  # externally and used as is by the op. Thus, we can merge the S and N
1892  # dimensions as we do for regular operations.
1893  # When BatchNorm is used with training=True, mean and variance are computed
1894  # for each channel across the batch dimension (first one). If we merge S and N
1895  # dimensions, mean and variances will be computed over a larger set. So, we
1896  # merge the S and C dimensions instead.
1897  if not is_training:
1898    # We return zeros for batch_mean and batch_variance output. Note that CPU
1899    # and GPU seem to have different behavior for those two outputs. CPU outputs
1900    # zero because these values are not used during inference. GPU outputs
1901    # something, probably real means and variances.
1902    inputs = _inputs_with_flattening(pfor_input, [0])
1903    outputs = _create_op(
1904        pfor_input.op_type,
1905        inputs, [x.dtype for x in pfor_input.outputs],
1906        attrs=pfor_input.op.node_def.attr).outputs
1907    y = outputs[0]
1908    n = pfor_input.pfor.loop_len_vector
1909    y = _unflatten_first_dim(y, n)
1910    mean = pfor_input.unstacked_input(3)
1911    zeros = array_ops.zeros_like(mean)
1912    return [wrap(y, True)] + [wrap(zeros, False)] * 5
1913
1914  pfor_input.stack_inputs()
1915  data_format = pfor_input.get_attr("data_format")
1916  # We merge the first dimension with the "C" dimension, run FusedBatchNormV3,
1917  # and then transpose back.
1918  x = pfor_input.stacked_input(0)
1919  x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format)
1920  # Note that we stack all the other inputs as well so that they are the same
1921  # size as the new size of the channel dimension.
1922  inputs = [x] + [
1923      array_ops.reshape(pfor_input.stacked_input(i), [-1])
1924      for i in range(1, pfor_input.num_inputs)
1925  ]
1926  outputs = _create_op(
1927      pfor_input.op_type,
1928      inputs, [x.dtype for x in pfor_input.outputs],
1929      attrs=pfor_input.op.node_def.attr).outputs
1930  y = outputs[0]
1931  y = array_ops.reshape(y, reverse_shape)
1932  y = array_ops.transpose(y, reverse_order)
1933  n = pfor_input.pfor.loop_len_vector
1934  outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
1935  outputs = [y] + outputs
1936  return [wrap(x, True) for x in outputs]
1937
1938
1939@RegisterPFor("FusedBatchNormGradV3")
1940def _convert_fused_batch_norm_grad(pfor_input):
1941  pfor_input.stack_inputs()
1942  data_format = pfor_input.get_attr("data_format")
1943  y_backprop = pfor_input.stacked_input(0)
1944  y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format)
1945  x = pfor_input.stacked_input(1)
1946  x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format)
1947  inputs = [y_backprop, x] + [
1948      array_ops.reshape(pfor_input.stacked_input(i), [-1])
1949      for i in range(2, pfor_input.num_inputs)
1950  ]
1951  outputs = _create_op(
1952      pfor_input.op_type,
1953      inputs, [x.dtype for x in pfor_input.outputs],
1954      attrs=pfor_input.op.node_def.attr).outputs
1955  x_backprop = outputs[0]
1956  x_backprop = array_ops.reshape(x_backprop, x_reverse_shape)
1957  x_backprop = array_ops.transpose(x_backprop, x_reverse_order)
1958  n = pfor_input.pfor.loop_len_vector
1959  outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
1960  outputs = [x_backprop] + outputs
1961  return [wrap(output, True) for output in outputs]
1962
1963
1964@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0)
1965@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0)
1966@RegisterPForWithArgs("AvgPool3DGrad", flatten_dims=[1], shape_dim=0)
1967def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims,
1968                                       shape_dim):
1969  del op_type
1970  inputs = _inputs_with_flattening(pfor_input, flatten_dims)
1971  n = pfor_input.pfor.loop_len_vector
1972  # Adjust the `input_sizes` input.
1973  ones = array_ops.ones([array_ops.shape(inputs[shape_dim])[0] - 1],
1974                        dtype=n.dtype)
1975  inputs[shape_dim] *= array_ops.concat([n, ones], axis=0)
1976  outputs = _create_op(
1977      pfor_input.op_type,
1978      inputs, [x.dtype for x in pfor_input.outputs],
1979      attrs=pfor_input.op.node_def.attr).outputs
1980  outputs = [_unflatten_first_dim(x, n) for x in outputs]
1981  return [wrap(x, True) for x in outputs]
1982
1983
1984@RegisterPFor("Conv2DBackpropFilter")
1985def _convert_conv2d_backprop_filter(pfor_input):
1986  pfor_input.stack_inputs(stack_indices=[2])
1987  inputs, inputs_stacked, _ = pfor_input.input(0)
1988  filter_sizes = pfor_input.unstacked_input(1)
1989  grads = pfor_input.stacked_input(2)
1990  strides = pfor_input.get_attr("strides")
1991  padding = pfor_input.get_attr("padding")
1992  use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu")
1993  data_format = pfor_input.get_attr("data_format")
1994  dilations = pfor_input.get_attr("dilations")
1995  if inputs_stacked:
1996    # TODO(agarwal): Implement this efficiently.
1997    logging.warn("Conv2DBackpropFilter uses a while_loop. Fix that!")
1998
1999    def while_body(i, ta):
2000      inp_i = inputs[i, ...]
2001      grad_i = grads[i, ...]
2002      output = nn_ops.conv2d_backprop_filter(
2003          inp_i,
2004          filter_sizes,
2005          grad_i,
2006          strides=strides,
2007          padding=padding,
2008          use_cudnn_on_gpu=use_cudnn_on_gpu,
2009          data_format=data_format,
2010          dilations=dilations)
2011      return i + 1, ta.write(i, array_ops.expand_dims(output, 0))
2012
2013    n = array_ops.reshape(pfor_input.pfor.loop_len_vector, [])
2014    _, ta = control_flow_ops.while_loop(
2015        lambda i, ta: i < n, while_body,
2016        (0, tensor_array_ops.TensorArray(inputs.dtype, n)))
2017    output = ta.concat()
2018    return wrap(output, True)
2019  else:
2020    # We merge the stack dimension with the channel dimension of the gradients
2021    # and pretend we had a larger filter (see change to filter_sizes below).
2022    # Once the filter backprop is computed, we reshape and transpose back
2023    # appropriately.
2024    grads, _, _ = _channel_flatten_input(grads, data_format)
2025    n = pfor_input.pfor.loop_len_vector
2026    old_filter_sizes = filter_sizes
2027    filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0)
2028    output = nn_ops.conv2d_backprop_filter(
2029        inputs,
2030        filter_sizes,
2031        grads,
2032        strides=strides,
2033        padding=padding,
2034        use_cudnn_on_gpu=use_cudnn_on_gpu,
2035        data_format=data_format,
2036        dilations=dilations)
2037    new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0)
2038    output = array_ops.reshape(output, new_filter_shape)
2039    output = array_ops.transpose(output, [3, 0, 1, 2, 4])
2040    return wrap(output, True)
2041
2042
2043@RegisterPForWithArgs("LogSoftmax", gen_nn_ops.log_softmax)
2044@RegisterPForWithArgs("Softmax", gen_nn_ops.softmax)
2045def _convert_softmax(pfor_input, op_type, op_func):
2046  del op_type
2047  return wrap(op_func(pfor_input.stacked_input(0)), True)
2048
2049
2050# array_ops
2051
2052
2053@RegisterPForWithArgs("Identity", array_ops.identity)
2054@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient)
2055@RegisterPForWithArgs("MatrixDiag", array_ops.matrix_diag)
2056@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part)
2057def _convert_identity(pfor_input, op_type, op_func):
2058  del op_type
2059  return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
2060
2061
2062@RegisterPFor("IdentityN")
2063def _convert_identity_n(pfor_input):
2064  outputs = array_ops.identity_n([x.t for x in pfor_input.inputs])
2065  return [
2066      wrap(out, inp.is_stacked) for out, inp in zip(outputs, pfor_input.inputs)
2067  ]
2068
2069
2070@RegisterPFor("Reshape")
2071def _convert_reshape(pfor_input):
2072  t = pfor_input.stacked_input(0)
2073  shape = pfor_input.unstacked_input(1)
2074  new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
2075  return wrap(array_ops.reshape(t, new_shape), True)
2076
2077
2078@RegisterPFor("Fill")
2079def _convert_fill(pfor_input):
2080  dims = pfor_input.unstacked_input(0)
2081  value = pfor_input.stacked_input(1)
2082  # Expand the rank of `value`
2083  new_shape = array_ops.concat(
2084      [[-1], array_ops.ones([array_ops.size(dims)], dtype=dtypes.int32)],
2085      axis=0)
2086  value = array_ops.reshape(value, new_shape)
2087  # Compute the new output shape
2088  new_dims = array_ops.concat([pfor_input.pfor.loop_len_vector, dims], axis=0)
2089  # Broadcast
2090  return wrap(array_ops.broadcast_to(value, new_dims), True)
2091
2092
2093@RegisterPFor("BroadcastTo")
2094def _convert_broadcast_to(pfor_input):
2095  t = pfor_input.stacked_input(0)
2096  shape = pfor_input.unstacked_input(1)
2097  new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
2098
2099  # Expand dims of stacked t to broadcast against the new shape.
2100  # TODO(davmre): consider factoring out common code with
2101  # `expanddim_inputs_for_broadcast`, which has similar logic but with
2102  # implicit shapes (of input Tensors) rather than explicit shapes.
2103  rank_diff = array_ops.shape(new_shape)[0] - array_ops.rank(t)
2104  ones = array_ops.tile([1], array_ops.reshape(rank_diff, [1]))
2105  t_shape = array_ops.shape(t)
2106  t_expanded_shape = array_ops.concat([t_shape[:1], ones, t_shape[1:]], axis=0)
2107
2108  return wrap(
2109      array_ops.broadcast_to(array_ops.reshape(t, t_expanded_shape), new_shape),
2110      True)
2111
2112
2113@RegisterPFor("ExpandDims")
2114def _convert_expanddims(pfor_input):
2115  t = pfor_input.stacked_input(0)
2116  dim = pfor_input.unstacked_input(1)
2117  dim += math_ops.cast(dim >= 0, dim.dtype)
2118  return wrap(array_ops.expand_dims(t, axis=dim), True)
2119
2120
2121@RegisterPForWithArgs("LowerBound", gen_array_ops.lower_bound)
2122@RegisterPForWithArgs("UpperBound", gen_array_ops.upper_bound)
2123def _convert_searchsorted(pfor_input, _, op_func):
2124  pfor_input.stack_inputs()
2125  sorted_inputs = _flatten_first_two_dims(pfor_input.stacked_input(0))
2126  values = _flatten_first_two_dims(pfor_input.stacked_input(1))
2127  out_type = pfor_input.get_attr("out_type")
2128  output = op_func(sorted_inputs, values, out_type)
2129  return wrap(
2130      _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector), True)
2131
2132
2133@RegisterPFor("MatrixBandPart")
2134def _convert_matrix_band_part(pfor_input):
2135  t = pfor_input.stacked_input(0)
2136  num_lower = pfor_input.unstacked_input(1)
2137  num_upper = pfor_input.unstacked_input(2)
2138  return wrap(
2139      array_ops.matrix_band_part(t, num_lower=num_lower, num_upper=num_upper),
2140      True)
2141
2142
2143@RegisterPFor("MatrixSetDiag")
2144def _convert_matrix_set_diag(pfor_input):
2145  pfor_input.stack_inputs()
2146  t = pfor_input.stacked_input(0)
2147  diag = pfor_input.stacked_input(1)
2148  return wrap(array_ops.matrix_set_diag(t, diag), True)
2149
2150
2151# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3.
2152# The input orders defined in the OpKernel and the actual python API are
2153# different (for compatibility with V1), so we cannot use _convert_identity.
2154# v2 is not compatible with v3 and is never exposed on the public API.
2155@RegisterPFor("MatrixDiagV2")
2156@RegisterPFor("MatrixDiagV3")
2157def _convert_matrix_diag_v2(pfor_input):
2158  params = {
2159      "diagonal": pfor_input.stacked_input(0),
2160      "k": pfor_input.unstacked_input(1),
2161      "num_rows": pfor_input.unstacked_input(2),
2162      "num_cols": pfor_input.unstacked_input(3),
2163      "padding_value": pfor_input.unstacked_input(4)
2164  }
2165  if pfor_input.op_type == "MatrixDiagV2":
2166    return wrap(array_ops.matrix_diag_v2(**params), True)
2167  params["align"] = pfor_input.get_attr("align")
2168  return wrap(array_ops.matrix_diag(**params), True)
2169
2170
2171@RegisterPFor("Diag")
2172def _convert_diag(pfor_input):
2173  diag = pfor_input.stacked_input(0)
2174  if diag.shape.ndims == 2:
2175    # We can use matrix_diag.
2176    return wrap(array_ops.matrix_diag(diag), True)
2177  else:
2178    # It is not clear if we can do better than a while loop here with existing
2179    # kernels.
2180    return _fallback_converter(pfor_input, warn=False)
2181
2182
2183# See notes for MatrixDiagV2
2184@RegisterPFor("MatrixDiagPartV2")
2185@RegisterPFor("MatrixDiagPartV3")
2186def _convert_matrix_diag_part_v2(pfor_input):
2187  params = {
2188      "input": pfor_input.stacked_input(0),
2189      "k": pfor_input.unstacked_input(1),
2190      "padding_value": pfor_input.unstacked_input(2)
2191  }
2192  if pfor_input.op_type == "MatrixDiagPartV2":
2193    return wrap(array_ops.matrix_diag_part_v2(**params), True)
2194  params["align"] = pfor_input.get_attr("align")
2195  return wrap(array_ops.matrix_diag_part(**params), True)
2196
2197
2198# See notes for MatrixDiagV2
2199@RegisterPFor("MatrixSetDiagV2")
2200@RegisterPFor("MatrixSetDiagV3")
2201def _convert_matrix_set_diag_v2(pfor_input):
2202  pfor_input.stack_inputs([0, 1])
2203  params = {
2204      "input": pfor_input.stacked_input(0),
2205      "diagonal": pfor_input.stacked_input(1),
2206      "k": pfor_input.unstacked_input(2)
2207  }
2208  if pfor_input.op_type == "MatrixSetDiagV2":
2209    return wrap(array_ops.matrix_set_diag_v2(**params), True)
2210  params["align"] = pfor_input.get_attr("align")
2211  return wrap(array_ops.matrix_set_diag(**params), True)
2212
2213
2214@RegisterPFor("DiagPart")
2215def _convert_diag_part(pfor_input):
2216  inp = pfor_input.stacked_input(0)
2217  if inp.shape.ndims == 3:
2218    # We can use matrix_diag_part.
2219    return wrap(array_ops.matrix_diag_part(inp), True)
2220  else:
2221    # It is not clear if we can do better than a while loop here with existing
2222    # kernels.
2223    return _fallback_converter(pfor_input, warn=False)
2224
2225
2226@RegisterPFor("OneHot")
2227def _convert_one_hot(pfor_input):
2228  indices = pfor_input.stacked_input(0)
2229  depth = pfor_input.unstacked_input(1)
2230  on_value = pfor_input.unstacked_input(2)
2231  off_value = pfor_input.unstacked_input(3)
2232  axis = pfor_input.get_attr("axis")
2233  if axis >= 0:
2234    axis += 1
2235  return wrap(
2236      array_ops.one_hot(indices, depth, on_value, off_value, axis), True)
2237
2238
2239@RegisterPFor("Slice")
2240def _convert_slice(pfor_input):
2241  t = pfor_input.stacked_input(0)
2242  begin = pfor_input.unstacked_input(1)
2243  size = pfor_input.unstacked_input(2)
2244  begin = array_ops.concat([[0], begin], axis=0)
2245  size = array_ops.concat([[-1], size], axis=0)
2246  return wrap(array_ops.slice(t, begin, size), True)
2247
2248
2249@RegisterPFor("Tile")
2250def _convert_tile(pfor_input):
2251  t = pfor_input.stacked_input(0)
2252  multiples = pfor_input.unstacked_input(1)
2253  multiples = array_ops.concat([[1], multiples], 0)
2254  return wrap(array_ops.tile(t, multiples), True)
2255
2256
2257@RegisterPFor("Pack")
2258def _convert_pack(pfor_input):
2259  pfor_input.stack_inputs()
2260  axis = pfor_input.get_attr("axis")
2261  if axis >= 0:
2262    axis += 1
2263  return wrap(
2264      array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True)
2265
2266
2267@RegisterPFor("Unpack")
2268def _convert_unpack(pfor_input):
2269  value = pfor_input.stacked_input(0)
2270  axis = pfor_input.get_attr("axis")
2271  if axis >= 0:
2272    axis += 1
2273  num = pfor_input.get_attr("num")
2274  return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)]
2275
2276
2277@RegisterPFor("Pad")
2278def _convert_pad(pfor_input):
2279  t = pfor_input.stacked_input(0)
2280  paddings = pfor_input.unstacked_input(1)
2281  paddings = array_ops.concat([[[0, 0]], paddings], 0)
2282  return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True)
2283
2284
2285@RegisterPFor("Split")
2286def _convert_split(pfor_input):
2287  split_dim = pfor_input.unstacked_input(0)
2288  t = pfor_input.stacked_input(1)
2289  num_split = pfor_input.get_attr("num_split")
2290  split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
2291  return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)]
2292
2293
2294@RegisterPFor("SplitV")
2295def _convert_split_v(pfor_input):
2296  t = pfor_input.stacked_input(0)
2297  splits = pfor_input.unstacked_input(1)
2298  split_dim = pfor_input.unstacked_input(2)
2299  split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
2300  return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)]
2301
2302
2303@RegisterPFor("Squeeze")
2304def _convert_squeeze(pfor_input):
2305  t = pfor_input.stacked_input(0)
2306  squeeze_dims = pfor_input.get_attr("squeeze_dims")
2307  squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims]
2308  return wrap(array_ops.squeeze(t, axis=squeeze_dims), True)
2309
2310
2311@RegisterPFor("ReverseV2")
2312def _convert_reverse(pfor_input):
2313  value = pfor_input.stacked_input(0)
2314  axis = pfor_input.unstacked_input(1)
2315  new_axis = array_ops.where_v2(axis >= 0, axis + 1, axis)
2316  return wrap(gen_array_ops.reverse_v2(value, axis=new_axis), True)
2317
2318
2319@RegisterPForWithArgs("Transpose", gen_array_ops.transpose)
2320@RegisterPForWithArgs("ConjugateTranspose", gen_array_ops.conjugate_transpose)
2321def _convert_transpose(pfor_input, _, op_func):
2322  t = pfor_input.stacked_input(0)
2323  perm = pfor_input.unstacked_input(1)
2324  new_perm = array_ops.concat([[0], perm + 1], axis=0)
2325  return wrap(op_func(t, new_perm), True)
2326
2327
2328@RegisterPFor("ZerosLike")
2329def _convert_zeroslike(pfor_input):
2330  t = pfor_input.stacked_input(0)
2331  shape = array_ops.shape(t)[1:]
2332  return wrap(array_ops.zeros(shape, dtype=t.dtype), False)
2333
2334
2335@RegisterPFor("Gather")
2336@RegisterPFor("GatherV2")
2337def _convert_gather(pfor_input):
2338  param, param_stacked, _ = pfor_input.input(0)
2339  indices, indices_stacked, _ = pfor_input.input(1)
2340  batch_dims = pfor_input.get_attr("batch_dims")
2341
2342  op_type = pfor_input.op_type
2343  if op_type == "Gather":
2344    validate_indices = pfor_input.get_attr("validate_indices")
2345    axis = 0
2346  else:
2347    validate_indices = None
2348    # Assume we will never have a Tensor with rank > 2**32.
2349    axis = math_ops.cast(pfor_input.unstacked_input(2), dtypes.int32)
2350    axis_value = tensor_util.constant_value(axis)
2351    if axis_value is not None:
2352      axis = axis_value
2353  if indices_stacked and not param_stacked:
2354    if indices is pfor_input.pfor.all_indices and axis == 0:
2355      param_shape0 = tensor_shape.dimension_value(param.shape[0])
2356      indices_shape0 = tensor_shape.dimension_value(indices.shape[0])
2357      if param_shape0 is not None and indices_shape0 == param_shape0:
2358        # Note that with loops and conditionals, indices may not be contiguous.
2359        # However they will be sorted and unique. So if the shape matches, then
2360        # it must be picking up all the rows of param.
2361        return wrap(param, True)
2362
2363    if batch_dims != 0:
2364      # Convert `batch_dims` to its positive equivalent if necessary.
2365      batch_dims_pos = batch_dims
2366      if batch_dims < 0:
2367        batch_dims_pos += array_ops.rank(indices)
2368      # In order to maintain
2369      #   indices.shape[:batch_dims] == params.shape[:batch_dims]
2370      # with stacked indices, we move the first dimension of `indices` to the
2371      # `batch_dims + 1`th position. The (non-batch) index dimensions will be
2372      # inserted into the shape of `output` at the `axis` dimension, which is
2373      # then transposed to the front (below).
2374      order = array_ops.concat([
2375          math_ops.range(1, batch_dims_pos + 1),
2376          [0],
2377          math_ops.range(batch_dims_pos + 1, array_ops.rank(indices))], axis=0)
2378      indices = array_ops.transpose(indices, order)
2379
2380    output = array_ops.gather(
2381        param, indices, validate_indices=validate_indices, axis=axis,
2382        batch_dims=batch_dims)
2383    if axis != 0:
2384      axis = control_flow_ops.cond(axis < 0,
2385                                   lambda: axis + array_ops.rank(param),
2386                                   lambda: axis)
2387      order = array_ops.concat(
2388          [[axis],
2389           math_ops.range(axis),
2390           math_ops.range(axis + 1, array_ops.rank(output))],
2391          axis=0)
2392      output = control_flow_ops.cond(
2393          math_ops.equal(axis, 0), lambda: output,
2394          lambda: array_ops.transpose(output, order))
2395    return wrap(output, True)
2396  if param_stacked:
2397    pfor_input.stack_inputs(stack_indices=[1])
2398    indices = pfor_input.stacked_input(1)
2399
2400    output = array_ops.gather(
2401        param, indices,
2402        axis=array_ops.where(axis >= 0, axis + 1, axis),
2403        batch_dims=(batch_dims + 1 if batch_dims >= 0 else batch_dims))
2404    return wrap(output, True)
2405
2406
2407@RegisterPFor("GatherNd")
2408def _convert_gather_nd(pfor_input):
2409  # TODO(jmenick): Add support for unstacked params.
2410  pfor_input.stack_inputs(stack_indices=[1])
2411  params = pfor_input.stacked_input(0)
2412  indices = pfor_input.stacked_input(1)
2413  stacked_result = array_ops.gather_nd(params, indices, batch_dims=1)
2414  return wrap(stacked_result, True)
2415
2416
2417@RegisterPFor("ConcatV2")
2418def _convert_concatv2(pfor_input):
2419  n = pfor_input.num_inputs
2420  pfor_input.stack_inputs(stack_indices=range(n - 1))
2421  axis = pfor_input.unstacked_input(n - 1)
2422  axis += math_ops.cast(axis >= 0, axis.dtype)
2423  return wrap(
2424      array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis),
2425      True)
2426
2427
2428@RegisterPFor("StridedSlice")
2429def _convert_strided_slice(pfor_input):
2430  inp = pfor_input.stacked_input(0)
2431  begin = pfor_input.unstacked_input(1)
2432  end = pfor_input.unstacked_input(2)
2433  strides = pfor_input.unstacked_input(3)
2434  begin_mask = pfor_input.get_attr("begin_mask")
2435  end_mask = pfor_input.get_attr("end_mask")
2436  ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
2437  new_axis_mask = pfor_input.get_attr("new_axis_mask")
2438  shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
2439
2440  begin = array_ops.concat([[0], begin], axis=0)
2441  end = array_ops.concat([[0], end], axis=0)
2442  strides = array_ops.concat([[1], strides], axis=0)
2443  begin_mask = begin_mask << 1 | 1
2444  end_mask = end_mask << 1 | 1
2445  ellipsis_mask <<= 1
2446  new_axis_mask <<= 1
2447  shrink_axis_mask <<= 1
2448  return wrap(
2449      array_ops.strided_slice(
2450          inp,
2451          begin,
2452          end,
2453          strides,
2454          begin_mask=begin_mask,
2455          end_mask=end_mask,
2456          ellipsis_mask=ellipsis_mask,
2457          new_axis_mask=new_axis_mask,
2458          shrink_axis_mask=shrink_axis_mask), True)
2459
2460
2461@RegisterPFor("StridedSliceGrad")
2462def _convert_strided_slice_grad(pfor_input):
2463  shape = pfor_input.unstacked_input(0)
2464  begin = pfor_input.unstacked_input(1)
2465  end = pfor_input.unstacked_input(2)
2466  strides = pfor_input.unstacked_input(3)
2467  dy = pfor_input.stacked_input(4)
2468  begin_mask = pfor_input.get_attr("begin_mask")
2469  end_mask = pfor_input.get_attr("end_mask")
2470  ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
2471  new_axis_mask = pfor_input.get_attr("new_axis_mask")
2472  shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
2473
2474  shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
2475  begin = array_ops.concat([[0], begin], axis=0)
2476  end = array_ops.concat([[0], end], axis=0)
2477  strides = array_ops.concat([[1], strides], axis=0)
2478  begin_mask = begin_mask << 1 | 1
2479  end_mask = end_mask << 1 | 1
2480  ellipsis_mask <<= 1
2481  new_axis_mask <<= 1
2482  shrink_axis_mask <<= 1
2483  return wrap(
2484      array_ops.strided_slice_grad(
2485          shape,
2486          begin,
2487          end,
2488          strides,
2489          dy,
2490          begin_mask=begin_mask,
2491          end_mask=end_mask,
2492          ellipsis_mask=ellipsis_mask,
2493          new_axis_mask=new_axis_mask,
2494          shrink_axis_mask=shrink_axis_mask), True)
2495
2496
2497@RegisterPFor("CheckNumerics")
2498def _convert_check_numerics(pfor_input):
2499  t = pfor_input.stacked_input(0)
2500  message = pfor_input.get_attr("message")
2501  return wrap(gen_array_ops.check_numerics(t, message), True)
2502
2503
2504# math_ops
2505
2506
2507@RegisterPFor("MatMul")
2508def _convert_matmul(pfor_input):
2509  # TODO(agarwal): Check if tiling is faster than two transposes.
2510  a, a_stacked, _ = pfor_input.input(0)
2511  b, b_stacked, _ = pfor_input.input(1)
2512  tr_a = pfor_input.get_attr("transpose_a")
2513  tr_b = pfor_input.get_attr("transpose_b")
2514  if a_stacked and b_stacked:
2515    output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True)
2516    return output
2517  elif a_stacked:
2518    if tr_a:
2519      a = array_ops.transpose(a, [0, 2, 1])
2520    if a.shape.is_fully_defined():
2521      x, y, z = a.shape
2522    else:
2523      x, y, z = [
2524          array_ops.reshape(i, [])
2525          for i in array_ops.split(array_ops.shape(a), 3)
2526      ]
2527    a = array_ops.reshape(a, [x * y, z])
2528    prod = math_ops.matmul(a, b, transpose_b=tr_b)
2529    return wrap(array_ops.reshape(prod, [x, y, -1]), True)
2530  else:
2531    assert b_stacked
2532    if tr_b:
2533      perm = [2, 0, 1]
2534      b = array_ops.transpose(b, perm)
2535    else:
2536      # As an optimization, if one of the first two dimensions is 1, then we can
2537      # reshape instead of transpose.
2538      # TODO(agarwal): This check can be done inside Transpose kernel.
2539      b_shape = array_ops.shape(b)
2540      min_dim = math_ops.minimum(b_shape[0], b_shape[1])
2541      perm = control_flow_ops.cond(
2542          math_ops.equal(min_dim, 1), lambda: [0, 1, 2], lambda: [1, 0, 2])
2543      new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]])
2544      b = array_ops.transpose(b, perm)
2545      b = array_ops.reshape(b, new_shape)
2546
2547    if b.shape.is_fully_defined():
2548      x, y, z = b.shape
2549    else:
2550      x, y, z = [
2551          array_ops.reshape(i, [])
2552          for i in array_ops.split(array_ops.shape(b), 3)
2553      ]
2554    b = array_ops.reshape(b, [x, y * z])
2555    prod = math_ops.matmul(a, b, transpose_a=tr_a)
2556    prod = array_ops.reshape(prod, [-1, y, z])
2557    prod = array_ops.transpose(prod, [1, 0, 2])
2558    return wrap(prod, True)
2559
2560
2561# TODO(rmlarsen): Use the converter of BatchMatMulV2 once compatibility window
2562# is met.
2563@RegisterPFor("BatchMatMul")
2564def _convert_batch_mat_mul(pfor_input):
2565  # TODO(agarwal): There may be a more efficient way to do this instead of
2566  # stacking the inputs.
2567  pfor_input.stack_inputs()
2568  x = pfor_input.stacked_input(0)
2569  y = pfor_input.stacked_input(1)
2570  adj_x = pfor_input.get_attr("adj_x")
2571  adj_y = pfor_input.get_attr("adj_y")
2572
2573  x = _flatten_first_two_dims(x)
2574  y = _flatten_first_two_dims(y)
2575  output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
2576  output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector)
2577  return wrap(output, True)
2578
2579
2580@RegisterPFor("BatchMatMulV2")
2581def _convert_batch_mat_mul_v2(pfor_input):
2582  pfor_input.expanddim_inputs_for_broadcast()
2583  x = pfor_input.input(0)[0]
2584  y = pfor_input.input(1)[0]
2585  adj_x = pfor_input.get_attr("adj_x")
2586  adj_y = pfor_input.get_attr("adj_y")
2587
2588  output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
2589  return wrap(output, True)
2590
2591
2592@RegisterPForWithArgs("Sum", math_ops.reduce_sum)
2593@RegisterPForWithArgs("Prod", math_ops.reduce_prod)
2594@RegisterPForWithArgs("Max", math_ops.reduce_max)
2595@RegisterPForWithArgs("Min", math_ops.reduce_min)
2596@RegisterPForWithArgs("Mean", math_ops.reduce_mean)
2597@RegisterPForWithArgs("All", math_ops.reduce_all)
2598@RegisterPForWithArgs("Any", math_ops.reduce_any)
2599def _convert_reduction(pfor_input, _, op_func):
2600  t = pfor_input.stacked_input(0)
2601  indices = pfor_input.unstacked_input(1)
2602  # Shift positive indices by one to account for the extra dimension.
2603  indices += math_ops.cast(indices >= 0, indices.dtype)
2604  keep_dims = pfor_input.get_attr("keep_dims")
2605  return wrap(op_func(t, indices, keepdims=keep_dims), True)
2606
2607
2608@RegisterPForWithArgs("ArgMax", math_ops.argmax)
2609@RegisterPForWithArgs("ArgMin", math_ops.argmin)
2610def _convert_argmax_argmin(pfor_input, _, op_func):
2611  t = pfor_input.stacked_input(0)
2612  dimension = pfor_input.unstacked_input(1)
2613  dimension += math_ops.cast(dimension >= 0, dimension.dtype)
2614  output_type = pfor_input.get_attr("output_type")
2615  return wrap(op_func(t, axis=dimension, output_type=output_type), True)
2616
2617
2618@RegisterPFor("Bucketize")
2619def _convert_bucketize(pfor_input):
2620  t = pfor_input.stacked_input(0)
2621  boundaries = pfor_input.get_attr("boundaries")
2622  return wrap(math_ops.bucketize(t, boundaries), True)
2623
2624
2625@RegisterPFor("ClipByValue")
2626def _convert_clip_by_value(pfor_input):
2627  t = pfor_input.stacked_input(0)
2628  clip_value_min = pfor_input.unstacked_input(1)
2629  clip_value_max = pfor_input.unstacked_input(2)
2630  return wrap(gen_math_ops.clip_by_value(t, clip_value_min, clip_value_max),
2631              True)
2632
2633
2634@RegisterPForWithArgs("Cumsum", math_ops.cumsum)
2635@RegisterPForWithArgs("Cumprod", math_ops.cumprod)
2636def _convert_cumfoo(pfor_input, _, op_func):
2637  t = pfor_input.stacked_input(0)
2638  axis = pfor_input.unstacked_input(1)
2639  # Shift positive indices by one to account for the extra dimension.
2640  axis += math_ops.cast(axis >= 0, axis.dtype)
2641  exclusive = pfor_input.get_attr("exclusive")
2642  reverse = pfor_input.get_attr("reverse")
2643  return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True)
2644
2645
2646@RegisterPFor("BiasAdd")
2647def _convert_biasadd(pfor_input):
2648  t, t_stacked, _ = pfor_input.input(0)
2649  bias, bias_stacked, _ = pfor_input.input(1)
2650  data_format = pfor_input.get_attr("data_format").decode()
2651  if bias_stacked:
2652    # BiasAdd only supports 1-D biases, so cast bias to match value and use Add.
2653    pfor_input.expanddim_inputs_for_broadcast()
2654    t, _, _ = pfor_input.input(0)
2655    bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype)
2656    if compat.as_bytes(data_format) == b"NCHW":
2657      b_shape = array_ops.shape(bias)
2658      new_b_shape = array_ops.concat(
2659          [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0)
2660      bias = array_ops.reshape(bias, new_b_shape)
2661    return wrap(math_ops.add(t, bias), True)
2662  else:
2663    assert t_stacked, "At least one input to BiasAdd should be loop variant."
2664    if compat.as_bytes(data_format) == b"NCHW":
2665      shape = array_ops.shape(t)
2666      flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0)
2667      t = array_ops.reshape(t, flattened_shape)
2668      t = nn_ops.bias_add(t, bias, data_format="NCHW")
2669      t = array_ops.reshape(t, shape)
2670      return wrap(t, True)
2671    return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True)
2672
2673
2674@RegisterPForWithArgs("UnsortedSegmentSum", math_ops.unsorted_segment_sum)
2675@RegisterPForWithArgs("UnsortedSegmentMax", math_ops.unsorted_segment_max)
2676@RegisterPForWithArgs("UnsortedSegmentMin", math_ops.unsorted_segment_min)
2677@RegisterPForWithArgs("UnsortedSegmentProd", math_ops.unsorted_segment_prod)
2678def _convert_unsortedsegmentsum(pfor_input, _, op_func):
2679  pfor_input.stack_inputs([0, 1])
2680  data = pfor_input.stacked_input(0)
2681  segment_ids = pfor_input.stacked_input(1)
2682  # TODO(agarwal): handle stacked?
2683  num_segments = pfor_input.unstacked_input(2)
2684  if segment_ids.dtype != num_segments.dtype:
2685    segment_ids = math_ops.cast(segment_ids, dtypes.int64)
2686    num_segments = math_ops.cast(num_segments, dtypes.int64)
2687  dtype = segment_ids.dtype
2688  segment_shape = array_ops.shape(segment_ids, out_type=dtype)
2689  n = segment_shape[0]
2690  ones = array_ops.ones_like(segment_shape, dtype=dtype)[1:]
2691  segment_offset = num_segments * math_ops.range(n, dtype=dtype)
2692  segment_offset = array_ops.reshape(segment_offset,
2693                                     array_ops.concat([[n], ones], axis=0))
2694  segment_ids += segment_offset
2695  num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast(
2696      n, dtypes.int64)
2697  output = op_func(data, segment_ids, num_segments)
2698  new_output_shape = array_ops.concat(
2699      [[n, -1], array_ops.shape(output)[1:]], axis=0)
2700  output = array_ops.reshape(output, new_output_shape)
2701  return wrap(output, True)
2702
2703
2704def _flatten_array_with_offset(ids, offset_delta, num_rows):
2705  """Flattens a rank 2 tensor, adding an offset to each row."""
2706  # Note that if `ids` is rank 1, it is broadcast to rank 2.
2707  offset_delta = math_ops.cast(offset_delta, ids.dtype)
2708  n = math_ops.cast(num_rows, dtype=ids.dtype)
2709  offsets = math_ops.range(
2710      start=0, limit=n * offset_delta, delta=offset_delta, dtype=ids.dtype)
2711  offsets = array_ops.expand_dims(offsets, -1)
2712  ids += offsets
2713  return array_ops.reshape(ids, [-1])
2714
2715
2716@RegisterPForWithArgs("SparseSegmentSum", math_ops.sparse_segment_sum_v2)
2717@RegisterPForWithArgs("SparseSegmentMean", math_ops.sparse_segment_mean_v2)
2718@RegisterPForWithArgs("SparseSegmentSqrtN", math_ops.sparse_segment_sqrt_n_v2)
2719@RegisterPForWithArgs("SparseSegmentSumWithNumSegments",
2720                      math_ops.sparse_segment_sum_v2)
2721@RegisterPForWithArgs("SparseSegmentMeanWithNumSegments",
2722                      math_ops.sparse_segment_mean_v2)
2723@RegisterPForWithArgs("SparseSegmentSqrtNWithNumSegments",
2724                      math_ops.sparse_segment_sqrt_n_v2)
2725def _convert_sparse_segment(pfor_input, _, op_func):
2726  _, segment_ids_stacked, _ = pfor_input.input(2)
2727  if segment_ids_stacked:
2728    pfor_input.stack_inputs([1])
2729  data, data_stacked, _ = pfor_input.input(0)
2730  indices, _, _ = pfor_input.input(1)
2731  num_inputs = len(pfor_input.inputs)
2732  assert num_inputs in (3, 4)
2733  if num_inputs == 3:
2734    # `segment_ids` needs to be unstacked since otherwise output sizes could
2735    # differ across pfor iterations.
2736    segment_ids = pfor_input.unstacked_input(2)
2737    num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1)
2738  else:
2739    segment_ids, _, _ = pfor_input.input(2)
2740    num_segments = pfor_input.unstacked_input(3)
2741
2742  n = pfor_input.pfor.loop_len_vector[0]
2743  if data_stacked:
2744    indices = _flatten_array_with_offset(indices, array_ops.shape(data)[1], n)
2745    data = _flatten_first_two_dims(data)
2746  else:
2747    indices = array_ops.reshape(indices, [-1])
2748  segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n)
2749
2750  if num_inputs == 3:
2751    num_segments = None
2752  else:
2753    num_segments *= n
2754  output = op_func(data, indices, segment_ids, num_segments=num_segments)
2755  output = _unflatten_first_dim(output, [n])
2756  return wrap(output, True)
2757
2758
2759@RegisterPForWithArgs("SparseSegmentMeanGrad",
2760                      math_ops.sparse_segment_mean_grad)
2761@RegisterPForWithArgs("SparseSegmentSqrtNGrad",
2762                      math_ops.sparse_segment_sqrt_n_grad)
2763def _convert_sparse_segment_grad(pfor_input, _, op_func):
2764  grad = pfor_input.stacked_input(0)
2765  indices = pfor_input.unstacked_input(1)
2766  segment_ids = pfor_input.unstacked_input(2)
2767  dim0 = pfor_input.unstacked_input(3)
2768
2769  n = pfor_input.pfor.loop_len_vector[0]
2770  indices = _flatten_array_with_offset(indices, dim0, n)
2771  num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1)
2772  segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n)
2773  grad = _flatten_first_two_dims(grad)
2774  dim0 *= n
2775  output = op_func(grad, indices, segment_ids, dim0)
2776  output = _unflatten_first_dim(output, [n])
2777  return wrap(output, True)
2778
2779
2780@RegisterPFor("Cast")
2781def _convert_cast(pfor_input):
2782  inp = pfor_input.stacked_input(0)
2783  dtype = pfor_input.get_attr("DstT")
2784  return wrap(math_ops.cast(inp, dtype), True)
2785
2786
2787@RegisterPForWithArgs("Abs", math_ops.abs)
2788@RegisterPForWithArgs("Acos", math_ops.acos)
2789@RegisterPForWithArgs("Acosh", math_ops.acosh)
2790@RegisterPForWithArgs("Add", math_ops.add)
2791@RegisterPForWithArgs("AddV2", math_ops.add_v2)
2792@RegisterPForWithArgs("Angle", math_ops.angle)
2793@RegisterPForWithArgs("Asin", math_ops.asin)
2794@RegisterPForWithArgs("Asinh", math_ops.asinh)
2795@RegisterPForWithArgs("Atan", math_ops.atan)
2796@RegisterPForWithArgs("Atan2", math_ops.atan2)
2797@RegisterPForWithArgs("Atanh", math_ops.atanh)
2798@RegisterPForWithArgs("BesselI0", special_math_ops.bessel_i0)
2799@RegisterPForWithArgs("BesselI1", special_math_ops.bessel_i1)
2800@RegisterPForWithArgs("BesselI0e", special_math_ops.bessel_i0e)
2801@RegisterPForWithArgs("BesselI1e", special_math_ops.bessel_i1e)
2802@RegisterPForWithArgs("BesselK0", special_math_ops.bessel_k0)
2803@RegisterPForWithArgs("BesselK1", special_math_ops.bessel_k1)
2804@RegisterPForWithArgs("BesselK0e", special_math_ops.bessel_k0e)
2805@RegisterPForWithArgs("BesselK1e", special_math_ops.bessel_k1e)
2806@RegisterPForWithArgs("BesselJ0", special_math_ops.bessel_j0)
2807@RegisterPForWithArgs("BesselJ1", special_math_ops.bessel_j1)
2808@RegisterPForWithArgs("BesselY0", special_math_ops.bessel_y0)
2809@RegisterPForWithArgs("BesselY1", special_math_ops.bessel_y1)
2810@RegisterPForWithArgs("BitwiseAnd", bitwise_ops.bitwise_and)
2811@RegisterPForWithArgs("BitwiseOr", bitwise_ops.bitwise_or)
2812@RegisterPForWithArgs("BitwiseXor", bitwise_ops.bitwise_xor)
2813@RegisterPForWithArgs("Ceil", math_ops.ceil)
2814@RegisterPForWithArgs("Complex", math_ops.complex)
2815@RegisterPForWithArgs("ComplexAbs", math_ops.complex_abs)
2816@RegisterPForWithArgs("Conj", math_ops.conj)
2817@RegisterPForWithArgs("Cos", math_ops.cos)
2818@RegisterPForWithArgs("Cosh", math_ops.cosh)
2819@RegisterPForWithArgs("Dawsn", special_math_ops.dawsn)
2820@RegisterPForWithArgs("Digamma", math_ops.digamma)
2821@RegisterPForWithArgs("Div", math_ops.div)
2822@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan)
2823@RegisterPForWithArgs("Elu", nn_ops.elu)
2824@RegisterPForWithArgs("Erf", math_ops.erf)
2825@RegisterPForWithArgs("Erfc", math_ops.erfc)
2826@RegisterPForWithArgs("Erfinv", math_ops.erfinv)
2827@RegisterPForWithArgs("Exp", math_ops.exp)
2828@RegisterPForWithArgs("Expint", special_math_ops.expint)
2829@RegisterPForWithArgs("Expm1", math_ops.expm1)
2830@RegisterPForWithArgs("Floor", math_ops.floor)
2831@RegisterPForWithArgs("FloorDiv", math_ops.floor_div)
2832@RegisterPForWithArgs("FloorMod", math_ops.floor_mod)
2833@RegisterPForWithArgs("FresnelCos", special_math_ops.fresnel_cos)
2834@RegisterPForWithArgs("FresnelSin", special_math_ops.fresnel_sin)
2835@RegisterPForWithArgs("Greater", math_ops.greater)
2836@RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal)
2837@RegisterPForWithArgs("Igamma", math_ops.igamma)
2838@RegisterPForWithArgs("IgammaGradA", math_ops.igamma_grad_a)
2839@RegisterPForWithArgs("Igammac", math_ops.igammac)
2840@RegisterPForWithArgs("Imag", math_ops.imag)
2841@RegisterPForWithArgs("Inv", math_ops.inv)
2842@RegisterPForWithArgs("Invert", bitwise_ops.invert)
2843@RegisterPForWithArgs("IsFinite", math_ops.is_finite)
2844@RegisterPForWithArgs("IsInf", math_ops.is_inf)
2845@RegisterPForWithArgs("IsNan", math_ops.is_nan)
2846@RegisterPForWithArgs("LeftShift", bitwise_ops.left_shift)
2847@RegisterPForWithArgs("Less", math_ops.less)
2848@RegisterPForWithArgs("LessEqual", math_ops.less_equal)
2849@RegisterPForWithArgs("Lgamma", math_ops.lgamma)
2850@RegisterPForWithArgs("Log", math_ops.log)
2851@RegisterPForWithArgs("Log1p", math_ops.log1p)
2852@RegisterPForWithArgs("LogicalAnd", math_ops.logical_and)
2853@RegisterPForWithArgs("LogicalNot", math_ops.logical_not)
2854@RegisterPForWithArgs("LogicalOr", math_ops.logical_or)
2855@RegisterPForWithArgs("LogicalXor", math_ops.logical_xor)
2856@RegisterPForWithArgs("Maximum", math_ops.maximum)
2857@RegisterPForWithArgs("Minimum", math_ops.minimum)
2858@RegisterPForWithArgs("Mod", math_ops.mod)
2859@RegisterPForWithArgs("Mul", math_ops.multiply)
2860@RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan)
2861@RegisterPForWithArgs("Ndtri", math_ops.ndtri)
2862@RegisterPForWithArgs("Neg", math_ops.negative)
2863@RegisterPForWithArgs("Polygamma", math_ops.polygamma)
2864@RegisterPForWithArgs("Pow", math_ops.pow)
2865@RegisterPForWithArgs("Real", math_ops.real)
2866@RegisterPForWithArgs("RealDiv", math_ops.divide)
2867@RegisterPForWithArgs("Reciprocal", math_ops.reciprocal)
2868@RegisterPForWithArgs("Relu", nn_ops.relu)
2869@RegisterPForWithArgs("Relu6", nn_ops.relu6)
2870@RegisterPForWithArgs("RightShift", bitwise_ops.right_shift)
2871@RegisterPForWithArgs("Rint", math_ops.rint)
2872@RegisterPForWithArgs("Round", math_ops.round)
2873@RegisterPForWithArgs("Rsqrt", math_ops.rsqrt)
2874@RegisterPForWithArgs("Selu", nn_ops.selu)
2875@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid)
2876@RegisterPForWithArgs("Sign", math_ops.sign)
2877@RegisterPForWithArgs("Sin", math_ops.sin)
2878@RegisterPForWithArgs("Sinh", math_ops.sinh)
2879@RegisterPForWithArgs("Softplus", nn_ops.softplus)
2880@RegisterPForWithArgs("Softsign", nn_ops.softsign)
2881@RegisterPForWithArgs("Spence", special_math_ops.spence)
2882@RegisterPForWithArgs("Sqrt", math_ops.sqrt)
2883@RegisterPForWithArgs("Square", math_ops.square)
2884@RegisterPForWithArgs("SquaredDifference", math_ops.squared_difference)
2885@RegisterPForWithArgs("Sub", math_ops.subtract)
2886@RegisterPForWithArgs("Tan", math_ops.tan)
2887@RegisterPForWithArgs("Tanh", math_ops.tanh)
2888@RegisterPForWithArgs("TruncateDiv", math_ops.truncate_div)
2889@RegisterPForWithArgs("TruncateMod", math_ops.truncate_mod)
2890@RegisterPForWithArgs("Xdivy", math_ops.xdivy)
2891@RegisterPForWithArgs("Xlogy", math_ops.xlogy)
2892@RegisterPForWithArgs("Xlog1py", math_ops.xlog1py)
2893@RegisterPForWithArgs("Zeta", math_ops.zeta)
2894def _convert_cwise(pfor_input, op_type, op_func):
2895  # Note that ops handled here do not have attributes except those listed below
2896  # and hence don't need extra arguments passed to the cwise_op call below.
2897  for attr in pfor_input.op.node_def.attr.keys():
2898    assert attr in [u"T", u"Tout", u"_xla_compile_id"], (op_type, attr)
2899  if pfor_input.num_inputs > 1:
2900    pfor_input.expanddim_inputs_for_broadcast()
2901  return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
2902
2903
2904@RegisterPFor("LeakyRelu")
2905def _convert_leaky_relu(pfor_input):
2906  t = pfor_input.stacked_input(0)
2907  alpha = pfor_input.get_attr("alpha")
2908  return wrap(gen_nn_ops.leaky_relu(t, alpha=alpha), True)
2909
2910
2911@RegisterPFor("Equal")
2912def _convert_equal(pfor_input):
2913  pfor_input.expanddim_inputs_for_broadcast()
2914  x = pfor_input.input(0)[0]
2915  y = pfor_input.input(1)[0]
2916  incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
2917  return wrap(gen_math_ops.equal(
2918      x, y, incompatible_shape_error=incompatible_shape_error), True)
2919
2920
2921@RegisterPFor("NotEqual")
2922def _convert_not_equal(pfor_input):
2923  pfor_input.expanddim_inputs_for_broadcast()
2924  x = pfor_input.input(0)[0]
2925  y = pfor_input.input(1)[0]
2926  incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
2927  return wrap(gen_math_ops.not_equal(
2928      x, y, incompatible_shape_error=incompatible_shape_error), True)
2929
2930
2931@RegisterPFor("ApproximateEqual")
2932def _convert_approximate_equal(pfor_input):
2933  pfor_input.expanddim_inputs_for_broadcast()
2934  x = pfor_input.input(0)[0]
2935  y = pfor_input.input(1)[0]
2936  tolerance = pfor_input.get_attr("tolerance")
2937  return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True)
2938
2939
2940@RegisterPFor("Shape")
2941def _convert_shape(pfor_input):
2942  out_type = pfor_input.get_attr("out_type")
2943  return wrap(
2944      array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:],
2945      False)
2946
2947
2948@RegisterPFor("ShapeN")
2949def _convert_shape_n(pfor_input):
2950  out_type = pfor_input.get_attr("out_type")
2951  shapes = [
2952      array_ops.shape(x, out_type=out_type)[1:] if stacked else array_ops.shape(
2953          x, out_type=out_type) for x, stacked, _ in pfor_input.inputs
2954  ]
2955  return [wrap(x, False) for x in shapes]
2956
2957
2958@RegisterPFor("Size")
2959def _convert_size(pfor_input):
2960  out_type = pfor_input.get_attr("out_type")
2961  n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type)
2962  return wrap(
2963      array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n,
2964      False)
2965
2966
2967@RegisterPFor("Rank")
2968def _convert_rank(pfor_input):
2969  return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False)
2970
2971
2972@RegisterPFor("AddN")
2973def _convert_addn(pfor_input):
2974  # AddN does not support broadcasting.
2975  pfor_input.stack_inputs(tile_variants=False)
2976  return _wrap_and_tile_variants(
2977      math_ops.add_n([x.t for x in pfor_input.inputs]),
2978      pfor_input.pfor.loop_len_vector)
2979
2980
2981@RegisterPFor("Cross")
2982def _convert_cross(pfor_input):
2983  pfor_input.stack_inputs()
2984  a = pfor_input.stacked_input(0)
2985  b = pfor_input.stacked_input(1)
2986  return wrap(math_ops.cross(a, b), True)
2987
2988
2989@RegisterPFor("BiasAddGrad")
2990def _convert_biasaddgrad(pfor_input):
2991  grad = pfor_input.stacked_input(0)
2992  fmt = pfor_input.get_attr("data_format")
2993  if fmt == b"NCHW":
2994    output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False)
2995  else:
2996    grad_shape = array_ops.shape(grad)
2997    last_dim_shape = grad_shape[-1]
2998    first_dim_shape = grad_shape[0]
2999    output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape])
3000    output = math_ops.reduce_sum(output, axis=[1], keepdims=False)
3001  return wrap(output, True)
3002
3003
3004# Some required ops are not exposed under the tf namespace. Hence relying on
3005# _create_op to create them.
3006@RegisterPForWithArgs("EluGrad")
3007@RegisterPForWithArgs("LeakyReluGrad")
3008@RegisterPForWithArgs("ReciprocalGrad")
3009@RegisterPForWithArgs("Relu6Grad")
3010@RegisterPForWithArgs("ReluGrad")
3011@RegisterPForWithArgs("RsqrtGrad")
3012@RegisterPForWithArgs("SeluGrad")
3013@RegisterPForWithArgs("SigmoidGrad")
3014@RegisterPForWithArgs("SoftplusGrad")
3015@RegisterPForWithArgs("SoftsignGrad")
3016@RegisterPForWithArgs("SqrtGrad")
3017@RegisterPForWithArgs("TanhGrad")
3018def _convert_grads(pfor_input, op_type, *args, **kw_args):
3019  del args
3020  del kw_args
3021  # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we
3022  # have to use tiling here.
3023  pfor_input.stack_inputs()
3024  outputs = _create_op(
3025      op_type, [x.t for x in pfor_input.inputs],
3026      [x.dtype for x in pfor_input.outputs],
3027      attrs=pfor_input.op.node_def.attr).outputs
3028  return [wrap(x, True) for x in outputs]
3029
3030
3031@RegisterPFor("Select")
3032def _convert_select(pfor_input):
3033  pfor_input.stack_inputs()
3034  cond = pfor_input.stacked_input(0)
3035  t = pfor_input.stacked_input(1)
3036  e = pfor_input.stacked_input(2)
3037  cond_rank = array_ops.rank(cond)
3038  cond, t, e = control_flow_ops.cond(
3039      cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]),
3040      lambda: [cond, t, e])
3041  outputs = _create_op(
3042      pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs],
3043      attrs=pfor_input.op.node_def.attr).outputs
3044  n = pfor_input.pfor.loop_len_vector
3045  out = control_flow_ops.cond(cond_rank > 1,
3046                              lambda: _unflatten_first_dim(outputs[0], n),
3047                              lambda: outputs[0])
3048  return [wrap(out, True) for x in outputs]
3049
3050
3051@RegisterPFor("SelectV2")
3052def _convert_selectv2(pfor_input):
3053  pfor_input.expanddim_inputs_for_broadcast()
3054  cond = pfor_input.input(0)[0]
3055  t = pfor_input.input(1)[0]
3056  e = pfor_input.input(2)[0]
3057  out = array_ops.where_v2(cond, t, e)
3058  return wrap(out, True)
3059
3060
3061# random_ops
3062
3063
3064def _transpose_dim_to_front(x, dim):
3065  rank = array_ops.rank(x)
3066  return array_ops.transpose(
3067      x,
3068      perm=array_ops.concat(
3069          [[dim], math_ops.range(0, dim),
3070           math_ops.range(dim + 1, rank)],
3071          axis=0))
3072
3073
3074@RegisterPForWithArgs("RandomUniform")
3075@RegisterPForWithArgs("RandomUniformInt")
3076@RegisterPForWithArgs("RandomStandardNormal")
3077@RegisterPForWithArgs("TruncatedNormal")
3078def _convert_random(pfor_input, op_type, *args, **kw_args):
3079  del args
3080  del kw_args
3081  inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)]
3082  # inputs[0] is "shape"
3083  inputs[0] = array_ops.concat([pfor_input.pfor.loop_len_vector, inputs[0]],
3084                               axis=0)
3085  logging.warning(
3086      "Note that %s inside pfor op may not give same output as "
3087      "inside a sequential loop.", op_type)
3088  outputs = _create_op(
3089      op_type,
3090      inputs, [x.dtype for x in pfor_input.outputs],
3091      attrs=pfor_input.op.node_def.attr).outputs
3092  return [wrap(x, True) for x in outputs]
3093
3094
3095@RegisterPFor("RandomGamma")
3096@RegisterPFor("RandomPoissonV2")
3097def _convert_random_with_param(pfor_input):
3098  shape = pfor_input.unstacked_input(0)
3099  # param is lam (Poisson rate) or alpha (Gamma shape).
3100  param, param_stacked, _ = pfor_input.input(1)
3101  logging.warning(
3102      "Note that %s inside pfor op may not give same output as "
3103      "inside a sequential loop.", pfor_input.op_type)
3104
3105  if param_stacked:
3106    samples = _create_op(
3107        pfor_input.op_type,
3108        inputs=[shape, param],
3109        op_dtypes=[x.dtype for x in pfor_input.outputs],
3110        attrs=pfor_input.op.node_def.attr).outputs[0]
3111    loop_dim = array_ops.shape(shape)[0]
3112    stacked_samples = _transpose_dim_to_front(samples, loop_dim)
3113  else:
3114    shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
3115    stacked_samples = _create_op(
3116        pfor_input.op_type,
3117        inputs=[shape, param],
3118        op_dtypes=[x.dtype for x in pfor_input.outputs],
3119        attrs=pfor_input.op.node_def.attr).outputs[0]
3120
3121  return wrap(stacked_samples, True)
3122
3123
3124@RegisterPFor("Multinomial")
3125def _convert_multinomial(pfor_input):
3126  logits, logits_stacked, _ = pfor_input.input(0)
3127  num_samples = pfor_input.unstacked_input(1)
3128  seed = pfor_input.get_attr("seed")
3129  seed2 = pfor_input.get_attr("seed2")
3130  output_dtype = pfor_input.get_attr("output_dtype")
3131  logging.warning(
3132      "Note that Multinomial inside pfor op may not give same output as "
3133      "inside a sequential loop.")
3134
3135  n = pfor_input.pfor.loop_len_vector[0]
3136  if logits_stacked:
3137    flattened_logits = _flatten_first_two_dims(logits)
3138    samples = gen_random_ops.multinomial(
3139        flattened_logits,
3140        num_samples,
3141        seed=seed,
3142        seed2=seed2,
3143        output_dtype=output_dtype)
3144    stacked_samples = _unflatten_first_dim(samples, [n])
3145  else:
3146    samples = gen_random_ops.multinomial(
3147        logits,
3148        num_samples * n,
3149        seed=seed,
3150        seed2=seed2,
3151        output_dtype=output_dtype)
3152    stacked_samples = array_ops.transpose(
3153        array_ops.reshape(samples, [-1, n, num_samples]), [1, 0, 2])
3154
3155  return wrap(stacked_samples, True)
3156
3157
3158@RegisterPFor("StatelessMultinomial")
3159@RegisterPFor("StatelessParameterizedTruncatedNormal")
3160@RegisterPFor("StatelessRandomBinomial")
3161@RegisterPFor("StatelessRandomGammaV2")
3162@RegisterPFor("StatelessRandomNormal")
3163@RegisterPFor("StatelessRandomPoisson")
3164@RegisterPFor("StatelessRandomUniform")
3165@RegisterPFor("StatelessRandomUniformInt")
3166@RegisterPFor("StatelessRandomUniformFullInt")
3167@RegisterPFor("StatelessTruncatedNormal")
3168def _convert_stateless_multinomial(pfor_input):
3169  # Unlike stateful random ops, for stateless ones we want better
3170  # reproducibility based on seed. Hence we don't want to use a similar strategy
3171  # as used for stateful ones where we generate a possibly different set of
3172  # random numbers under vectorization.
3173  # Unfortunately, the kernels currently are not necessarily setup to do this
3174  # efficiently and hence we fallback to a sequential loop for vectorization.
3175  return _fallback_converter(pfor_input, warn=False)
3176
3177
3178# linalg_ops
3179
3180
3181@RegisterPForWithArgs("XlaEinsum")
3182@RegisterPForWithArgs("Einsum")
3183def _convert_einsum(pfor_input, op_type):
3184  first_input, first_input_stacked, _ = pfor_input.input(0)
3185  second_input, second_input_stacked, _ = pfor_input.input(1)
3186
3187  # Parse the einsum equation.
3188  equation = pfor_input.get_attr("equation").decode("utf-8")
3189  input_expr, output_expr = equation.split("->")
3190  input_a_expr, input_b_expr = input_expr.split(",")
3191
3192  # pick a placeholder symbol to use for the new axis
3193  chosen_symbol = None
3194  for s in string.ascii_letters:
3195    if s in equation:
3196      continue
3197    else:
3198      chosen_symbol = s
3199      break
3200
3201  if chosen_symbol is None:
3202    raise ValueError("Could not figure out what symbol to use for new axis.")
3203
3204  assert first_input_stacked or second_input_stacked
3205  if first_input_stacked:
3206    input_a_expr = "{}{}".format(chosen_symbol, input_a_expr)
3207  if second_input_stacked:
3208    input_b_expr = "{}{}".format(chosen_symbol, input_b_expr)
3209  output_expr = "{}{}".format(chosen_symbol, output_expr)
3210
3211  new_equation = "{},{}->{}".format(input_a_expr, input_b_expr, output_expr)
3212  if op_type == "XlaEinsum":
3213    result = xla.einsum(equation=new_equation, a=first_input, b=second_input)
3214  else:
3215    assert op_type == "Einsum"
3216    result = special_math_ops.einsum(new_equation, first_input, second_input)
3217
3218  return wrap(result, True)
3219
3220
3221@RegisterPFor("Cholesky")
3222def _convert_cholesky(pfor_input):
3223  t = pfor_input.stacked_input(0)
3224  return wrap(linalg_ops.cholesky(t), True)
3225
3226
3227@RegisterPFor("LogMatrixDeterminant")
3228def _convert_log_matrix_determinant(pfor_input):
3229  t = pfor_input.stacked_input(0)
3230  return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)]
3231
3232
3233@RegisterPFor("MatrixInverse")
3234def _convert_matrix_inverse(pfor_input):
3235  t = pfor_input.stacked_input(0)
3236  adjoint = pfor_input.get_attr("adjoint")
3237  return wrap(gen_linalg_ops.matrix_inverse(t, adjoint=adjoint), True)
3238
3239
3240@RegisterPFor("MatrixSolve")
3241def _convert_matrix_solve(pfor_input):
3242  pfor_input.stack_inputs()
3243  matrix = pfor_input.stacked_input(0)
3244  rhs = pfor_input.stacked_input(1)
3245  adjoint = pfor_input.get_attr("adjoint")
3246  output = gen_linalg_ops.matrix_solve(
3247      matrix, rhs, adjoint=adjoint)
3248  return wrap(output, True)
3249
3250
3251@RegisterPFor("MatrixTriangularSolve")
3252def _convert_matrix_triangular_solve(pfor_input):
3253  pfor_input.expanddim_inputs_for_broadcast()
3254  matrix = pfor_input.input(0)[0]
3255  rhs = pfor_input.input(1)[0]
3256  lower = pfor_input.get_attr("lower")
3257  adjoint = pfor_input.get_attr("adjoint")
3258  output = linalg_ops.matrix_triangular_solve(
3259      matrix, rhs, lower=lower, adjoint=adjoint)
3260  return wrap(output, True)
3261
3262
3263@RegisterPFor("SelfAdjointEigV2")
3264def _convert_self_adjoint_eig(pfor_input):
3265  t = pfor_input.stacked_input(0)
3266  compute_v = pfor_input.get_attr("compute_v")
3267  e, v = gen_linalg_ops.self_adjoint_eig_v2(t, compute_v=compute_v)
3268  # If compute_v is False, v will have shape [0].
3269  return wrap(e, True), wrap(v, compute_v)
3270
3271
3272# logging_ops
3273
3274
3275@RegisterPFor("Assert")
3276def _convert_assert(pfor_input):
3277  cond, cond_stacked, _ = pfor_input.input(0)
3278  if cond_stacked:
3279    cond = math_ops.reduce_all(cond)
3280
3281  data_list = [x.t for x in pfor_input.inputs][1:]
3282  return _create_op(
3283      "Assert", [cond] + data_list, [], attrs=pfor_input.op.node_def.attr)
3284
3285
3286@RegisterPFor("Print")
3287def _convert_print(pfor_input):
3288  # Note that we don't stack all the inputs. Hence unstacked values are printed
3289  # once here vs multiple times in a while_loop.
3290  pfor_input.stack_inputs([0])
3291  outputs = _create_op(
3292      "Print", [x.t for x in pfor_input.inputs],
3293      [x.dtype for x in pfor_input.outputs],
3294      attrs=pfor_input.op.node_def.attr).outputs
3295  return [wrap(x, True) for x in outputs]
3296
3297
3298# data_flow_ops
3299
3300# TensorArray conversion is tricky since we don't support arrays of
3301# TensorArrays. For converting them, we consider two distinct cases:
3302#
3303# 1. The array is constructed outside the pfor call, and read/written inside the
3304# loop.
3305# This is an easier case since we don't need to make an array of TensorArrays.
3306# A correctness requirement is that these parallel iterations shouldn't attempt
3307# to write to the same location. Hence at conversion time we disallow indices to
3308# be loop-invariant as that would guarantee a collision. Even if the indices are
3309# not loop-invariant, they could conflict and that shall trigger runtime errors.
3310#
3311# 2. The array is constructed and used entirely inside each pfor iteration.
3312# For simplicity, here we require that the indices used for write/scatter are
3313# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in
3314# different pfor iterations. We consider two sub_cases:
3315#
3316# 2a Elements written to the array are "stacked"
3317# To simulate multiple TensorArrays, we may increase the dimension of each
3318# element of the array. i.e. the i_th row of the j_th entry of the converted
3319# TensorArray corresponds to the j_th entry of the TensorArray in the i_th
3320# pfor iteration.
3321#
3322# 2b Elements written to the array are "unstacked"
3323# In this case we don't increase the dimensions to avoid redundant tiling. Each
3324# iteration is trying to write the same value. So we convert that to a single
3325# write.
3326#
3327# Here are some tricks used to implement the above:
3328# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of
3329# trying to trace whether future writes are stacked or unstacked in order to set
3330# this attr, we set it to correspond to unknown shape.
3331# - We use the "flow" output of the different ops to track whether the array
3332# elements are stacked or unstacked. If a stacked write/scatter is done, we make
3333# the flow stacked as well.
3334# - We use some heuristic traversal of the graph to track whether the
3335# TensorArray handle was created inside or outside the pfor loop.
3336
3337
3338@RegisterPFor("TensorArrayV3")
3339def _convert_tensor_array_v3(pfor_input):
3340  size = pfor_input.unstacked_input(0)
3341  dtype = pfor_input.get_attr("dtype")
3342  dynamic_size = pfor_input.get_attr("dynamic_size")
3343  clear_after_read = pfor_input.get_attr("clear_after_read")
3344  identical_element_shapes = pfor_input.get_attr("identical_element_shapes")
3345  tensor_array_name = pfor_input.get_attr("tensor_array_name")
3346  handle, flow = data_flow_ops.tensor_array_v3(
3347      size,
3348      dtype=dtype,
3349      # We don't set element shape since we don't know if writes are stacked or
3350      # not yet.
3351      element_shape=None,
3352      dynamic_size=dynamic_size,
3353      clear_after_read=clear_after_read,
3354      identical_element_shapes=identical_element_shapes,
3355      tensor_array_name=tensor_array_name)
3356  # Note we keep flow unstacked for now since we don't know if writes will be
3357  # stacked or not.
3358  return wrap(handle, False), wrap(flow, False)
3359
3360
3361@RegisterPFor("TensorArraySizeV3")
3362def _convert_tensor_array_size_v3(pfor_input):
3363  handle = pfor_input.unstacked_input(0)
3364  flow, flow_stacked, _ = pfor_input.input(1)
3365  if flow_stacked:
3366    flow = _unstack_flow(flow)
3367  size = data_flow_ops.tensor_array_size_v3(handle, flow)
3368  return wrap(size, False)
3369
3370
3371def _handle_inside_pfor(pfor_input, handle):
3372  """Returns True if handle was created inside the pfor loop."""
3373  # We use some heuristic to find the original TensorArray creation op.
3374  # The logic should handle the common cases (except cond based subgraphs).
3375  # In theory the user could perform different operations on the handle (like
3376  # Reshape, stack multiple handles, etc) which could break this logic.
3377  # TODO(agarwal): handle Switch/Merge.
3378  while handle.op.type in ("Enter", "Identity"):
3379    handle = handle.op.inputs[0]
3380  if handle.op.type not in [
3381      "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape"
3382  ]:
3383    raise ValueError("Unable to find source for handle %s" % handle)
3384  else:
3385    return pfor_input.pfor.op_is_inside_loop(handle.op)
3386
3387
3388def _unstack_flow(value):
3389  # TODO(agarwal): consider looking if this is a Tile op then get its input.
3390  # This may avoid running the Tile operations.
3391  return array_ops.gather(value, 0)
3392
3393
3394@RegisterPFor("TensorArrayReadV3")
3395def _convert_tensor_array_read_v3(pfor_input):
3396  handle = pfor_input.unstacked_input(0)
3397  index, index_stacked, _ = pfor_input.input(1)
3398  dtype = pfor_input.get_attr("dtype")
3399  flow, flow_stacked, _ = pfor_input.input(2)
3400  if flow_stacked:
3401    flow = _unstack_flow(flow)
3402
3403  is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3404  if is_inside_pfor:
3405    # Note that if we are inside a control flow construct inside the pfor, and
3406    # only some of the iterations are doing the read (i.e.
3407    # `all_indices_partitioned` is True), then the read operation should only
3408    # return values for the currently active pfor iterations (`all_indices`
3409    # below). Hence, whenever the returned value is stacked (i.e. `flow` is
3410    # stacked), we may need to do an extra gather after reading the values. Also
3411    # note that if `is_inside` is false, then values in the tensor array are
3412    # unstacked. So the check is only needed in this branch.
3413    all_indices = pfor_input.pfor.all_indices
3414    all_indices_partitioned = pfor_input.pfor.all_indices_partitioned
3415    # Note: flow_stacked indicates if values in the TensorArray are stacked or
3416    # not.
3417    if index_stacked:
3418      if flow_stacked:
3419        raise ValueError(
3420            "It looks like TensorArrayReadV3 was called on a TensorArray whose"
3421            " values are not loop-invariant, and the read indices were also"
3422            " not loop invariant. This is currently unsupported.")
3423      value = data_flow_ops.tensor_array_gather_v3(
3424          handle, index, flow, dtype=dtype)
3425      return wrap(value, True)
3426    value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype)
3427    if flow_stacked and all_indices_partitioned:
3428      value = array_ops.gather(value, all_indices)
3429    return wrap(value, flow_stacked)
3430  # Values in the TensorArray should be unstacked (since different iterations
3431  # couldn't write to the same location). So whether output is stacked or not
3432  # depends on index_stacked.
3433  if index_stacked:
3434    value = data_flow_ops.tensor_array_gather_v3(
3435        handle, index, flow, dtype=dtype)
3436  else:
3437    value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype)
3438  return wrap(value, index_stacked)
3439
3440
3441@RegisterPFor("TensorArrayWriteV3")
3442def _convert_tensor_array_write_v3(pfor_input):
3443  handle = pfor_input.unstacked_input(0)
3444  index, index_stacked, _ = pfor_input.input(1)
3445  value, value_stacked, _ = pfor_input.input(2)
3446  flow, flow_stacked, _ = pfor_input.input(3)
3447  if value_stacked and pfor_input.pfor.all_indices_partitioned:
3448    # Looks like we are in a control flow in a pfor where not all iterations are
3449    # active now. We don't allow that since that could lead to different indices
3450    # having different shapes which will be hard to merge later.
3451    raise ValueError("Writing non loop invariant values to TensorArray from "
3452                     "inside a while_loop/cond not supported.")
3453  if flow_stacked:
3454    flow = _unstack_flow(flow)
3455  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3456  if is_inside:
3457    if index_stacked:
3458      raise ValueError("Need indices for %s to be loop invariant" % handle)
3459    if not flow_stacked and not value_stacked:
3460      flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
3461      return wrap(flow_out, False)
3462    else:
3463      if not value_stacked:
3464        value = _stack(value, pfor_input.pfor.loop_len_vector).t
3465      # TODO(agarwal): Note that if flow is unstacked and value is stacked, then
3466      # this may or may not be a safe situation. flow is unstacked both for a
3467      # freshly created TensorArray, as well as after unstacked values are
3468      # written to it. If it is the latter, then we cannot write a stacked value
3469      # now since that may cause runtime errors due to different shapes in the
3470      # array. At the moment we are not able to handle this gracefully and
3471      # distinguish between the two cases. That would require some heuristic
3472      # traversal of the graph to figure out whether all the writes are
3473      # unstacked or not.
3474      flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
3475      return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3476  else:
3477    if not index_stacked:
3478      raise ValueError("Need indices for %s to be not loop invariant" % handle)
3479    # Note that even when index_stacked is true, actual values in index may
3480    # still not be unique. However that will cause runtime error when executing
3481    # the scatter operation below.
3482    if not value_stacked:
3483      value = _stack(value, pfor_input.pfor.loop_len_vector).t
3484    flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow)
3485    return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3486
3487
3488def _transpose_first_two_dims(value):
3489  # TODO(agarwal): optimize if one of the dims == 1.
3490  value_shape = array_ops.shape(value)
3491  v0 = value_shape[0]
3492  v1 = value_shape[1]
3493  value = array_ops.reshape(value, [v0, v1, -1])
3494  value = array_ops.transpose(value, [1, 0, 2])
3495  new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0)
3496  return array_ops.reshape(value, new_shape)
3497
3498
3499@RegisterPFor("TensorArrayGatherV3")
3500def _convert_tensor_array_gather_v3(pfor_input):
3501  handle = pfor_input.unstacked_input(0)
3502  indices, indices_stacked, _ = pfor_input.input(1)
3503  indices = array_ops.reshape(indices, [-1])
3504  flow, flow_stacked, _ = pfor_input.input(2)
3505  if flow_stacked:
3506    flow = _unstack_flow(flow)
3507  dtype = pfor_input.get_attr("dtype")
3508  # TODO(agarwal): support element_shape attr?
3509
3510  n = pfor_input.pfor.loop_len_vector
3511  value = data_flow_ops.tensor_array_gather_v3(
3512      handle, indices, flow, dtype=dtype)
3513  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3514  if is_inside:
3515    # flow_stacked indicates if values in the TensorArray are stacked or not.
3516    if indices_stacked:
3517      if flow_stacked:
3518        raise ValueError(
3519            "It looks like TensorArrayGatherV3 was called on a TensorArray "
3520            "whose values are not loop-invariant, and the indices were also "
3521            "not loop invariant. This is currently unsupported.")
3522      else:
3523        value = _unflatten_first_dim(value, n)
3524        return wrap(value, True)
3525    else:
3526      if flow_stacked:
3527        # Since elements in this array are stacked and `value` was produced by
3528        # gather, its first two dims are "gathered elements" and "stack
3529        # dimension". Our semantics require these two to be flipped.
3530        value = _transpose_first_two_dims(value)
3531      return wrap(value, flow_stacked)
3532  else:
3533    # Values in the TensorArray should be unstacked (since different iterations
3534    # couldn't write to the same location). So whether output is stacked or not
3535    # depends on indices_stacked.
3536    if indices_stacked:
3537      value = _unflatten_first_dim(value, n)
3538    return wrap(value, indices_stacked)
3539
3540
3541@RegisterPFor("TensorArrayScatterV3")
3542def _convert_tensor_array_scatter_v3(pfor_input):
3543  handle = pfor_input.unstacked_input(0)
3544  indices, indices_stacked, _ = pfor_input.input(1)
3545  indices = array_ops.reshape(indices, [-1])
3546  value, value_stacked, _ = pfor_input.input(2)
3547  flow, flow_stacked, _ = pfor_input.input(3)
3548
3549  if flow_stacked:
3550    flow = _unstack_flow(flow)
3551
3552  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3553  if is_inside:
3554    if indices_stacked:
3555      raise ValueError("Need indices for %s to be loop invariant" % handle)
3556    # Note that flow_stacked indicates if existing values in the array are
3557    # stacked or not.
3558    if not flow_stacked and not value_stacked:
3559      flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
3560                                                       flow)
3561      return wrap(flow_out, False)
3562    if not value_stacked:
3563      # TODO(agarwal): tile in the second dimension directly instead of
3564      # transposing below.
3565      value = _stack(value, pfor_input.pfor.loop_len_vector).t
3566
3567    value = _transpose_first_two_dims(value)
3568    # TODO(agarwal): Note that if a previous write was unstacked, flow will be
3569    # unstacked, and a stacked value may be written here which may cause
3570    # runtime error due to different elements having different shape. We do
3571    # not try to prevent that.
3572    flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
3573                                                     flow)
3574    return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3575  if not indices_stacked:
3576    raise ValueError("Need indices for %s to be not loop invariant" % handle)
3577  if not value_stacked:
3578    value = _stack(value, pfor_input.pfor.loop_len_vector).t
3579  value = _flatten_first_two_dims(value)
3580  flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, flow)
3581  return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3582
3583
3584@RegisterPFor("TensorArrayGradV3")
3585def _convert_tensor_array_grad_v3(pfor_input):
3586  handle = pfor_input.unstacked_input(0)
3587  flow, flow_stacked, _ = pfor_input.input(1)
3588  if flow_stacked:
3589    flow = _unstack_flow(flow)
3590  source = pfor_input.get_attr("source")
3591  # TODO(agarwal): For now, we assume that gradients are stacked if the
3592  # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong
3593  # will give runtime error due to incorrect shape being written to the
3594  # accumulator. It is difficult to know in advance if gradients written will be
3595  # stacked or not. Note that flow being stacked is not indicative of the
3596  # gradient being stacked or not. Revisit this later.
3597  shape_to_prepend = pfor_input.pfor.loop_len_vector
3598  grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape(
3599      handle=handle,
3600      flow_in=flow,
3601      shape_to_prepend=shape_to_prepend,
3602      source=source)
3603  flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t
3604  return [wrap(grad_handle, False), wrap(flow_out, True)]
3605
3606
3607def _stack_tensor_list_shape(shape, first_dim):
3608  shape_value = tensor_util.constant_value(shape)
3609  # Note that negative values in the shape are used to signify unknown shapes
3610  # and are handled in a special way.
3611  if shape_value is not None:
3612    shape_value = np.asarray(shape_value)
3613    if -1 in shape_value:
3614      return constant_op.constant(-1)
3615    elif not shape_value.size:
3616      return first_dim
3617  else:
3618    shape = array_ops.reshape(shape, [-1])
3619    return control_flow_ops.cond(
3620        math_ops.reduce_any(shape < 0),
3621        lambda: constant_op.constant(-1),
3622        lambda: array_ops.concat([first_dim, shape], axis=0))
3623
3624
3625def _tile_variant_with_length(t, length):
3626  """stacks `t` `length` times."""
3627  if _is_variant_with_internal_stacking(t):
3628    # The content of TensorLists is vectorized, not the variant itself.
3629    return t
3630  original_tensor = t
3631  t.set_shape([])
3632  t = array_ops.reshape(t, [-1])
3633  with ops.device("CPU:0"):
3634    result = array_ops.tile(t, length)
3635    # TODO(b/169968286): Should regular shape functions do handle data
3636    # propagation here?
3637    custom_gradient.copy_handle_data(original_tensor, result)
3638    return result
3639
3640
3641def _tile_variant(t, pfor_input):
3642  """stacks `t` according to its loop context."""
3643  return _tile_variant_with_length(t, pfor_input.pfor.loop_len_vector)
3644
3645
3646def _untile_variant(t):
3647  if _is_variant_with_internal_stacking(t):
3648    # The content of TensorLists is vectorized, not the variant itself.
3649    if not t.shape.is_compatible_with([]):
3650      raise AssertionError(
3651          ("Unexpectedly saw a vectorized variant (e.g. TensorList) with "
3652           "non-scalar shape: {!r}").format(t))
3653    return t
3654  return array_ops.gather(t, 0)
3655
3656
3657@RegisterPFor("OptionalFromValue")
3658def _convert_optional_from_value(pfor_input):
3659  pfor_input.stack_inputs()
3660  return wrap(
3661      gen_dataset_ops.optional_from_value([x.t for x in pfor_input.inputs]),
3662      True)
3663
3664
3665@RegisterPFor("OptionalGetValue")
3666def _convert_optional_get_value(pfor_input):
3667  handle = pfor_input.stacked_input(0)
3668  output_types = pfor_input.get_attr("output_types")
3669  original_output_shapes = pfor_input.get_attr("output_shapes")
3670  output_shapes = []
3671  for shape in original_output_shapes:
3672    shape = tensor_shape.TensorShape(shape)
3673    loop_len_shape = tensor_shape.TensorShape(
3674        [tensor_util.constant_value(pfor_input.pfor.loop_len_vector)])
3675    shape = loop_len_shape.concatenate(shape)
3676    output_shapes.append(shape.as_proto())
3677  results = gen_dataset_ops.optional_get_value(handle, output_types,
3678                                               output_shapes)
3679  return [wrap(t, True) for t in results]
3680
3681
3682@RegisterPFor("TensorListReserve")
3683def _convert_tensor_list_reserve(pfor_input):
3684  element_shape = pfor_input.unstacked_input(0)
3685  num_elements = pfor_input.unstacked_input(1)
3686  element_dtype = pfor_input.get_attr("element_dtype")
3687
3688  # Prepend a dimension to element_shape.
3689  element_shape = _stack_tensor_list_shape(element_shape,
3690                                           pfor_input.pfor.loop_len_vector)
3691  handle = list_ops.tensor_list_reserve(
3692      element_shape, num_elements, element_dtype=element_dtype)
3693
3694  return wrap(_tile_variant(handle, pfor_input), True)
3695
3696
3697@RegisterPFor("TensorListElementShape")
3698def _convert_tensor_list_element_shape(pfor_input):
3699  handle = _untile_variant(pfor_input.stacked_input(0))
3700  shape_type = pfor_input.get_attr("shape_type")
3701  shape = list_ops.tensor_list_element_shape(handle, shape_type)
3702  shape = array_ops.reshape(shape, [-1])
3703  shape = shape[1:]
3704  return wrap(shape, False)
3705
3706
3707@RegisterPFor("TensorListLength")
3708def _convert_tensor_list_length(pfor_input):
3709  handle = _untile_variant(pfor_input.stacked_input(0))
3710  return wrap(list_ops.tensor_list_length(handle), False)
3711
3712
3713def _stack_tensor_list(handle, dtype, loop_len_vector, element_shape=None):
3714  if element_shape is None:
3715    element_shape = list_ops.tensor_list_element_shape(handle, dtypes.int32)
3716  length = list_ops.tensor_list_length(handle)
3717  new_handle = list_ops.tensor_list_reserve(
3718      _stack_tensor_list_shape(element_shape, loop_len_vector), length, dtype)
3719
3720  def _body_fn(i, h):
3721    elem = list_ops.tensor_list_get_item(handle, i, dtype, element_shape)
3722    elem = _stack(elem, loop_len_vector).t
3723    return i + 1, list_ops.tensor_list_set_item(h, i, elem)
3724
3725  return control_flow_ops.while_loop(lambda i, _: i < length, _body_fn,
3726                                     [0, new_handle])[1]
3727
3728
3729@RegisterPFor("TensorListGetItem")
3730def _convert_tensor_list_get_item(pfor_input):
3731  handle, handle_stacked, _ = pfor_input.input(0)
3732  index, index_stacked, _ = pfor_input.input(1)
3733  element_shape = pfor_input.unstacked_input(2)
3734  element_dtype = pfor_input.get_attr("element_dtype")
3735
3736  if handle_stacked:
3737    handle = _untile_variant(handle)
3738    element_shape = _stack_tensor_list_shape(element_shape,
3739                                             pfor_input.pfor.loop_len_vector)
3740    if index_stacked:
3741      # We use a sequential loop since that may be more efficient than first
3742      # gathering and concatenating all the element corresponding to `index`,
3743      # and then doing a gather on it.
3744      def _map_fn(i):
3745        item_i = list_ops.tensor_list_get_item(
3746            handle,
3747            index[i],
3748            element_dtype=element_dtype)
3749        return array_ops.gather(item_i, i)
3750
3751      output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices)
3752      return wrap(output, True)
3753    else:
3754      output = list_ops.tensor_list_get_item(
3755          handle,
3756          index,
3757          element_shape=element_shape,
3758          element_dtype=element_dtype)
3759      return wrap(output, True)
3760  else:
3761    assert index_stacked
3762    return wrap(
3763        list_ops.tensor_list_gather(
3764            handle,
3765            index,
3766            element_shape=element_shape,
3767            element_dtype=element_dtype), True)
3768
3769
3770@RegisterPFor("TensorListSetItem")
3771def _convert_tensor_array_set_item(pfor_input):
3772  handle, handle_stacked, _ = pfor_input.input(0)
3773  index, index_stacked, _ = pfor_input.input(1)
3774  item, item_stacked, _ = pfor_input.input(2)
3775
3776  if not handle_stacked:
3777    # Special case where we can statically guarantee that the indices are
3778    # disjoint.
3779    if index is pfor_input.pfor.all_indices:
3780      if not item_stacked:
3781        item = _stack(item, pfor_input.pfor.loop_len_vector).t
3782      return wrap(
3783          list_ops.tensor_list_scatter(item, index, input_handle=handle), False)
3784    else:
3785      handle = _stack_tensor_list(handle, item.dtype,
3786                                  pfor_input.pfor.loop_len_vector)
3787  else:
3788    handle = _untile_variant(handle)
3789
3790  if index_stacked:
3791    # TODO(agarwal): handle this.
3792    raise ValueError("Vectorizing writes to a TensorList with loop "
3793                     "variant indices is currently unsupported.")
3794
3795  else:
3796    if not item_stacked:
3797      item = _stack(item, pfor_input.pfor.loop_len_vector).t
3798    handle = list_ops.tensor_list_set_item(handle, index, item)
3799    return wrap(_tile_variant(handle, pfor_input), True)
3800
3801
3802@RegisterPFor("TensorListPushBack")
3803def _convert_tensor_list_push_back(pfor_input):
3804  handle, handle_stacked, _ = pfor_input.input(0)
3805  tensor, tensor_stacked, _ = pfor_input.input(1)
3806  if handle_stacked:
3807    handle = _untile_variant(handle)
3808  else:
3809    handle = _stack_tensor_list(handle, tensor.dtype,
3810                                pfor_input.pfor.loop_len_vector)
3811  if not tensor_stacked:
3812    tensor = _stack(tensor, pfor_input.pfor.loop_len_vector).t
3813  handle = list_ops.tensor_list_push_back(handle, tensor)
3814  return wrap(_tile_variant(handle, pfor_input), True)
3815
3816
3817@RegisterPFor("TensorListPopBack")
3818def _convert_tensor_array_push_back(pfor_input):
3819  handle = pfor_input.stacked_input(0)
3820  element_shape = pfor_input.unstacked_input(1)
3821  handle = _untile_variant(handle)
3822
3823  if element_shape.shape.ndims == 0:
3824    # Default / unspecified
3825    vectorized_shape = -1
3826  else:
3827    # PopBack has an element shape set when it's the gradient of PushBack, only
3828    # used when the list is uninitialized.
3829    vectorized_shape = array_ops.concat(
3830        [pfor_input.pfor.loop_len_vector, element_shape], axis=0)
3831
3832  output_handle, tensor = gen_list_ops.tensor_list_pop_back(
3833      input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"),
3834      element_shape=vectorized_shape)
3835  return wrap(output_handle, True), wrap(tensor, True)
3836
3837
3838@RegisterPFor("TensorListConcatV2")
3839def _convert_tensor_list_concat_v2(pfor_input):
3840  input_handle = pfor_input.stacked_input(0)
3841  element_shape = pfor_input.unstacked_input(1)
3842  leading_dims = pfor_input.unstacked_input(2)
3843  element_dtype = pfor_input.get_attr("element_dtype")
3844
3845  handle = _untile_variant(input_handle)
3846  length = list_ops.tensor_list_length(handle)
3847  # Note that element_shape attribute can have incomplete shapes. This doesn't
3848  # seem to work well when creating another list and then doing a concat on it.
3849  # Hence we try to find the dynamic shape here.
3850  element_shape = control_flow_ops.cond(
3851      length > 0, lambda: array_ops.shape(
3852          list_ops.tensor_list_get_item(handle, 0, element_dtype, None)),
3853      lambda: constant_op.constant([0, 0], dtype=dtypes.int32))
3854  # The code below creates a copy of the list with each elements' first two
3855  # dimensions transposed.
3856  new_element_shape = array_ops.concat(
3857      [element_shape[1:2], element_shape[0:1], element_shape[2:]], axis=0)
3858
3859  # Create a new TensorList with elements transposed.
3860  def _transpose_elem(i, h):
3861    elem = list_ops.tensor_list_get_item(handle, i, element_dtype, None)
3862    elem = _transpose_first_two_dims(elem)
3863    return i + 1, list_ops.tensor_list_set_item(h, i, elem)
3864
3865  new_handle = list_ops.tensor_list_reserve(new_element_shape, length,
3866                                            element_dtype)
3867  new_handle = control_flow_ops.while_loop(lambda i, _: i < length,
3868                                           _transpose_elem, [0, new_handle])[1]
3869  output, lengths = gen_list_ops.tensor_list_concat_v2(
3870      input_handle=new_handle,
3871      element_dtype=element_dtype,
3872      element_shape=new_element_shape,
3873      leading_dims=leading_dims)
3874  output = _transpose_first_two_dims(output)
3875  return wrap(output, True), wrap(lengths, False)
3876
3877
3878@RegisterPFor("TensorListStack")
3879def _convert_tensor_list_stack(pfor_input):
3880  handle = pfor_input.stacked_input(0)
3881  input_shape = pfor_input.unstacked_input(1)
3882  element_dtype = pfor_input.get_attr("element_dtype")
3883  num_elements = pfor_input.get_attr("num_elements")
3884
3885  handle = _untile_variant(handle)
3886  input_shape = _stack_tensor_list_shape(input_shape,
3887                                         pfor_input.pfor.loop_len_vector)
3888  output = list_ops.tensor_list_stack(
3889      handle,
3890      element_dtype,
3891      element_shape=input_shape,
3892      num_elements=num_elements)
3893  output = _transpose_first_two_dims(output)
3894  return wrap(output, True)
3895
3896
3897@RegisterPFor("TensorListGather")
3898def _convert_tensor_list_gather(pfor_input):
3899  handle, handle_stacked, _ = pfor_input.input(0)
3900  index, index_stacked, _ = pfor_input.input(1)
3901  element_shape = pfor_input.unstacked_input(2)
3902  element_dtype = pfor_input.get_attr("element_dtype")
3903
3904  if handle_stacked:
3905    handle = _untile_variant(handle)
3906    element_shape = _stack_tensor_list_shape(element_shape,
3907                                             pfor_input.pfor.loop_len_vector)
3908    if index_stacked:
3909      # We use a sequential loop since that may be more efficient than first
3910      # gathering and concatenating all the element corresponding to `index`,
3911      # and then doing a gather on it.
3912      def _map_fn(i):
3913        item_i = list_ops.tensor_list_gather(
3914            handle,
3915            index[i],
3916            element_dtype=element_dtype)
3917        axis = array_ops.rank(index) - 1
3918        return array_ops.gather(item_i, i, axis=axis)
3919
3920      output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices)
3921      return wrap(output, True)
3922    else:
3923      output = list_ops.tensor_list_gather(
3924          handle,
3925          index,
3926          element_shape=element_shape,
3927          element_dtype=element_dtype)
3928      return wrap(output, True)
3929  else:
3930    assert index_stacked
3931    index_shape = array_ops.shape(index)
3932    index = array_ops.reshape(index, [-1])
3933    values = list_ops.tensor_list_gather(
3934        handle, index, element_shape=element_shape, element_dtype=element_dtype)
3935    final_shape = array_ops.concat(
3936        [index_shape, array_ops.shape(values)[1:]], axis=0)
3937    return wrap(array_ops.reshape(values, final_shape), True)
3938
3939
3940@RegisterPFor("TensorListScatterIntoExistingList")
3941def _convert_tensor_list_scatter(pfor_input):
3942  pfor_input.stack_inputs([1])
3943  handle, handle_stacked, _ = pfor_input.input(0)
3944  item = pfor_input.stacked_input(1)
3945  # TODO(agarwal): handle stacked indices.
3946  indices = pfor_input.unstacked_input(2)
3947  if handle_stacked:
3948    handle = _untile_variant(handle)
3949  else:
3950    handle = _stack_tensor_list(handle, item.dtype,
3951                                pfor_input.pfor.loop_len_vector)
3952
3953  item = _transpose_first_two_dims(item)
3954  handle = list_ops.tensor_list_scatter(item, indices, input_handle=handle)
3955  return wrap(_tile_variant(handle, pfor_input), True)
3956
3957
3958@RegisterPFor("TensorListFromTensor")
3959def _convert_tensor_list_from_tensor(pfor_input):
3960  tensor = pfor_input.stacked_input(0)
3961  element_shape = pfor_input.unstacked_input(1)
3962  tensor = _transpose_first_two_dims(tensor)
3963  element_shape = _stack_tensor_list_shape(element_shape,
3964                                           pfor_input.pfor.loop_len_vector)
3965  handle = list_ops.tensor_list_from_tensor(tensor, element_shape)
3966  return wrap(_tile_variant(handle, pfor_input), True)
3967
3968
3969# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar
3970# to TensorArrays, we convert them by changing the dimension of the elements
3971# inside the stack.
3972#
3973# We consider two cases:
3974#
3975# 1. StackV2 is constructed and used entirely inside the pfor loop.
3976# We keep a single Stack and perform the push/pop operations of all the
3977# iterations in lock-step. We also assume that all the iterations perform these
3978# operations. In case of dynamic control flow, if only some of the iterations
3979# try to perform a push/pop, then the conversion may not work correctly and may
3980# cause undefined behavior.
3981# TODO(agarwal): test StackV2 with dynamic control flow.
3982#
3983# 2. StackV2 is constructed outside the pfor loop.
3984# Performing stack push/pop in a parallel fashion is ill-defined. However given
3985# that reading stacks created externally is a common operation when computing
3986# jacobians, we provide some special semantics here as follows.
3987#  - disallow push operations to the stack
3988#  - pop operations are performed in lock step by all iterations, similar to the
3989#  case when the stack is created inside. A single value is popped during the
3990#  lock-step operation and broadcast to all the iterations. Values in the stack
3991#  are assumed to be loop-invariant.
3992#
3993# Some other implementation details:
3994# We use an ugly logic to find whether values in Stack data structure are
3995# loop invariant or not. When converting push/pop operations, we keep track of
3996# whether the last conversion used a stacked value or not (see _stack_cache
3997# below). As a result if an unstacked value is written first, subsequent stacked
3998# writes are disallowed when they could have been allowed in theory.
3999
4000# Map from cache key based on StackV2 handle to a bool indicating whether values
4001# are stacked or not.
4002# TODO(agarwal): move _stack_cache inside pfor?
4003_stack_cache = {}
4004
4005
4006def _stack_cache_key(pfor_input):
4007  """Create cache key corresponding to a stack handle."""
4008  op_type = pfor_input.op_type
4009  assert op_type in ["StackPushV2", "StackPopV2"], op_type
4010  orig_handle = pfor_input.op.inputs[0]
4011  while orig_handle.op.type in ["Identity", "Enter"]:
4012    orig_handle = orig_handle.op.inputs[0]
4013  assert orig_handle.op.type == "StackV2", orig_handle.op
4014  return ops.get_default_graph(), pfor_input.pfor, orig_handle
4015
4016
4017def _stack_handle_inside_pfor(handle, pfor_input):
4018  while handle.op.type in ["Identity", "Enter"]:
4019    handle = handle.op.inputs[0]
4020  assert handle.op.type == "StackV2", ("Unable to find StackV2 op. Got %s" %
4021                                       handle.op)
4022  return pfor_input.pfor.op_is_inside_loop(handle.op)
4023
4024
4025@RegisterPFor("StackPushV2")
4026def _convert_stack_push_v2(pfor_input):
4027  handle = pfor_input.unstacked_input(0)
4028  elem, elem_stacked, _ = pfor_input.input(1)
4029  swap_memory = pfor_input.get_attr("swap_memory")
4030
4031  if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input):
4032    raise ValueError("StackPushV2 not allowed on stacks created outside pfor")
4033  stack_cache_key = _stack_cache_key(pfor_input)
4034  stacked = _stack_cache.get(stack_cache_key, None)
4035  if stacked is None:
4036    stacked = elem_stacked
4037    _stack_cache[stack_cache_key] = stacked
4038  else:
4039    # If we previously made it unstacked then we can't revert to being stacked.
4040    if not stacked and elem_stacked:
4041      raise ValueError(
4042          "It looks like the stack was previously determined to be loop"
4043          " invariant, but we are now trying to push a loop dependent value"
4044          " to it. This is currently unsupported.")
4045    if stacked and not elem_stacked:
4046      elem = _stack(elem, pfor_input.pfor.loop_len_vector).t
4047  out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory)
4048  return wrap(out, stacked)
4049
4050
4051# Note that inputs to this convertor will be unstacked. However it should get
4052# called since it is a stateful op.
4053@RegisterPFor("StackPopV2")
4054def _convert_stack_pop_v2(pfor_input):
4055  handle = pfor_input.unstacked_input(0)
4056  stack_cache_key = _stack_cache_key(pfor_input)
4057  stacked = _stack_cache.get(stack_cache_key, None)
4058  # If a StackPushV2 has not been converted yet, we default to unstacked since
4059  # the push could be outside of pfor, or the convertor may not be called if the
4060  # inputs are unconverted.
4061  if stacked is None:
4062    stacked = False
4063    _stack_cache[stack_cache_key] = False
4064  elem_type = pfor_input.get_attr("elem_type")
4065  out = data_flow_ops.stack_pop_v2(handle, elem_type)
4066  return wrap(out, stacked)
4067
4068
4069# parsing_ops
4070
4071
4072@RegisterPFor("DecodeCSV")
4073def _convert_decode_csv(pfor_input):
4074  lines = pfor_input.stacked_input(0)
4075  record_defaults = [
4076      pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
4077  ]
4078  field_delim = pfor_input.get_attr("field_delim")
4079  use_quote_delim = pfor_input.get_attr("use_quote_delim")
4080  select_cols = pfor_input.get_attr("select_cols")
4081  if not select_cols:
4082    select_cols = None
4083  return [
4084      wrap(t, True) for t in parsing_ops.decode_csv(
4085          lines,
4086          record_defaults,
4087          field_delim=field_delim,
4088          use_quote_delim=use_quote_delim,
4089          select_cols=select_cols)
4090  ]
4091
4092
4093@RegisterPFor("ParseSingleExample")
4094def _convert_parse_single_example(pfor_input):
4095  serialized = pfor_input.stacked_input(0)
4096  dense_defaults = [
4097      pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
4098  ]
4099  sparse_keys = pfor_input.get_attr("sparse_keys")
4100  dense_keys = pfor_input.get_attr("dense_keys")
4101  sparse_types = pfor_input.get_attr("sparse_types")
4102  dense_shapes = pfor_input.get_attr("dense_shapes")
4103  output = gen_parsing_ops.parse_example(
4104      serialized=serialized,
4105      names=[],
4106      dense_defaults=dense_defaults,
4107      sparse_keys=sparse_keys,
4108      dense_keys=dense_keys,
4109      sparse_types=sparse_types,
4110      dense_shapes=dense_shapes)
4111  return [wrap(t, True, True) for t in nest.flatten(output)]
4112
4113
4114@RegisterPFor("ParseExampleV2")
4115def _convert_parse_example_v2(pfor_input):
4116  serialized = pfor_input.stacked_input(0)
4117  sparse_keys = pfor_input.unstacked_input(2)
4118  dense_keys = pfor_input.unstacked_input(3)
4119  ragged_keys = pfor_input.unstacked_input(4)
4120  dense_defaults = [
4121      pfor_input.unstacked_input(i) for i in range(5, pfor_input.num_inputs)
4122  ]
4123  num_sparse = pfor_input.get_attr("num_sparse")
4124  sparse_types = pfor_input.get_attr("sparse_types")
4125  ragged_value_types = pfor_input.get_attr("ragged_value_types")
4126  ragged_split_types = pfor_input.get_attr("ragged_split_types")
4127  dense_shapes = pfor_input.get_attr("dense_shapes")
4128  if serialized.shape.ndims not in (None, 1):
4129    raise ValueError("ParseExampleV2 can only be converted if `serialized` "
4130                     "is scalar.")
4131  output = gen_parsing_ops.parse_example_v2(
4132      serialized=serialized,
4133      names=[],
4134      sparse_keys=sparse_keys,
4135      dense_keys=dense_keys,
4136      ragged_keys=ragged_keys,
4137      dense_defaults=dense_defaults,
4138      num_sparse=num_sparse,
4139      sparse_types=sparse_types,
4140      ragged_value_types=ragged_value_types,
4141      ragged_split_types=ragged_split_types,
4142      dense_shapes=dense_shapes)
4143  return [wrap(t, True, True) for t in nest.flatten(output)]
4144
4145
4146# functional_ops
4147
4148
4149def _convert_function_call(func, converter, inputs):
4150  assert isinstance(func.graph, func_graph.FuncGraph), func
4151  assert isinstance(converter, PFor)
4152
4153  # TODO(agarwal): consider caching this function definition.
4154  @def_function.function
4155  def f(*args):
4156    assert all(isinstance(arg, WrappedTensor) for arg in args), args
4157    assert len(args) == len(func.graph.inputs), (args, func.graph.inputs)
4158    #  Map inputs to function arguments.
4159    for inp, arg in zip(func.graph.inputs, args):
4160      converter._add_conversion(inp, arg)
4161    # Convert output tensors.
4162    return tuple(
4163        [converter._convert_helper(x).t for x in func._func_graph_outputs])
4164
4165  call_outputs = f(*inputs)
4166  assert len(call_outputs) == len(func._func_graph_outputs)
4167  outputs = []
4168  for call_output, output_tensor in zip(call_outputs, func._func_graph_outputs):
4169    func_output = converter._convert_helper(output_tensor)
4170    outputs.append(
4171        wrap(call_output, func_output.is_stacked,
4172             func_output.is_sparse_stacked))
4173  return outputs
4174
4175
4176@RegisterPFor("StatefulPartitionedCall")
4177@RegisterPFor("PartitionedCall")
4178def _convert_partitioned_call(pfor_input):
4179  func_name = pfor_input.get_attr("f").name
4180  func = pfor_input.op.graph._get_function(compat.as_bytes(func_name))
4181  assert isinstance(func.graph, func_graph.FuncGraph), (
4182      "Could not find FuncGraph object for %s. Got func %s" % (func_name, func))
4183  pfor = pfor_input.pfor
4184  converter = PFor(
4185      loop_var=pfor.loop_var,
4186      loop_len=pfor.loop_len_vector[0],
4187      pfor_ops=func.graph.get_operations(),
4188      fallback_to_while_loop=pfor.fallback_to_while_loop,
4189      all_indices=pfor.all_indices,
4190      all_indices_partitioned=pfor.all_indices_partitioned,
4191      pfor_config=pfor.pfor_config)
4192  return _convert_function_call(func, converter, pfor_input.inputs)
4193
4194
4195def _partition_inputs_for_indices(inputs, indices):
4196  new_inputs = []
4197  for inp in inputs:
4198    if inp.is_stacked:
4199      new_inputs.append(wrap(array_ops.gather(inp.t, indices), True))
4200    else:
4201      new_inputs.append(inp)
4202  return new_inputs
4203
4204
4205def _outputs_for_branch(func_name, indices, pfor_input, inputs):
4206  if indices is None:
4207    indices = pfor_input.pfor.all_indices
4208    partitioned = pfor_input.pfor.all_indices_partitioned
4209  else:
4210    partitioned = True
4211  func = pfor_input.op.graph._get_function(func_name)
4212  converter = PFor(
4213      loop_var=pfor_input.pfor.loop_var,
4214      loop_len=array_ops.size(indices),
4215      pfor_ops=func.graph.get_operations(),
4216      fallback_to_while_loop=pfor_input.pfor.fallback_to_while_loop,
4217      all_indices=indices,
4218      all_indices_partitioned=partitioned,
4219      pfor_config=pfor_input.pfor.pfor_config)
4220  outputs = _convert_function_call(func, converter, inputs)
4221  stacked_outputs = []
4222  for out in outputs:
4223    if not out.is_stacked:
4224      stacked_outputs.append(_stack(out.t, [array_ops.size(indices)]).t)
4225    else:
4226      stacked_outputs.append(out.t)
4227  return stacked_outputs
4228
4229
4230# TODO(agarwal): Currently the converted code aggressively tiles loop variant
4231# outputs from the then/else branches. Instead, it could do so only if at least
4232# one of the branch outputs is loop variant.
4233@RegisterPFor("StatelessIf")
4234@RegisterPFor("If")
4235def _convert_if(pfor_input):
4236  cond, cond_stacked, _ = pfor_input.input(0)
4237  inputs = pfor_input.inputs[1:]
4238  then_branch = pfor_input.get_attr("then_branch")
4239  else_branch = pfor_input.get_attr("else_branch")
4240
4241  if cond_stacked:
4242    cond_int = math_ops.cast(cond, dtypes.int32)
4243    # Compute loop indices for the different branches
4244    false_indices, true_indices = data_flow_ops.dynamic_partition(
4245        pfor_input.pfor.all_indices, cond_int, 2)
4246    # Compute indices for cond being True or False.
4247    if pfor_input.pfor.all_indices_partitioned:
4248      else_indices, then_indices = data_flow_ops.dynamic_partition(
4249          math_ops.range(pfor_input.pfor.loop_len_vector[0]),
4250          cond_int, 2)
4251    else:
4252      else_indices, then_indices = false_indices, true_indices
4253    # Partition inputs
4254    then_inputs = _partition_inputs_for_indices(inputs, then_indices)
4255    else_inputs = _partition_inputs_for_indices(inputs, else_indices)
4256
4257    # Convert "then" branch.
4258    then_outputs = _outputs_for_branch(then_branch.name, true_indices,
4259                                       pfor_input, then_inputs)
4260
4261    # Convert "else" branch.
4262    else_outputs = _outputs_for_branch(else_branch.name, false_indices,
4263                                       pfor_input, else_inputs)
4264
4265    assert len(then_outputs) == len(else_outputs)
4266    # Note that if the "then" and "else" branches are updating the same state,
4267    # and possibly reading them as well, it could lead to undefined behavior
4268    # since the ordering of those operations is not well defined.
4269    # One possibility is to order all the "then" branches to execute before all
4270    # the "else" branches so that the side-effects in the former are visible to
4271    # the latter. For now, we leave that as undefined behavior.
4272    outputs = []
4273    # Merge outputs
4274    for then_output, else_output in zip(then_outputs, else_outputs):
4275      out = data_flow_ops.dynamic_stitch([then_indices, else_indices],
4276                                         [then_output, else_output])
4277      outputs.append(wrap(out, True))
4278    return outputs
4279  else:
4280    outputs = control_flow_ops.cond(
4281        cond,
4282        lambda: _outputs_for_branch(then_branch.name, None, pfor_input, inputs),
4283        lambda: _outputs_for_branch(else_branch.name, None, pfor_input, inputs))
4284    return [wrap(t, True) for t in outputs]
4285
4286
4287class WhileV2(object):
4288  """Object for vectorizing V2 while_loop op."""
4289
4290  def __init__(self, pfor_input):
4291    self._pfor_input = pfor_input
4292    self._pfor = pfor_input.pfor
4293    cond_func_name = pfor_input.get_attr("cond").name
4294    self._cond_func = pfor_input.op.graph._get_function(compat.as_bytes(
4295        cond_func_name))
4296    body_func_name = pfor_input.get_attr("body").name
4297    self._body_func = pfor_input.op.graph._get_function(compat.as_bytes(
4298        body_func_name))
4299    if self._cond_func is None or self._body_func is None:
4300      raise ValueError("Error extracting cond and body functions for op %s." % (
4301          self._pfor_input.op))
4302    # Indices of inputs that are passed unchanged through the while loop body.
4303    # Typically these are tensors captured from outside the body context.
4304    self._body_pass_through_indices = set()
4305    for i, (inp, out) in enumerate(zip(self._body_func.graph.inputs,
4306                                       self._body_func.graph.outputs)):
4307      if id(inp) == id(out):
4308        self._body_pass_through_indices.add(i)
4309    self._parallel_iterations = self._pfor_input.get_attr("parallel_iterations")
4310
4311  def _output_shapes(self):
4312    # Calculate output shape for vectorized loop. This will be used as
4313    # shape_invariant. Merges shape inference outputs with the `output_shapes`
4314    # attribute of the op.
4315    output_shapes = [out.shape for out in self._pfor_input.op.outputs]
4316    shapes = self._pfor_input.get_attr("output_shapes")
4317    if not shapes:
4318      shapes = [tensor_shape.TensorShape(None) for _ in output_shapes]
4319    else:
4320      shapes = [tensor_shape.TensorShape(shape) for shape in shapes]
4321    for i, shape in enumerate(shapes):
4322      shape = shape.merge_with(output_shapes[i])
4323      pfor_input = self._pfor_input.input(i)
4324      if pfor_input.is_stacked:
4325        if _is_variant_with_internal_stacking(pfor_input.t):
4326          shape = tensor_shape.TensorShape([]).concatenate(shape)
4327        else:
4328          shape = tensor_shape.TensorShape([None]).concatenate(shape)
4329      output_shapes[i] = shape
4330    assert len(output_shapes) == self._pfor_input.num_inputs
4331    return output_shapes
4332
4333  def _init_values(self):
4334    """Create arguments passed to converted while_loop."""
4335    loop_len = self._pfor.loop_len_vector[0]
4336    inputs = []
4337    # TensorArrays for outputs of converted while loop
4338    output_tas = []
4339
4340    with ops.name_scope("while_init"):
4341      for inp in self._pfor_input.inputs:
4342        inputs.append(inp.t)
4343        output_tas.append(tensor_array_ops.TensorArray(
4344            inp.t.dtype,
4345            size=loop_len,
4346            dynamic_size=False,
4347            infer_shape=True))
4348    # See documentation for __call__ for the structure of init_values.
4349    indices = (
4350        math_ops.range(self._pfor.loop_len_vector[0])
4351        if self._pfor.all_indices_partitioned else self._pfor.all_indices)
4352    return [True, indices] + inputs + output_tas
4353
4354  def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
4355    """Handles case when condition is pfor loop invariant."""
4356    # Note that all iterations end together. So we don't need to partition the
4357    # inputs.
4358    not_all_done = array_ops.reshape(conditions, [])
4359    return not_all_done, indices, inputs, output_tas
4360
4361  def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
4362                            output_tas):
4363    """Handles case when condition is pfor loop dependent."""
4364    # Compute if all iterations are done.
4365    not_all_done = math_ops.reduce_any(conditions)
4366    conditions_int = math_ops.cast(conditions, dtypes.int32)
4367    # Partition the indices.
4368    done_indices, new_indices = data_flow_ops.dynamic_partition(
4369        indices, conditions_int, 2)
4370
4371    new_inputs = []
4372    new_output_tas = []
4373    for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
4374      pass_through = i in self._body_pass_through_indices
4375      # Partition the inputs.
4376      if stacked:
4377        done_inp, new_inp = data_flow_ops.dynamic_partition(
4378            inp, conditions_int, 2)
4379      else:
4380        if not pass_through:
4381          done_inp = _stack(inp, [array_ops.size(done_indices)]).t
4382        new_inp = inp
4383
4384      new_inputs.append(new_inp)
4385      out_ta = output_tas[i]
4386      if not pass_through:
4387        # Note that done_indices can be empty. done_inp should also be empty
4388        # in that case.
4389        out_ta = out_ta.scatter(done_indices, done_inp)
4390      new_output_tas.append(out_ta)
4391
4392    assert len(new_output_tas) == len(output_tas)
4393    assert len(new_inputs) == len(inputs)
4394    return not_all_done, new_indices, new_inputs, new_output_tas
4395
4396  def _process_body(self, inputs_stacked, new_indices, cond_stacked,
4397                    new_inputs, not_all_done):
4398    """Convert the body function."""
4399    # This is used to store the indices of inputs to the while op that need to
4400    # be stacked. This stacking may be needed in cases where the input to the
4401    # while_loop is loop_invariant but the corresponding output is not.
4402    mismatching_stacked_indices = []
4403
4404    def true_fn():
4405      """Converts the body function for all but last iteration."""
4406      wrapped_inputs = [wrap(inp, stacked) for inp, stacked in
4407                        zip(new_inputs, inputs_stacked)]
4408      # Note the iterative process below to figure out loop invariance.
4409      # Here we iterate on vectorization process till a fixed point. The issue
4410      # is that the while body can take pfor loop invariant inputs but return
4411      # loop variant outputs. For any loop variant output, the corresponding
4412      # input has to be then made loop variant (since subsequent while
4413      # iterations will need to see loop variant values).
4414      # However once we make a new input loop variant, we might make other
4415      # outputs loop variant. Hence we need to iterate till we get fixed point.
4416      while True:
4417        if self._pfor.all_indices_partitioned:
4418          indices = array_ops.gather(self._pfor.all_indices, new_indices)
4419        else:
4420          indices = new_indices
4421        body_pfor = PFor(
4422            loop_var=self._pfor.loop_var,
4423            loop_len=array_ops.size(new_indices),
4424            pfor_ops=self._body_func.graph.get_operations(),
4425            fallback_to_while_loop=self._pfor.fallback_to_while_loop,
4426            all_indices=indices,
4427            all_indices_partitioned=(self._pfor.all_indices_partitioned or
4428                                     cond_stacked),
4429            pfor_config=self._pfor.pfor_config)
4430        stacking_mismatch = False
4431        outputs = _convert_function_call(self._body_func,
4432                                         body_pfor,
4433                                         wrapped_inputs)
4434        for i, (out, inp) in enumerate(zip(outputs, wrapped_inputs)):
4435          if out.is_stacked != inp.is_stacked:
4436            stacking_mismatch = True
4437            mismatching_stacked_indices.append(i)
4438            stacked = _stack(inp.t, [array_ops.size(new_indices)])
4439            if inp.t.dtype == dtypes.variant:
4440              stacked = wrap(
4441                  _tile_variant_with_length(stacked.t,
4442                                            [array_ops.size(new_indices)]))
4443            wrapped_inputs[i] = stacked
4444        if not stacking_mismatch:
4445          if mismatching_stacked_indices:
4446            # We needed to stack some inputs. This code will be abandoned and
4447            # should not get executed. Hence we simply return `new_inputs` to
4448            # make sure the graph construction code completes.
4449            with ops.control_dependencies([
4450                control_flow_ops.Assert(
4451                    False, ["pfor ERROR: this branch should never execute"])]):
4452              return [array_ops.identity(x) for x in new_inputs]
4453          else:
4454            return [out.t for out in outputs]
4455
4456    # If all are done, we simply return `new_inputs`. Else we need to run the
4457    # body function.
4458    return control_flow_ops.cond(
4459        not_all_done,
4460        true_fn,
4461        lambda: list(new_inputs)), mismatching_stacked_indices
4462
4463  def __call__(self):
4464    """Converter for the V2 while_loop.
4465
4466    The conversion of a while_loop is another while_loop.
4467
4468    The arguments to this converted while_loop are as follows:
4469    not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
4470      are done.
4471    indices: int32 1-D Tensor storing the id of the pfor iterations that are not
4472      done.
4473    args: Remaining arguments. These can be divided into 2 categories:
4474      - The first set of arguments correspond one-to-one to the inputs to the
4475        unvectorized while_loop.
4476      - The second set are TensorArrays, corresponding one-to-one to each output
4477        of the unvectorized while_loop. Each TensorArray has `PFor.loop_len`
4478        elements, i.e. the number of pfor iterations. At the end, the i'th
4479        element of each TensorArray will contain the output computed by the i'th
4480        iteration of pfor. Note that elements can be written into these tensors
4481        arrays in any order, depending on when the corresponding pfor iteration
4482        is done.
4483    In each iteration, the while_loop body recomputes the condition for all
4484    active pfor iterations to see which of them are now done. It then partitions
4485    all the inputs and passes them along to the converted body. Values for all
4486    the iterations that are done are written to TensorArrays indexed by the pfor
4487    iteration number. When all iterations are done, the TensorArrays are stacked
4488    to get the final value.
4489
4490    Returns:
4491      List of converted outputs.
4492    """
4493    output_shapes = self._output_shapes()
4494    # Note that we use these lists as a hack since we need the `body` to compute
4495    # these values during construction of the while_loop graph.
4496    cond_is_stacked = [None]
4497    indices_to_stack = []
4498
4499    def cond(not_all_done, *_):
4500      return not_all_done
4501
4502    def body(not_all_done, indices, *args):
4503      # See documentation for __call__ for the structure of *args.
4504      num_inputs = self._pfor_input.num_inputs
4505      inputs = args[:num_inputs]
4506      output_tas = args[num_inputs:]
4507      inputs_stacked = [x.is_stacked for x in self._pfor_input.inputs]
4508      assert len(inputs) >= len(output_tas)
4509      assert len(inputs) == len(inputs_stacked)
4510      # Convert condition
4511      with ops.name_scope("while_cond"):
4512        # Note that we set all_indices_partitioned to True here. At this point
4513        # we don't know if indices will be partitioned. Hence we use the
4514        # conservative value.
4515        cond_pfor = PFor(
4516            loop_var=self._pfor.loop_var,
4517            loop_len=array_ops.size(indices),
4518            pfor_ops=self._cond_func.graph.get_operations(),
4519            fallback_to_while_loop=self._pfor.fallback_to_while_loop,
4520            all_indices=indices,
4521            all_indices_partitioned=True,
4522            pfor_config=self._pfor.pfor_config)
4523
4524        wrapped_inputs = [wrap(inp, stacked) for inp, stacked
4525                          in zip(inputs, inputs_stacked)]
4526        conditions, cond_stacked, _ = _convert_function_call(
4527            self._cond_func,
4528            cond_pfor,
4529            wrapped_inputs)[0]
4530        cond_is_stacked[0] = cond_stacked
4531
4532      # Recompute the new condition, write outputs of done iterations, and
4533      # partition the inputs if needed.
4534      if not cond_stacked:
4535        (not_all_done, new_indices, new_inputs,
4536         new_output_tas) = self._process_cond_unstacked(conditions, indices,
4537                                                        inputs, output_tas)
4538      else:
4539        (not_all_done, new_indices, new_inputs,
4540         new_output_tas) = self._process_cond_stacked(conditions, indices,
4541                                                      inputs, inputs_stacked,
4542                                                      output_tas)
4543      # Convert body
4544      with ops.name_scope("while_body"):
4545        #  Compute the outputs from the body.
4546        new_outputs, mismatching_stacked_indices = self._process_body(
4547            inputs_stacked, new_indices, cond_stacked, new_inputs, not_all_done)
4548
4549      indices_to_stack[:] = mismatching_stacked_indices
4550      for i, new_output in enumerate(new_outputs):
4551        new_output.set_shape(output_shapes[i])
4552      new_args = ([not_all_done, new_indices] + new_outputs +
4553                  list(new_output_tas))
4554      return tuple(new_args)
4555
4556    # Note that we run the code below in a function since we might abandon the
4557    # generated code in cases where the conversion dictates that some inputs be
4558    # further stacked. Hence we run the graph construction using
4559    # `get_concrete_function` and avoid calling the constructed function if not
4560    # needed.
4561    @def_function.function
4562    def while_fn():
4563      # Create init_values that will be passed to the while_loop.
4564      init_values = self._init_values()
4565      ta_shape_invariants = [tensor_shape.TensorShape([]) for _ in
4566                             self._pfor_input.outputs]
4567      shape_invariants = (
4568          [tensor_shape.TensorShape([]), tensor_shape.TensorShape([None])]
4569          + output_shapes + ta_shape_invariants)
4570
4571      while_outputs = control_flow_ops.while_loop(
4572          cond, body, init_values,
4573          shape_invariants=shape_invariants,
4574          parallel_iterations=self._parallel_iterations)
4575      if indices_to_stack:
4576        # This function will be abandoned.
4577        return while_outputs
4578      else:
4579        num_inputs = self._pfor_input.num_inputs
4580        new_inputs = while_outputs[2:num_inputs+2]
4581        output_tas = while_outputs[num_inputs+2:]
4582        assert cond_is_stacked[0] is not None
4583        outputs = []
4584        for i, inp in enumerate(new_inputs):
4585          if cond_is_stacked[0]:
4586            if i in self._body_pass_through_indices:
4587              outputs.append(init_values[i + 2])
4588            else:
4589              ta = output_tas[i]
4590              outputs.append(ta.stack())
4591          else:
4592            outputs.append(inp)
4593        return outputs
4594
4595    _ = while_fn.get_concrete_function()
4596    if indices_to_stack:
4597      # Need to abandon the current conversion, stack some inputs and restart.
4598      self._pfor_input.stack_inputs(
4599          stack_indices=indices_to_stack, tile_variants=True)
4600      # Note that this call will recurse at most one time. The first call will
4601      # do the required stacking, based on the iterative procedure in
4602      # _process_body, and the next invocation to __call__ should not need to do
4603      # any more stacking.
4604      # We invoke `self()` here as a way to discard any corrupted state.
4605      return self()
4606    else:
4607      outputs = while_fn()
4608      wrapped_outputs = []
4609      for i, (out, inp) in enumerate(zip(outputs, self._pfor_input.inputs)):
4610        if i not in self._body_pass_through_indices and cond_is_stacked[0]:
4611          wrapped_outputs.append(wrap(out, True))
4612        else:
4613          wrapped_outputs.append(wrap(out, inp.is_stacked))
4614      return wrapped_outputs
4615
4616
4617@RegisterPFor("StatelessWhile")
4618@RegisterPFor("While")
4619def _convert_while(pfor_input):
4620  converter = WhileV2(pfor_input)
4621  return converter()
4622
4623
4624# spectral_ops
4625
4626
4627@RegisterPForWithArgs("FFT", gen_spectral_ops.fft)
4628@RegisterPForWithArgs("FFT2D", gen_spectral_ops.fft2d)
4629@RegisterPForWithArgs("FFT3D", gen_spectral_ops.fft3d)
4630@RegisterPForWithArgs("IFFT", gen_spectral_ops.ifft)
4631@RegisterPForWithArgs("IFFT2D", gen_spectral_ops.ifft2d)
4632@RegisterPForWithArgs("IFFT3D", gen_spectral_ops.ifft3d)
4633def _convert_fft(pfor_input, _, op_func):
4634  return wrap(op_func(pfor_input.stacked_input(0)), True)
4635
4636
4637@RegisterPForWithArgs("RFFT", gen_spectral_ops.rfft, "Tcomplex")
4638@RegisterPForWithArgs("RFFT2D", gen_spectral_ops.rfft2d, "Tcomplex")
4639@RegisterPForWithArgs("RFFT3D", gen_spectral_ops.rfft3d, "Tcomplex")
4640@RegisterPForWithArgs("IRFFT", gen_spectral_ops.irfft, "Treal")
4641@RegisterPForWithArgs("IRFFT2D", gen_spectral_ops.irfft2d, "Treal")
4642@RegisterPForWithArgs("IRFFT3D", gen_spectral_ops.irfft3d, "Treal")
4643def _convert_rfft(pfor_input, _, op_func, attr_name):
4644  inp = pfor_input.stacked_input(0)
4645  fft_length = pfor_input.unstacked_input(1)
4646  attr = pfor_input.get_attr(attr_name)
4647  return wrap(op_func(inp, fft_length, attr), True)
4648