1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for managing state of v1 control flow for computing gradients."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import control_flow_util as util
28from tensorflow.python.ops import control_flow_v2_func_graphs
29from tensorflow.python.ops import default_gradient
30from tensorflow.python.ops import gen_data_flow_ops
31from tensorflow.python.ops import gen_resource_variable_ops
32from tensorflow.python.ops import resource_variable_ops
33
34# pylint: disable=protected-access
35
36
37def _GetMaxSizeFromNestedMaximumIterations(value, while_ctxt):
38  """Calculate a max_size for use by stack ops inside an XLA while_loop.
39
40  Args:
41    value: The value inside the while_loop forward context.  Used for printing
42      error messages.
43    while_ctxt: The forward context inside which value resides.  This does not
44      always match the value's immediate context, as `value` may be inside e.g.
45      a cond context inside the while_loop.
46
47  Returns:
48    A tensor containing the `max_size` to feed to a Stack initializer.
49
50  Raises:
51    ValueError: If `value` is nested inside a `while_loop` that either
52      lacks a `maximum_iterations` parameter, or the `maximum_iterations`
53      parameter:
54
55        - is inside a `while_loop` that is a parent of the calling context, and
56        - cannot be evaluated at graph build time to a constant.
57  """
58  value_name = value.name
59  # curr_ctxt is the context that tf.gradients was called in.
60  curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
61
62  curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else ""
63  max_size = constant_op.constant(1)
64
65  # Loop through all containing while contexts between value and the
66  # current context, multiplying together each context's
67  # max_iterations to get the maximum stack size.
68  while while_ctxt not in (None, curr_ctxt):
69    max_iter = while_ctxt.maximum_iterations
70    if max_iter is None:
71      raise ValueError(
72          "Cannot create a gradient accumulator for tensor '%s' inside "
73          "XLA while_loop because maximum_iterations was not passed to "
74          "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name))
75
76    # pylint: disable=protected-access
77    max_iter_ctxt = max_iter.op._get_control_flow_context()
78    # pylint: enable=protected-access
79
80    # If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use.
81    if util.IsContainingContext(curr_ctxt, max_iter_ctxt):
82      max_size *= max_iter
83    else:
84      # We cannot use max_iter because it's defined in a nested while
85      # or cond context, so will fail if we try to use it as input to
86      # any ops in curr_ctxt (e.g. max_size or the final accumulator
87      # stack). Attempt to get a constant value out to use instead.
88      const_max_iter = tensor_util.constant_value(max_iter)
89      if const_max_iter is None:
90        raise ValueError(
91            "Cannot create a gradient accumulator for tensor '%s' inside XLA "
92            "while_loop. maximum_iterations tensor '%s' for while_loop context "
93            "'%s' must be statically known (e.g. a constant value or known "
94            "shape dimension), or be defined at or outside the while loop "
95            "context '%s' (currently defined in '%s')." %
96            (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name,
97             max_iter_ctxt.name))
98      max_size *= const_max_iter
99
100    # Find the next outer WhileContext (or stop if we reach the
101    # tf.gradient's context).
102    while_ctxt = util.GetContainingWhileContext(
103        while_ctxt.outer_context, stop_ctxt=curr_ctxt)
104
105  return max_size
106
107
108class _GradLoopState(object):
109  """The state used for constructing the gradient graph for a while loop.
110
111  We create a _GradLoopState for each while loop in forward and its
112  corresponding while loop in backprop. This gives us access to both
113  the forward and the backprop WhileContexts.
114
115  During the construction of gradient graph, any time when we detect
116  a forward value that is needed for backprop, we create a history
117  accumulator and add it to `history_map`. Any time when we backprop
118  a loop switch op (in _SwitchGrad), we add the grad merge op in
119  `switch_map`.
120  """
121
122  def __init__(self, forward_ctxt, outer_grad_state):
123    # The grad loop state for the outer while loop.
124    self._outer_grad_state = None
125
126    # The while loop context for forward.
127    self._forward_context = None
128
129    # The loop counter added by AddForwardLoopCounter. It is the value
130    # of the loop counter for the next iteration.
131    self._forward_index = None
132
133    # A sync op for forward.
134    self._forward_sync = None
135
136    # The while loop context for backprop.
137    self._grad_context = None
138
139    # The loop counter added by AddBackpropLoopCounter. It is the value
140    # of the loop counter for the current iteration.
141    self._grad_index = None
142
143    # A sync op for backprop.
144    self._grad_sync = None
145
146    # Information needed by backprop.
147    self._history_map = {}
148    self._switch_map = {}
149    self._unused_exits = []
150    self._deferred_exits = []
151    self._forward_loop_exits = list(forward_ctxt.loop_exits)
152    self._pending_exits_count = len(forward_ctxt.loop_exits)
153
154    self._outer_grad_state = outer_grad_state
155    if outer_grad_state:
156      outer_forward_ctxt = outer_grad_state.forward_context
157    else:
158      if not hasattr(forward_ctxt, "outer_context"):
159        raise ValueError("Failed to call gradients on a while loop without"
160                         "properly serializing graph via MetaGraphDef")
161      outer_forward_ctxt = forward_ctxt.outer_context
162
163    # Add the forward loop counter.
164    with forward_ctxt._graph.as_default():  # pylint: disable=protected-access
165      if outer_forward_ctxt:
166        outer_forward_ctxt.Enter()
167      cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
168      if outer_forward_ctxt:
169        outer_forward_ctxt.Exit()
170    self._forward_context = forward_ctxt
171    self._forward_index = forward_index
172
173    # Add the backprop WhileContext, and the backprop loop counter.
174    if outer_grad_state:
175      # This is a nested loop. Remember the iteration counts for each
176      # execution of this inner loop.
177      outer_forward_ctxt.AddName(cnt.name)
178      history_cnt = outer_grad_state.AddForwardAccumulator(cnt)
179
180      outer_grad_ctxt = outer_grad_state.grad_context
181      outer_grad_ctxt.Enter()
182      self._grad_context = control_flow_ops.WhileContext(
183          maximum_iterations=forward_ctxt.maximum_iterations,
184          parallel_iterations=forward_ctxt.parallel_iterations,
185          back_prop=forward_ctxt.back_prop,
186          swap_memory=forward_ctxt.swap_memory,
187          name=forward_ctxt.name,
188          grad_state=self)
189      real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt)
190      self._grad_index = self._grad_context.AddBackpropLoopCounter(
191          real_cnt, outer_grad_state)
192      outer_grad_ctxt.Exit()
193    else:
194      if outer_forward_ctxt:
195        outer_forward_ctxt.Enter()
196      self._grad_context = control_flow_ops.WhileContext(
197          maximum_iterations=forward_ctxt.maximum_iterations,
198          parallel_iterations=forward_ctxt.parallel_iterations,
199          back_prop=forward_ctxt.back_prop,
200          swap_memory=forward_ctxt.swap_memory,
201          name=forward_ctxt.name,
202          grad_state=self)
203      self._grad_index = self._grad_context.AddBackpropLoopCounter(
204          cnt, outer_grad_state)
205      if outer_forward_ctxt:
206        outer_forward_ctxt.Exit()
207
208  @property
209  def outer_grad_state(self):
210    """The grad loop state for outer loop."""
211    return self._outer_grad_state
212
213  @property
214  def forward_context(self):
215    """The while loop context for forward."""
216    return self._forward_context
217
218  @property
219  def forward_index(self):
220    """The loop index of forward loop."""
221    return self._forward_index
222
223  @property
224  def forward_sync(self):
225    """A control trigger node for synchronization in the forward loop.
226
227    One main use is to keep the push ops of a stack executed in the
228    iteration order.
229    """
230    if self._forward_sync is None:
231      with ops.control_dependencies(None):
232        self._forward_sync = control_flow_ops.control_trigger(name="f_sync")
233      self._forward_sync._set_control_flow_context(self._forward_context)
234      self._forward_index.op._add_control_input(self._forward_sync)
235    return self._forward_sync
236
237  @property
238  def grad_context(self):
239    """The corresponding WhileContext for gradient."""
240    return self._grad_context
241
242  @property
243  def grad_index(self):
244    """The loop index of backprop loop."""
245    return self._grad_index
246
247  @property
248  def grad_sync(self):
249    """A control trigger node for synchronization in the grad loop.
250
251    One main use is to keep the pop ops of a stack executed in the
252    iteration order.
253    """
254    if self._grad_sync is None:
255      with ops.control_dependencies(None):
256        self._grad_sync = control_flow_ops.control_trigger(name="b_sync")
257      self._grad_sync._set_control_flow_context(self._grad_context)
258      self._grad_index.op._add_control_input(self._grad_sync)
259      if self._grad_context.outer_context:
260        self._grad_context.outer_context.AddInnerOp(self._grad_sync)
261    return self._grad_sync
262
263  @property
264  def history_map(self):
265    """The map that records all the tensors needed for backprop."""
266    return self._history_map
267
268  @property
269  def switch_map(self):
270    """The map that records all the Switch ops for the while loop."""
271    return self._switch_map
272
273  @property
274  def unused_exits(self):
275    """The list of "unused" exits."""
276    return self._unused_exits
277
278  @property
279  def deferred_exits(self):
280    """The list of "deferred" exits."""
281    return self._deferred_exits
282
283  @property
284  def forward_loop_exits(self):
285    """The list of exits of the forward loop."""
286    return self._forward_loop_exits
287
288  @property
289  def pending_exits_count(self):
290    """The number of exits we expect to see but haven't."""
291    return self._pending_exits_count
292
293  @pending_exits_count.setter
294  def pending_exits_count(self, cnt):
295    """Set the pending count to cnt."""
296    self._pending_exits_count = cnt
297
298  def AddForwardAccumulator(self, value, dead_branch=False):
299    """Add an accumulator for each forward tensor that is needed in backprop.
300
301    This is added to the forward loop at the first time when a tensor
302    in the forward loop is used by backprop gradient computation loop.
303    We create an accumulator that accumulates the value of tensor at each
304    iteration. Called in the control flow context where gradients() is called.
305
306    The pseudocode is:
307    ```
308      acc = stack();
309      while (_pivot) {
310        acc = stack_push(acc, value);
311      }
312    ```
313
314    We make sure that the stack push op in one iteration is executed before
315    next iteration. This is achieved by adding a control edge from
316    `forward_index.op.inputs[0].op` to the push op, and another control
317    edge from the push op to either `forward_index.op` or `forward_sync`.
318
319    Args:
320      value: The source tensor in forward that is to be accumulated.
321      dead_branch: True iff the tensor is on a dead branch of a cond.
322
323    Returns:
324      The stack that contains the accumulated history of the tensor.
325
326    Raises:
327      TypeError: For internal errors involving the value condition context.
328      ValueError: If `value` is inside a XLA scope and a valid max size
329        for the stack can't be found.
330    """
331    # curr_ctxt is the context that tf.gradients was called in.
332    with self._forward_index.graph.as_default():
333      curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
334      with ops.control_dependencies(None):
335        if curr_ctxt:
336          curr_ctxt.Enter()
337        with ops.colocate_with(value):
338          # We only need to pass maximum_iterations to the stack if
339          # we're inside an XLA context.
340          if not util.IsInXLAContext(value.op):
341            max_size = constant_op.constant(-1, dtypes.int32)
342          else:
343            max_size = _GetMaxSizeFromNestedMaximumIterations(
344                value, self.forward_context)
345          acc = gen_data_flow_ops.stack_v2(
346              max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
347        if curr_ctxt:
348          curr_ctxt.Exit()
349
350        # Make acc available in the forward context.
351        enter_acc = self.forward_context.AddValue(acc)
352
353        # Add the stack_push op in the context of value.op.
354        swap_enabled = self.forward_context.swap_memory
355        value_ctxt = util.GetOutputContext(value.op)
356        if value_ctxt == self.forward_context:
357          # value is not nested in the forward context.
358          self.forward_context.Enter()
359          push = gen_data_flow_ops.stack_push_v2(
360              enter_acc, value, swap_memory=swap_enabled)
361          self.forward_context.Exit()
362          # Protect stack push and order it before forward_index.
363          self.forward_index.op._add_control_input(push.op)
364        else:
365          # value is in a cond context within the forward context.
366          if not isinstance(value_ctxt, control_flow_ops.CondContext):
367            raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
368          if dead_branch:
369            # The special case for creating a zero tensor for a dead
370            # branch of a switch. See _ControlFlowState.ZerosLikeV1WhileLoop().
371            value_ctxt.outer_context.Enter()
372            push = gen_data_flow_ops.stack_push_v2(
373                enter_acc, value, swap_memory=swap_enabled)
374            value_ctxt.outer_context.Exit()
375            push.op._set_control_flow_context(value_ctxt)
376          else:
377            value_ctxt.Enter()
378            push = gen_data_flow_ops.stack_push_v2(
379                enter_acc, value, swap_memory=swap_enabled)
380            value_ctxt.Exit()
381          # Protect stack push and order it before forward_sync.
382          self.forward_sync._add_control_input(push.op)
383        # Order stack push after the successor of forward_index
384        add_op = self.forward_index.op.inputs[0].op
385        push.op._add_control_input(add_op)
386        return acc
387
388  def AddBackpropAccumulatedValue(self, history_value, value,
389                                  dead_branch=False):
390    """Add the getter for an accumulated value in the grad context.
391
392    This is added to the backprop loop. Called in the grad context to
393    get the value of an accumulated value. The stack pop op must be guarded
394    by the pred of the controlling cond.
395
396    Args:
397      history_value: The history (a stack) of a value.
398      value: The value that is pushed onto the stack.
399      dead_branch: True iff the tensor is on a dead branch of a cond.
400
401    Returns:
402      The current value (the top of the stack).
403    """
404    history_ctxt = history_value.op._get_control_flow_context()
405    # Find the cond context that controls history_value if any.
406    cond_ctxt = None
407    value_ctxt = value.op._get_control_flow_context()
408    while value_ctxt and value_ctxt != history_ctxt:
409      if isinstance(value_ctxt, control_flow_ops.CondContext):
410        cond_ctxt = value_ctxt
411        break
412      value_ctxt = value_ctxt.outer_context
413    with ops.control_dependencies(None):
414      self.grad_context.Enter()
415      if cond_ctxt:
416        # Guard stack pop with a switch if it is controlled by a cond.
417        grad_state = self
418        pred = None
419        while pred is None and grad_state:
420          pred = grad_state.history_map.get(cond_ctxt.pred.name)
421          grad_state = grad_state.outer_grad_state
422        if pred is None:
423          pred = cond_ctxt.pred
424        branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
425        history_value = control_flow_ops._SwitchRefOrTensor(
426            history_value, pred)[branch]
427      pop = gen_data_flow_ops.stack_pop_v2(history_value,
428                                           value.dtype.base_dtype)
429      pop.set_shape(value.get_shape())
430      self.grad_context.Exit()
431    parallel_iterations = self.grad_context.parallel_iterations
432    if parallel_iterations > 1:
433      # All pops are ordered after pivot_for_body and before grad_sync.
434      self.grad_sync._add_control_input(pop.op)
435    return pop
436
437  def GetRealValue(self, value):
438    """Get the real value of `value`.
439
440    If backprop "uses" a value produced by forward inference, an accumulator
441    is added in the forward loop to accumulate its values.  We use the
442    accumulated value. This method must be called in the grad loop context.
443    `value` must be in forward and needed for backprop.
444
445    Args:
446      value: A tensor to be captured.
447
448    Returns:
449      The same tensor obtained from the saved history.
450    """
451    assert value.op.type not in ["Variable", "VariableV2"]
452    real_value = self._history_map.get(value.name)
453    if real_value is None:
454      cur_value = value
455      cur_grad_state = self
456      while True:
457        enter_op = util.GetLoopConstantEnter(cur_value)
458        if enter_op:
459          # Special case: cur_value comes from a constant Enter node.
460          cur_value = enter_op.inputs[0]
461          cur_grad_state = cur_grad_state.outer_grad_state
462          if cur_grad_state is None:
463            # We are now outside all nested loops for this gradient(),
464            # so `value` is a loop invariant and there is no need to
465            # save the history of value. Just make cur_value to enter
466            # the right control flow context.
467            real_value = self._grad_context.AddValue(cur_value)
468            break
469        elif constant_op.is_constant(cur_value):
470          # If the value to be forwarded is a constant, clone the constant in
471          # the gradient loop rather than using a stack.
472          # TODO(phawkins): consider hoisting the constant out of the loop
473          # instead.
474          real_value = constant_op.constant(
475              tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
476          break
477        else:
478          # Record the history of this value in forward_ctxt.
479          self._grad_context.Exit()
480          history_value = cur_grad_state.AddForwardAccumulator(cur_value)
481          self._grad_context.Enter()
482          break
483
484      if real_value is None:
485        # Add the stack pop op in the grad context.
486        real_value = cur_grad_state.AddBackpropAccumulatedValue(
487            history_value, cur_value)
488        if cur_grad_state != self:
489          real_value = self._grad_context.AddValue(real_value)
490      self._history_map[value.name] = real_value
491    return real_value
492
493
494class _ControlFlowState(object):
495  """Maintain the mapping from the loops to their grad states."""
496
497  def __init__(self):
498    self._map = {}  # maps forward loop context to _GradLoopState
499
500  def GetGradState(self, op, before):
501    """Return the grad state for this op if it's in a forward loop context."""
502    if before and util.IsLoopExit(op):
503      forward_ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
504      forward_ctxt = forward_ctxt.outer_context
505      if forward_ctxt:
506        forward_ctxt = forward_ctxt.GetWhileContext()
507    else:
508      forward_ctxt = util.GetWhileContext(op)
509    if forward_ctxt:
510      return self._map.get(forward_ctxt)
511    return None
512
513  def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
514    """Process all the "unused" loop exits.
515
516    The "unused" exits of the loops are added to `unused_exits`. An exit is
517    unused if its pending_count is 0. If there is an exit with real gradient,
518    all these deferred exits will enter the backprop loop with zero gradient.
519    Otherwise, they will enter the backprop loop with None. As an example,
520    people often write:
521
522    ```python
523    v1, _ = tf.while_loop(p, b, [x1, x2])
524    result = gradients(v1, x1)
525    ```
526
527    The exit node for x2 is not included by the betweenness analysis. But we
528    need to backprop x2 if x2 is involved in computing v1.
529
530    Args:
531      pending_count: The number of backprop inputs for every op.
532      to_ops_set: The set of ops for ys in gradients(ys, xs)
533
534    Returns:
535      The set of unused loop exits that we know at this point we need
536      to backprop.
537    """
538    loop_exits = []
539    for grad_state in self._map.values():
540      for y in grad_state.forward_loop_exits:
541        if pending_count[y.op] == 0:
542          grad_state.pending_exits_count -= 1
543          if y.op not in to_ops_set:
544            grad_state.unused_exits.append(y)
545          if grad_state.pending_exits_count == 0:
546            loop_exits.extend(grad_state.unused_exits)
547      # Need to include Enters in backprop for higher-order gradients.
548      for y in grad_state.forward_context.loop_enters:
549        if pending_count[y.op] == 0:
550          pending_count[y.op] = 1
551    return loop_exits
552
553  def EnterGradWhileContext(self, op, before):
554    """Enter the WhileContext for gradient computation."""
555    grad_state = self.GetGradState(op, before)
556    if grad_state:
557      grad_state.grad_context.Enter()
558
559  def ExitGradWhileContext(self, op, before):
560    """Exit the WhileContext for gradient computation."""
561    grad_state = self.GetGradState(op, before)
562    if grad_state:
563      grad_state.grad_context.Exit()
564
565  def AddWhileContext(self, op, between_op_list, between_ops):
566    """Add the grad state for the while loop that op belongs to.
567
568    Note that op is an Exit, and this method must be called in
569    the control flow context where gradients() is called.
570
571    Note that this method modifies `between_op_list` and `between_ops`.
572    """
573    forward_ctxt = util.GetWhileContext(op)
574    grad_state = self._map.get(forward_ctxt)
575    if grad_state is None:
576      # This is a new while loop so create a grad state for it.
577      outer_forward_ctxt = forward_ctxt.outer_context
578      if outer_forward_ctxt:
579        outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
580      outer_grad_state = None
581      if outer_forward_ctxt:
582        outer_grad_state = self._map.get(outer_forward_ctxt)
583      grad_state = _GradLoopState(forward_ctxt, outer_grad_state)
584      self._map[forward_ctxt] = grad_state
585
586      # We need to include all exits of a loop for backprop.
587      for loop_exit in grad_state.forward_loop_exits:
588        if loop_exit.op not in between_ops:
589          between_ops.add(loop_exit.op)
590          between_op_list.append(loop_exit.op)
591
592  def ZerosLikeForExit(self, val):
593    """Create zeros_like gradient for a loop exit.
594
595    If the result of a loop variable is not used but is involved in
596    computing the result of some needed loop variable, we create a
597    zero-valued tensor that is fed as gradient for the Exit node of that
598    loop variable. Note that val.op is an Exit, and this method must be
599    called in the control flow context where gradients() is called.
600
601    Args:
602      val: The output tensor of an Exit op.
603
604    Returns:
605      A zero tensor of the same shape of val.
606    """
607    val_shape = val.get_shape()
608    forward_ctxt = val.op._get_control_flow_context()
609    outer_forward_ctxt = forward_ctxt.outer_context
610    if outer_forward_ctxt:
611      outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
612    outer_grad_state = None
613    if outer_forward_ctxt:
614      outer_grad_state = self._map.get(outer_forward_ctxt)
615    if outer_grad_state:
616      # This is a nested loop.
617      if val_shape.is_fully_defined():
618        # If the shape is known statically, just create a zero tensor
619        # with the right shape in the right context.
620        outer_grad_state.grad_context.Enter()
621        result = array_ops.zeros(val_shape.dims, val.dtype)
622        outer_grad_state.grad_context.Exit()
623      else:
624        # Only the shape of value is needed for backprop.
625        forward_ctxt.outer_context.Enter()
626        shape = array_ops.shape_internal(val, optimize=False)
627        forward_ctxt.outer_context.Exit()
628        # Save the shape to a stack.
629        history_shape = outer_grad_state.AddForwardAccumulator(shape)
630        # Get the shape back from the stack.
631        outer_grad_ctxt = outer_grad_state.grad_context
632        outer_grad_ctxt.Enter()
633        real_shape = outer_grad_state.AddBackpropAccumulatedValue(
634            history_shape, shape)
635        result = array_ops.zeros(real_shape, val.dtype)
636        outer_grad_ctxt.Exit()
637    else:
638      # This is not a nested loop.
639      if val_shape.is_fully_defined():
640        # If the shape is known statically, just create a zero tensor
641        # with the right shape.
642        result = array_ops.zeros(val_shape.dims, val.dtype)
643      else:
644        result = array_ops.zeros_like(val, optimize=False)
645    return result
646
647  def ZerosLikeV1WhileLoop(self, op, index):
648    """Create zeros_like for the specified output of an op.
649
650    If op is in a while loop that is part of gradients(), this method
651    must be called in its grad loop context.
652
653    Args:
654      op: A tensorflow operation.
655      index: the index for a specific output of the op.
656
657    Returns:
658      A zero tensor of the same shape of op.outputs[index].
659    """
660    if util.IsLoopSwitch(op):
661      return None
662    if op.graph.building_function:
663      # The optimization here is tricky to apply to functions
664      return array_ops.zeros_like(op.outputs[index])
665    dead_branch = util.IsSwitch(op)
666    forward_ctxt = util.GetWhileContext(op)
667    grad_state = self._map.get(forward_ctxt)
668    if grad_state is None:
669      # op is not in a while loop that is part of gradients().
670      return ZerosLike(op, index)
671    op_ctxt = op._get_control_flow_context()
672    val = ops.convert_to_tensor(op.outputs[index], name="tensor")
673    shape = val.get_shape()
674    if shape.is_fully_defined():
675      # If the shape is known statically, just create a zero tensor with
676      # the right shape in the grad loop context.
677      if val.dtype == dtypes.resource:
678        result = array_ops.zeros(
679            resource_variable_ops.variable_shape(val),
680            dtype=default_gradient.get_zeros_dtype(val))
681      else:
682        result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
683      if dead_branch:
684        # op is a cond switch. Guard the zero tensor with a switch.
685        pred = grad_state.history_map.get(op_ctxt.pred.name)
686        branch = op_ctxt.branch
687        result = control_flow_ops._SwitchRefOrTensor(result, pred)[1 - branch]
688    else:
689      # Unknown shape so keep a history of the shape at runtime.
690      if dead_branch:
691        # Need to add a special switch to guard the value.
692        pred = op_ctxt.pred
693        branch = op_ctxt.branch
694        op_ctxt.outer_context.Enter()
695        val = control_flow_ops._SwitchRefOrTensor(op.inputs[0],
696                                                  pred)[1 - branch]
697        zeros_shape = array_ops.shape_internal(val, optimize=False)
698        op_ctxt.outer_context.Exit()
699        val.op._set_control_flow_context(op_ctxt)
700        zeros_shape.op._set_control_flow_context(op_ctxt)
701      else:
702        op_ctxt.Enter()
703        zeros_shape = array_ops.shape_internal(val, optimize=False)
704        op_ctxt.Exit()
705
706      # Add forward accumulator for shape.
707      grad_state.grad_context.Exit()
708      history_zeros_shape = grad_state.AddForwardAccumulator(
709          zeros_shape, dead_branch=dead_branch)
710      grad_state.grad_context.Enter()
711
712      # Create a zero tensor with the right shape.
713      shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
714                                                     zeros_shape, dead_branch)
715      result = array_ops.zeros(shape, val.dtype)
716    return result
717
718  def PostProcessing(self):
719    """Perform postprocessing at the end of gradients().
720
721    We have created the gradient graph at this point. So this function
722    can be used to perform any postprocessing on the gradient graph.
723    We currently perform the following postprocessing:
724      1. Patch the gradient graph if the output of a loop variable
725         doesn't depend on its input.
726    """
727    for _, grad_state in self._map.items():
728      for _, b_merge in grad_state.switch_map.items():
729        if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
730          # The value of this loop variable at iteration i+1 doesn't
731          # depend on its value at iteration i. So use zeros as the
732          # gradients for all iterations > 0.
733          dtype = b_merge.op.inputs[0].dtype
734          shape = b_merge.op.inputs[0].get_shape()
735          # pylint: disable=protected-access
736          if shape.is_fully_defined():
737            grad_state.grad_context.Enter()
738            # Create a zeros and use it for iterations > 0.
739            grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
740            next_grad_val = control_flow_ops._NextIteration(grad_val)
741            grad_state.grad_context.Exit()
742          else:
743            # Create a zeros in the outer grad context.
744            outer_grad_ctxt = grad_state.grad_context.outer_context
745            if outer_grad_ctxt:
746              outer_grad_ctxt.Enter()
747            enter_grad_op = b_merge.op.inputs[0].op
748            enter_grad = enter_grad_op.inputs[0]
749            grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
750            grad_val = array_ops.zeros(grad_shape)
751            if outer_grad_ctxt:
752              outer_grad_ctxt.Exit()
753            # Use the zeros for iterations > 0.
754            grad_state.grad_context.Enter()
755            next_grad_val = control_flow_ops._NextIteration(grad_val)
756            grad_state.grad_context.Exit()
757          b_merge.op._update_input(1, next_grad_val)
758          # pylint: enable=protected-access
759
760
761def MaybeCreateControlFlowState(between_op_list, between_ops,
762                                colocate_gradients_with_ops):
763  """Create the state for all the while loops involved in one gradients().
764
765  We create a _ControlFlowState when there are while loops involved in
766  gradients(). In gradients(), control flow logic is only invoked when
767  the _ControlFlowState is not None.
768
769  Note that this method modifies `between_op_list` and `between_ops`.
770  """
771  loop_state = None
772  for op in between_op_list:
773    if util.IsLoopExit(op):
774      if loop_state is None:
775        loop_state = _ControlFlowState()
776      if colocate_gradients_with_ops:
777        with ops.colocate_with(op):
778          loop_state.AddWhileContext(op, between_op_list, between_ops)
779      else:
780        loop_state.AddWhileContext(op, between_op_list, between_ops)
781  return loop_state
782
783
784def _ZerosLikeV1(op, index):
785  """Branch of ZerosLike for TF1."""
786  val = op.outputs[index]
787  op_ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
788  if op_ctxt:
789    # We are in a cond context. Use a switch to create zeros only when needed.
790    pred = op_ctxt.pred
791    branch = op_ctxt.branch
792    switch_val = control_flow_ops.switch(op.inputs[0], pred)[1 - branch]
793    # A op is created along the branch taken as control dependencies are on
794    # the whole op and not on the tensor output.
795    pivot = array_ops.identity(switch_val)
796    if val.dtype == dtypes.resource:
797      with ops.control_dependencies([pivot]):
798        return array_ops.zeros(
799            gen_resource_variable_ops.variable_shape(switch_val),
800            dtype=default_gradient.get_zeros_dtype(val))
801    zeros_shape = array_ops.shape_internal(switch_val, optimize=False)
802    # Ensure ops created within array_ops.zeros are dominated by switch in
803    # cond context.
804    with ops.control_dependencies([pivot]):
805      return array_ops.zeros(zeros_shape, dtype=val.dtype)
806  else:
807    return array_ops.zeros_like(val, optimize=False)
808
809
810def _ZerosLikeV2(op, index):
811  """Branch of ZerosLike for TF2."""
812  val = op.outputs[index]
813  if val.dtype == dtypes.resource:
814    return array_ops.zeros(
815        gen_resource_variable_ops.variable_shape(val),
816        dtype=default_gradient.get_zeros_dtype(val))
817  if (isinstance(val.op.graph, control_flow_v2_func_graphs.WhileBodyFuncGraph)
818      and val.dtype != dtypes.variant):
819    # In while_v2 we do not want to add a `ZerosLike` op because that will
820    # trigger accumulation of `val`. Normally `ZerosLike` is preferred because
821    # it helps avoid creating extra nodes(possibly Consts) for the shape.
822    # For variants, we must use ZerosLike.
823    if val.shape.is_fully_defined():
824      return constant_op.constant(0, shape=val.shape.dims, dtype=val.dtype)
825    else:
826      # Note: Even though we add `Shape` in the default graph, while_v2 is smart
827      # enough to place it in the forward graph i.e. `val.graph`.
828      zeros_shape = array_ops.shape_internal(val, optimize=False)
829      return array_ops.zeros(zeros_shape, val.dtype)
830  else:
831    return array_ops.zeros_like(val, optimize=False)
832
833
834def ZerosLike(op, index):
835  """Create zeros_like for the specified output of an op."""
836  if not util.IsSwitch(op):
837    return _ZerosLikeV2(op, index)
838  else:
839    return _ZerosLikeV1(op, index)
840