1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements the graph generation for computation of gradients."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23
24from six.moves import xrange, zip  # pylint: disable=redefined-builtin
25
26from tensorflow.core.framework import attr_value_pb2
27from tensorflow.python import pywrap_tfe
28from tensorflow.python.eager import backprop
29from tensorflow.python.eager import backprop_util
30from tensorflow.python.eager import context
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import function as framework_function
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_shape
36from tensorflow.python.framework.func_graph import FuncGraph
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_ops
39from tensorflow.python.ops import control_flow_state
40from tensorflow.python.ops import control_flow_util
41from tensorflow.python.ops import default_gradient
42from tensorflow.python.ops import functional_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import resource_variable_ops
45from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
46from tensorflow.python.platform import tf_logging as logging
47from tensorflow.python.util import compat
48from tensorflow.python.util import object_identity
49from tensorflow.python.util.compat import collections_abc
50from tensorflow.python.util.tf_export import tf_export
51
52
53def _MarkReachedOps(from_ops, reached_ops, func_graphs):
54  """Mark all ops reached from "from_ops".
55
56  Args:
57    from_ops: list of Operations.
58    reached_ops: set of Operations.
59    func_graphs: list of FuncGraphs. This method will traverse through
60      these functions if they capture from_ops or any reachable ops.
61  """
62  queue = collections.deque()
63  queue.extend(from_ops)
64  while queue:
65    op = queue.popleft()
66    if op not in reached_ops:
67      reached_ops.add(op)
68      for output in op.outputs:
69        if _IsBackpropagatable(output):
70          queue.extend(_Consumers(output, func_graphs))
71
72
73def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
74                  xs_set):
75  """Initialize the pending count for ops between two lists of Operations.
76
77  'pending_count[op]' indicates the number of backprop inputs
78  to this operation.
79
80  Args:
81    to_ops: list of Operations.
82    from_ops: list of Operations.
83    colocate_gradients_with_ops: Python bool.  See docstring of gradients().
84    func_graphs: list of FuncGraphs. This method will traverse through
85      these functions if they capture from_ops or any reachable ops. This is
86      useful if to_ops occur in a function and from_ops are in an outer function
87      or graph.
88    xs_set: ObjectIdentitySet of Tensors.
89
90  Returns:
91    A tuple containing: (1) the subset of to_ops reachable from from_ops by a
92    path of zero or more backpropagatable tensors, (2) a mapping from operation
93    to the number of backprop inputs to that op, and (3) a ControlFlowState
94    object which is not None if the ops between from_ops and to_ops contain
95    control flow loops.
96  """
97  # Mark reachable ops from from_ops.
98  reached_ops = set()
99  _MarkReachedOps(from_ops, reached_ops, func_graphs)
100  # X in reached_ops iff X is reachable from from_ops by a path of zero or more
101  # backpropagatable tensors.
102
103  reachable_to_ops = set(op for op in to_ops if op in reached_ops)
104
105  # Mark between ops.
106  between_ops = set()
107  between_op_list = []
108  queue = collections.deque()
109  queue.extend(to_ops)
110  while queue:
111    op = queue.popleft()
112    # We are interested in this op.
113    if op in reached_ops:
114      between_ops.add(op)
115      between_op_list.append(op)
116      # Clear the boolean so we won't add the inputs again.
117      reached_ops.remove(op)
118      for inp in _NonEagerInputs(op, xs_set):
119        queue.append(inp.op)
120  # X in between_ops iff X is on a path of zero or more backpropagatable tensors
121  # between from_ops and to_ops
122
123  # 'loop_state' is None if there are no while loops.
124  loop_state = control_flow_state.MaybeCreateControlFlowState(
125      between_op_list, between_ops, colocate_gradients_with_ops)
126
127  # Initialize pending count for between ops.
128  pending_count = collections.defaultdict(int)
129  for op in between_op_list:
130    for x in _NonEagerInputs(op, xs_set):
131      if x.op in between_ops:
132        pending_count[x.op] += 1
133
134  return reachable_to_ops, pending_count, loop_state
135
136
137def _AsList(x):
138  return x if isinstance(x, (list, tuple)) else [x]
139
140
141def _DefaultGradYs(grad_ys,
142                   ys,
143                   colocate_gradients_with_ops,
144                   gradient_uid="__unsupported__"):
145  """Fill in default values for grad_ys.
146
147  Args:
148    grad_ys: List of gradients, can contain None.
149    ys: List of tensors.
150    colocate_gradients_with_ops: If True, try colocating gradients with
151      the corresponding op.
152    gradient_uid: A unique identifier within the graph indicating
153      which invocation of gradients is being executed. Used to cluster
154      ops for compilation.
155
156  Returns:
157    A list of gradients to use, without None.
158
159  Raises:
160    ValueError: If sizes of gradients and inputs don't match
161    TypeError: If type of any gradient is not valid for its input.
162  """
163  if len(grad_ys) != len(ys):
164    raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys)))
165  grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y")
166  new_grad_ys = []
167  for i, (y, grad_y) in enumerate(zip(ys, grad_ys)):
168    with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops):
169      if grad_y is None:
170        if y.dtype.is_complex:
171          raise TypeError(
172              "Gradients of complex tensors must set grad_ys (y.dtype = %r)" %
173              y.dtype)
174        new_grad_ys.append(
175            array_ops.fill(
176                array_ops.shape(y),
177                constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i)))
178        continue
179      if y.dtype.is_floating or y.dtype.is_integer:
180        if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
181          raise TypeError(
182              "Gradient type %s generated for real or "
183              "integer-valued tensor %s with type %s must be "
184              "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y,
185                                   dtypes.as_dtype(y.dtype).name))
186      elif y.dtype.is_complex:
187        if not grad_y.dtype.is_complex:
188          raise TypeError(
189              "Gradient type %s generated for complex-valued "
190              "tensor %s with type %s must be real" % (dtypes.as_dtype(
191                  grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
192      elif y.dtype == dtypes.variant:
193        if grad_y.dtype != dtypes.variant:
194          raise TypeError(
195              "Gradient type %s generated for variant "
196              "tensor %s with type %s must be variant" % (dtypes.as_dtype(
197                  grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
198      elif y.dtype == dtypes.resource:
199        # We assume y is the handle of a ResourceVariable. The gradient of a
200        # ResourceVariable should be a numeric value, not another resource.
201        if grad_y.dtype == dtypes.resource:
202          raise TypeError("Input gradient %s for resource tensor %s should not "
203                          "be a resource" % (grad_y, y))
204      else:
205        raise TypeError(
206            "Tensor %s with type %s must be numeric "
207            "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name))
208      # Create a grad_y tensor in the name scope of the gradient.
209      # Required for TensorArrays to identify which gradient call a
210      # grad_y value is coming from.
211      if isinstance(grad_y, ops.IndexedSlices):
212        new_grad_ys.append(
213            ops.IndexedSlices(
214                indices=(array_ops.identity(
215                    grad_y.indices, name="grad_ys_%d_indices" % i)
216                         if isinstance(grad_y.indices, ops.Tensor) else
217                         grad_y.indices),
218                values=(array_ops.identity(
219                    grad_y.values, name="grad_ys_%d_values" % i) if isinstance(
220                        grad_y.values, ops.Tensor) else grad_y.values),
221                dense_shape=(array_ops.identity(
222                    grad_y.dense_shape, name="grad_ys_%d_shape" % i)
223                             if isinstance(grad_y.dense_shape, ops.Tensor) else
224                             grad_y.dense_shape)))
225      else:
226        new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i))
227
228  return new_grad_ys
229
230
231def _IsBackpropagatable(tensor):
232  if backprop_util.IsTrainable(tensor):
233    return True
234  dtype = dtypes.as_dtype(tensor.dtype)
235  return dtype.base_dtype == dtypes.bfloat16
236
237
238def _VerifyGeneratedGradients(grads, op):
239  """Verify that gradients are valid in number and type.
240
241  Args:
242    grads: List of generated gradients.
243    op: Operation for which the gradients where generated.
244
245  Raises:
246    ValueError: if sizes of gradients and inputs don't match.
247    TypeError: if type of any gradient is not valid for its input.
248  """
249  # While ops have inputs added to them during the gradient computation, so we
250  # skip the below check. See while_v2 for details.
251  if op.type == "While" or op.type == "StatelessWhile":
252    return
253
254  if len(grads) != len(op.inputs):
255    raise ValueError("Num gradients %d generated for op %s do not match num "
256                     "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
257
258
259def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set):
260  """The set of ops that terminate the gradient computation.
261
262  This computes the frontier of the forward graph *before* which backprop
263  should stop. Operations in the returned set will not be differentiated.
264  This set is defined as the subset of `from_ops` containing ops that have
265  no predecessor in `from_ops`. `pending_count` is the result of
266  `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops`
267  iff pending_count[op] > 0.
268
269  In addition, none of `stop_gradient_ops` will be differentiated.
270
271  Args:
272    from_ops: list of Operations.
273    stop_gradient_ops: list of Operations never to backprop through.
274    pending_count: mapping from operation to number of backprop inputs.
275    xs_set: ObjectIdentitySet of Tensors.
276
277  Returns:
278    The set of operations.
279  """
280  stop_ops = set()
281  for op in from_ops:
282    is_stop_op = True
283    for inp in _NonEagerInputs(op, xs_set):
284      if pending_count[inp.op] > 0:
285        is_stop_op = False
286        break
287    if is_stop_op:
288      stop_ops.add(op)
289  stop_ops.update(op for op in stop_gradient_ops)
290  return stop_ops
291
292
293@contextlib.contextmanager
294def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):  # pylint: disable=invalid-name
295  """Context to colocate with `op` if `colocate_gradients_with_ops`."""
296  if colocate_gradients_with_ops:
297    with ops._colocate_with_for_gradient(op, gradient_uid):  # pylint: disable=protected-access
298      yield
299  else:
300    yield
301
302
303def _IsPartitionedCall(op):
304  return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall"
305
306
307def _SymGrad(op, out_grads):
308  """Backprop through a function call node op given its outputs' gradients."""
309  f_in = [x for x in op.inputs] + out_grads
310  f_types = [default_gradient.get_zeros_dtype(x) for x in op.inputs]
311  f = attr_value_pb2.NameAttrList()
312  if _IsPartitionedCall(op):
313    f.name = op.get_attr("f").name
314  else:
315    f.name = op.type
316  for k in op.node_def.attr:
317    f.attr[k].CopyFrom(op.node_def.attr[k])
318  in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f)
319  return in_grads
320
321
322def _MaybeCompile(scope, op, func, grad_fn):
323  """Compile the calculation in grad_fn if op was marked as compiled."""
324  scope = scope.rstrip("/").replace("/", "_")
325  if func is not None:
326    xla_compile = func.definition.attr["_XlaCompile"].b
327    xla_separate_compiled_gradients = func.definition.attr[
328        "_XlaSeparateCompiledGradients"].b
329    xla_scope = func.definition.attr["_XlaScope"].s.decode()
330  else:
331    try:
332      xla_compile = op.get_attr("_XlaCompile")
333      xla_separate_compiled_gradients = op.get_attr(
334          "_XlaSeparateCompiledGradients")
335      xla_scope = op.get_attr("_XlaScope").decode()
336    except ValueError:
337      xla_compile = False
338
339  if not xla_compile:
340    return grad_fn()  # Exit early
341
342  # If the gradients are supposed to be compiled separately, we give them a
343  # _XlaScope name that is based on the name_scope of the gradients.  Otherwise
344  # they just inherit the existing _XlaScope name, which lets them be merged
345  # together with the non-gradient computation.
346  if xla_separate_compiled_gradients:
347    xla_grad_scope = "%s_grad_%s" % (xla_scope, scope)
348  else:
349    xla_grad_scope = xla_scope
350
351  attrs = {
352      "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile),
353      "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode())
354  }
355  with ops.get_default_graph()._attr_scope(attrs):  # pylint: disable=protected-access
356    return grad_fn()
357
358
359def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set):
360  """Raises an error if we backprop through a loop var."""
361  # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
362  # message.
363  target_op = None
364  queue = collections.deque([op])
365  visited = set()
366  while queue:
367    curr_op = queue.popleft()
368    if curr_op in visited: continue
369    visited.add(curr_op)
370    if curr_op in from_ops:
371      target_op = curr_op
372      break
373    queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set))
374  assert target_op
375  raise ValueError(
376      "Cannot compute gradient inside while loop with respect to op '%s'. "
377      "We do not support taking the gradient wrt or through the initial value "
378      "of a loop variable. Gradients can be computed through loop invariants "
379      "or wrt the input parameters to the loop body."
380      % target_op.name)
381
382
383def _IsFunction(graph):
384  return (isinstance(graph, FuncGraph) or
385          isinstance(graph, framework_function._FuncGraph))  # pylint: disable=protected-access
386
387
388def _Captures(func_graph):
389  if isinstance(func_graph, FuncGraph):
390    return func_graph.captures
391  else:
392    assert isinstance(func_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
393    return func_graph.captures
394
395
396def _MaybeCaptured(t):
397  """If t is a captured value placeholder, returns the original captured value.
398
399  Args:
400    t: Tensor
401
402  Returns:
403    A tensor, potentially from a different Graph/FuncGraph.
404  """
405  # pylint: disable=protected-access
406  if (not isinstance(t, ops.EagerTensor) and
407      _IsFunction(t.op.graph) and t.op.type == "Placeholder"):
408    for input_t, placeholder_t in _Captures(t.op.graph):
409      if t is placeholder_t:
410        return _MaybeCaptured(input_t)
411  # pylint: enable=protected-access
412  return t
413
414
415def _NonEagerInputs(op, xs_set):
416  """Returns the inputs of op, crossing closure boundaries where necessary.
417
418  Does not return any captured EagerTensors, i.e., the number of tensors
419  returned may be less than the actual number of inputs.
420
421  Args:
422    op: Operation
423    xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
424
425  Returns:
426    A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
427    is in a FuncGraph and has captured inputs.
428  """
429  return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)]
430
431
432# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
433# _GradientsHelper a class with xs as a member variable.
434def _Inputs(op, xs_set):
435  """Returns the inputs of op, crossing closure boundaries where necessary.
436
437  Args:
438    op: Operation
439    xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
440
441  Returns:
442    A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
443    is in a FuncGraph and has captured inputs.
444  """
445  if _IsFunction(op.graph):  # pylint: disable=protected-access
446    inputs = []
447    for t in op.inputs:
448      # If we're differentiating w.r.t. `t`, do not attempt to traverse through
449      # it to a captured value. The algorithm needs to "see" `t` in this case,
450      # even if it's a function input for a captured value, whereas usually we'd
451      # like to traverse through these closures as if the captured value was the
452      # direct input to op.
453      if t not in xs_set:
454        t = _MaybeCaptured(t)
455      inputs.append(t)
456    return inputs
457  else:
458    return op.inputs
459
460
461def _Consumers(t, func_graphs):
462  """Returns the consumers of t, crossing closure boundaries where necessary.
463
464  Args:
465    t: Tensor
466    func_graphs: a list of FuncGraphs that may have captured t.
467
468  Returns:
469    A list of tensors. The tensors will be from the current graph and/or
470    func_graphs.
471  """
472  consumers = t.consumers()
473  for func in func_graphs:
474    for input_t, placeholder in _Captures(func):
475      if input_t is t:
476        consumers.extend(_Consumers(placeholder, func_graphs))
477  return consumers
478
479
480def _GradientsHelper(ys,
481                     xs,
482                     grad_ys=None,
483                     name="gradients",
484                     colocate_gradients_with_ops=False,
485                     gate_gradients=False,
486                     aggregation_method=None,
487                     stop_gradients=None,
488                     unconnected_gradients=UnconnectedGradients.NONE,
489                     src_graph=None):
490  """Implementation of gradients()."""
491  if context.executing_eagerly():
492    raise RuntimeError("tf.gradients is not supported when eager execution "
493                       "is enabled. Use tf.GradientTape instead.")
494  if src_graph is None:
495    src_graph = ops.get_default_graph()
496  try:
497    unconnected_gradients = UnconnectedGradients(unconnected_gradients)
498  except ValueError:
499    raise ValueError(
500        "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
501
502  # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
503  # ancestor graphs. This is necessary for correctly handling captured values.
504  func_graphs = []
505  curr_graph = src_graph
506  while _IsFunction(curr_graph):
507    func_graphs.append(curr_graph)
508    if isinstance(curr_graph, FuncGraph):
509      curr_graph = curr_graph.outer_graph
510    else:
511      assert isinstance(curr_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
512      curr_graph = curr_graph._outer_graph  # pylint: disable=protected-access
513
514  ys = _AsList(ys)
515  xs = _AsList(xs)
516  stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
517  if grad_ys is None:
518    grad_ys = [None] * len(ys)
519  else:
520    grad_ys = _AsList(grad_ys)
521
522  with ops.name_scope(
523      name, "gradients",
524      list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
525    # Get a uid for this call to gradients that can be used to help
526    # cluster ops for compilation.
527    gradient_uid = ops.get_default_graph().unique_name("uid")
528    ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
529    xs = [
530        x.handle if resource_variable_ops.is_resource_variable(x) else x
531        for x in xs
532    ]
533    xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
534        xs, name="x", as_ref=True)
535    xs_set = object_identity.ObjectIdentitySet(xs)
536    grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
537                             gradient_uid)
538
539    # The approach we take here is as follows: Create a list of all ops in the
540    # subgraph between the ys and xs.  Visit these ops in reverse order of ids
541    # to ensure that when we visit an op the gradients w.r.t its outputs have
542    # been collected.  Then aggregate these gradients if needed, call the op's
543    # gradient function, and add the generated gradients to the gradients for
544    # its input.
545
546    # Initialize the pending count for ops in the connected subgraph from ys
547    # to the xs.
548    to_ops = [t.op for t in ys]
549    from_ops = [t.op for t in xs]
550    stop_gradient_ops = [t.op for t in stop_gradients]
551    reachable_to_ops, pending_count, loop_state = _PendingCount(
552        to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set)
553
554    # Iterate over the collected ops.
555    #
556    # grads: op => list of gradients received on each output endpoint of the
557    # op.  The gradients for each endpoint are initially collected as a list.
558    # When it is time to call the op's gradient function, for each endpoint we
559    # aggregate the list of received gradients into a Add() Operation if there
560    # is more than one.
561    grads = {}
562
563    # Add the initial gradients for the ys.
564    for y, grad_y in zip(ys, grad_ys):
565      _SetGrad(grads, y, grad_y)
566
567    # Initialize queue with to_ops.
568    queue = collections.deque()
569    # Add the ops in 'to_ops' into the queue.
570    to_ops_set = set()
571    for op in to_ops:
572      # 'ready' handles the case where one output gradient relies on
573      # another output's gradient.
574      ready = (pending_count[op] == 0)
575      if ready and op not in to_ops_set and op in reachable_to_ops:
576        to_ops_set.add(op)
577        queue.append(op)
578
579    if loop_state:
580      loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
581      for y in loop_exits:
582        if backprop_util.IsTrainable(y):
583          _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
584          queue.append(y.op)
585
586    stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set)
587    while queue:
588      # generate gradient subgraph for op.
589      op = queue.popleft()
590      with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
591        if loop_state:
592          loop_state.EnterGradWhileContext(op, before=True)
593        out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
594                                     aggregation_method)
595        if loop_state:
596          loop_state.ExitGradWhileContext(op, before=True)
597
598        grad_fn = None
599        func_call = None
600        is_partitioned_call = _IsPartitionedCall(op)
601        # pylint: disable=protected-access
602        is_func_call = (
603            src_graph._is_function(op.type) or is_partitioned_call)
604        # pylint: enable=protected-access
605        has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
606        if has_out_grads and (op not in stop_ops):
607          try:
608            grad_fn = ops.get_gradient_function(op)
609          except LookupError:
610            if is_func_call:
611              if is_partitioned_call:
612                func_name = compat.as_bytes(op.get_attr("f").name)
613                func_call = src_graph._get_function(  # pylint: disable=protected-access
614                    func_name)
615                # When a graph is imported, the FunctionDefs are not copied over
616                # to each sub-graph so we recursively search the outer graphs
617                # for the FunctionDef.
618                if not func_call and hasattr(src_graph, "outer_graph"):
619                  graph = src_graph.outer_graph
620                  while graph is not None:
621                    func_call = graph._get_function(func_name)  # pylint: disable=protected-access
622                    if func_call  is not None:
623                      break
624                    if hasattr(graph, "outer_graph"):
625                      graph = graph.outer_graph
626                    else:
627                      break
628              else:
629                func_call = src_graph._get_function(op.type)  # pylint: disable=protected-access
630              # Note that __defun is not set if the graph is
631              # imported. If it's set, we prefer to access the original
632              # defun.
633              func_call = getattr(op, "__defun", func_call)
634              grad_fn = func_call.python_grad_func
635            else:
636              raise LookupError(
637                  "No gradient defined for operation '%s' (op type: %s)" %
638                  (op.name, op.type))
639        if loop_state:
640          loop_state.EnterGradWhileContext(op, before=False)
641
642        # NOTE(skyewm): We don't support computing gradients wrt a loop variable
643        # unless it's within the context of a single iteration (i.e. the
644        # gradient is wrt to the loop parameter in the body function, not wrt or
645        # through the initial value). This means if we're in a while loop
646        # context, we should never see a switch node from this context.
647        # pylint: disable=protected-access
648        if (control_flow_util.IsSwitch(op) and
649            op._control_flow_context is not None and
650            op._control_flow_context.IsWhileContext() and
651            op._control_flow_context ==
652            ops.get_default_graph()._get_control_flow_context()):
653          _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set)
654        # pylint: enable=protected-access
655
656        if (grad_fn or is_func_call) and has_out_grads:
657          # NOTE: If _AggregatedGrads didn't compute a value for the i'th
658          # output, it means that the cost does not depend on output[i],
659          # therefore dC/doutput[i] is 0.
660          for i, out_grad in enumerate(out_grads):
661            if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
662                (not grad_fn and is_func_call)
663                or backprop_util.IsTrainable(op.outputs[i])):
664              # Only trainable outputs or outputs for a function call that
665              # will use SymbolicGradient get a zero gradient. Gradient
666              # functions should ignore the gradient for other outputs.
667              # TODO(apassos) gradients of resource handles might be an
668              # issue here because of zeros.
669              if loop_state:
670                out_grads[i] = loop_state.ZerosLikeV1WhileLoop(op, i)
671              elif default_gradient.supports_default_grad(op.outputs[i]):
672                # TODO(b/143286622): The supports_default_grad check is needed
673                # because While op emits non-differentiable resource tensors
674                # as outputs. Remove this check when that is not the case.
675                out_grads[i] = control_flow_state.ZerosLike(op, i)
676          with ops.name_scope(op.name + "_grad"):
677            # pylint: disable=protected-access
678            with src_graph._original_op(op):
679              # pylint: enable=protected-access
680              if grad_fn:
681                # If grad_fn was found, do not use SymbolicGradient even for
682                # functions.
683                in_grads = _MaybeCompile(grad_scope, op, func_call,
684                                         lambda: grad_fn(op, *out_grads))
685              else:
686                # For function call ops, we add a 'SymbolicGradient'
687                # node to the graph to compute gradients.
688                in_grads = _MaybeCompile(grad_scope, op, func_call,
689                                         lambda: _SymGrad(op, out_grads))
690              in_grads = _AsList(in_grads)
691              _VerifyGeneratedGradients(in_grads, op)
692              if gate_gradients and len([x for x in in_grads
693                                         if x is not None]) > 1:
694                with ops.device(None):
695                  with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
696                      None,
697                      gradient_uid,
698                      ignore_existing=True):
699                    in_grads = control_flow_ops.tuple(in_grads)
700          _LogOpGradients(op, out_grads, in_grads)
701        else:
702          # If no grad_fn is defined or none of out_grads is available,
703          # just propagate a list of None backwards.
704          in_grads = [None] * len(_Inputs(op, xs_set))
705        # Note: we don't filter out eager inputs here because the inputs need to
706        # line up with in_grads.
707        for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)):
708          if in_grad is not None:
709            if (isinstance(in_grad, ops.Tensor) and
710                t_in.dtype != dtypes.resource):
711              try:
712                in_grad.set_shape(t_in.get_shape())
713              except ValueError:
714                raise ValueError(
715                    "Incompatible shapes between op input and calculated "
716                    "input gradient.  Forward operation: %s.  Input index: %d. "
717                    "Original input shape: %s.  "
718                    "Calculated input gradient shape: %s" %
719                    (op.name, i, t_in.shape, in_grad.shape))
720            if not isinstance(t_in, ops.EagerTensor):
721              _SetGrad(grads, t_in, in_grad)
722        if loop_state:
723          loop_state.ExitGradWhileContext(op, before=False)
724
725      # Update pending count for the inputs of op and enqueue ready ops.
726      _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
727                                    xs_set)
728
729  if loop_state:
730    loop_state.PostProcessing()
731  return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
732
733
734def _HasAnyNotNoneGrads(grads, op):
735  """Return true iff op has real gradient."""
736  out_grads = _GetGrads(grads, op)
737  for out_grad in out_grads:
738    if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
739      return True
740    if out_grad and isinstance(out_grad, collections_abc.Sequence):
741      if any(g is not None for g in out_grad):
742        return True
743  return False
744
745
746def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
747                                  xs_set):
748  """Update pending count for the inputs of op and enqueue ready ops."""
749  for x in _NonEagerInputs(op, xs_set):
750    pending_count[x.op] -= 1
751    ready = (pending_count[x.op] == 0)
752    if loop_state and not ready:
753      ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op)
754    if ready:
755      if control_flow_util.IsLoopExit(x.op):
756        # if x is an exit without real gradient, defer processing them.
757        grad_state = loop_state.GetGradState(x.op, before=False)
758        grad_state.deferred_exits.append(x)
759        grad_state.pending_exits_count -= 1
760        if grad_state.pending_exits_count == 0:
761          # We now have all the exits so process them.
762          has_not_none_grad = False
763          for y in grad_state.deferred_exits:
764            if _HasAnyNotNoneGrads(grads, y.op):
765              has_not_none_grad = True
766              queue.append(y.op)
767            else:
768              grad_state.unused_exits.append(y)
769          if has_not_none_grad:
770            # For an unused exit, if it has trainable outputs, backprop
771            # a zero gradient. Otherwise, just ignore it.
772            for y in grad_state.unused_exits:
773              if backprop_util.IsTrainable(y):
774                _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
775              queue.append(y.op)
776          else:
777            # All exits are "unused" so use None as gradient.
778            for y in grad_state.unused_exits:
779              queue.append(y.op)
780      else:
781        queue.append(x.op)
782
783
784def _SetGrad(grads, t, grad):
785  """Sets gradient "grad" in "grads" for tensor "t"."""
786  op = t.op
787  op_grads = grads.get(op)
788  if not op_grads:
789    op_grads = [[] for _ in xrange(len(op.outputs))]
790    grads[op] = op_grads
791  t_grads = op_grads[t.value_index]
792  if isinstance(t_grads, list):
793    t_grads.append(grad)
794  else:
795    assert control_flow_util.IsLoopSwitch(op)
796    op_grads[t.value_index] = grad
797
798
799def _ZerosLike(t):
800  t_dtype = default_gradient.get_zeros_dtype(t)
801  if t.dtype == dtypes.resource:
802    return array_ops.zeros(
803        resource_variable_ops.variable_shape(t), dtype=t_dtype)
804  else:
805    return array_ops.zeros_like(t, dtype=t_dtype)
806
807
808def _GetGrad(grads, t, unconnected_gradients):
809  """Gets gradient for tensor "t"."""
810  op = t.op
811  op_grads = grads.get(op)
812  if not op_grads:
813    if unconnected_gradients == UnconnectedGradients.ZERO:
814      return _ZerosLike(t)
815    elif unconnected_gradients == UnconnectedGradients.NONE:
816      return None
817    else:
818      raise ValueError(
819          "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
820
821  t_grad = op_grads[t.value_index]
822  # This can happen if some other output of `t.op` has non-None grad.
823  if unconnected_gradients == UnconnectedGradients.ZERO and t_grad is None:
824    return _ZerosLike(t)
825
826  assert not isinstance(
827      t_grad, list), ("gradients list should have been aggregated by now.")
828  return t_grad
829
830
831def _GetGrads(grads, op):
832  """Gets all gradients for op."""
833  if op in grads:
834    return grads[op]
835  else:
836    return [[] for _ in xrange(len(op.outputs))]
837
838
839def _AccumulatorShape(inputs):
840  shape = tensor_shape.unknown_shape()
841  for i in inputs:
842    if isinstance(i, ops.Tensor):
843      shape = shape.merge_with(i.get_shape())
844  return shape
845
846
847def _LogOpGradients(op, out_grads, in_grads):
848  """Log the in and out grads of an op."""
849  logging.vlog(1, "Gradient for '" + op.name + "'")
850
851  def _FilterGrad(x):
852    if x is None:
853      return False
854    if isinstance(x, (list, tuple)):
855      return bool(x)
856    else:
857      return True
858
859  logging.vlog(1, "  in  --> %s",
860               ", ".join(x.name for x in out_grads if _FilterGrad(x)))
861  logging.vlog(1, "  out --> %s",
862               ", ".join(x.name for x in in_grads if _FilterGrad(x)))
863
864
865def _MultiDeviceAddN(tensor_list, gradient_uid):
866  """Adds tensors from potentially multiple devices."""
867  # Basic function structure comes from control_flow_ops.group().
868  # Sort tensors according to their devices.
869  tensors_on_device = collections.defaultdict(lambda: [])
870  for tensor in tensor_list:
871    tensors_on_device[tensor.device].append(tensor)
872
873  # For each device, add the tensors on that device first.
874  # Then gather the partial sums from multiple devices.
875  # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion.
876  # E.g., aggregate per GPU, then per task, and so on.
877  summands = []
878
879  def DeviceKey(dev):
880    return "" if dev is None else dev
881
882  for dev in sorted(tensors_on_device, key=DeviceKey):
883    tensors = tensors_on_device[dev]
884    with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
885        tensors[0].op,
886        gradient_uid,
887        ignore_existing=True):
888      summands.append(math_ops.add_n(tensors))
889
890  return math_ops.add_n(summands)
891
892
893@tf_export("AggregationMethod")
894class AggregationMethod(object):
895  """A class listing aggregation methods used to combine gradients.
896
897  Computing partial derivatives can require aggregating gradient
898  contributions. This class lists the various methods that can
899  be used to combine gradients in the graph.
900
901  The following aggregation methods are part of the stable API for
902  aggregating gradients:
903
904  *  `ADD_N`: All of the gradient terms are summed as part of one
905     operation using the "AddN" op (see `tf.add_n`). This
906     method has the property that all gradients must be ready and
907     buffered separately in memory before any aggregation is performed.
908  *  `DEFAULT`: The system-chosen default aggregation method.
909
910  The following aggregation methods are experimental and may not
911  be supported in future releases:
912
913  * `EXPERIMENTAL_TREE`: Gradient terms are summed in pairs using
914    the "AddN" op. This method of summing gradients may reduce
915    performance, but it can improve memory utilization because the
916    gradients can be released earlier.
917
918  """
919  ADD_N = 0
920  DEFAULT = ADD_N
921  # The following are experimental and may not be supported in future releases.
922  EXPERIMENTAL_TREE = 1
923  EXPERIMENTAL_ACCUMULATE_N = 2  # An alias for EXPERIMENTAL_ADD_N = 1
924
925
926def _AggregatedGrads(grads,
927                     op,
928                     gradient_uid,
929                     loop_state,
930                     aggregation_method=None):
931  """Get the aggregated gradients for op.
932
933  Args:
934    grads: The map of memoized gradients.
935    op: The op to get gradients for.
936    gradient_uid: A unique identifier within the graph indicating
937      which invocation of gradients is being executed. Used to cluster
938      ops for compilation.
939    loop_state: An object for maintaining the state of the while loops in the
940                graph. It is of type ControlFlowState. None if the graph
941                contains no while loops.
942    aggregation_method: Specifies the method used to combine gradient terms.
943      Accepted values are constants defined in the class `AggregationMethod`.
944
945  Returns:
946    A list of gradients, one per each output of `op`. If the gradients
947      for a particular output is a list, this function aggregates it
948      before returning.
949
950  Raises:
951    TypeError: if the incoming grads are not Tensors or IndexedSlices.
952    ValueError: if the arguments are invalid.
953
954  """
955  if aggregation_method is None:
956    aggregation_method = AggregationMethod.DEFAULT
957  if aggregation_method not in [
958      AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE,
959      AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
960  ]:
961    raise ValueError(
962        "Invalid aggregation_method specified %s." % aggregation_method)
963  out_grads = _GetGrads(grads, op)
964  for i, out_grad in enumerate(out_grads):
965    if loop_state:
966      if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
967        assert control_flow_util.IsLoopSwitch(op)
968        continue
969    # Grads have to be Tensors or IndexedSlices
970    if (isinstance(out_grad, collections_abc.Sequence) and not all(
971        isinstance(g, (ops.Tensor, ops.IndexedSlices))
972        for g in out_grad
973        if g is not None)):
974      raise TypeError("gradients have to be either all Tensors "
975                      "or all IndexedSlices")
976    # Aggregate multiple gradients, and convert [] to None.
977    if out_grad:
978      if len(out_grad) < 2:
979        used = "nop"
980        out_grads[i] = out_grad[0]
981      elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None):
982        tensor_shape = _AccumulatorShape(out_grad)
983        if aggregation_method in [
984            AggregationMethod.EXPERIMENTAL_TREE,
985            AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
986        ]:
987          # Aggregate all gradients by doing pairwise sums: this may
988          # reduce performance, but it can improve memory because the
989          # gradients can be released earlier.
990          #
991          # TODO(vrv): Consider replacing this with a version of
992          # tf.AddN() that eagerly frees its inputs as soon as they are
993          # ready, so the order of this tree does not become a problem.
994          used = "tree"
995          with ops.name_scope(op.name + "_gradient_sum"):
996            running_sum = out_grad[0]
997            for grad in out_grad[1:]:
998              running_sum = math_ops.add_n([running_sum, grad])
999            out_grads[i] = running_sum
1000        else:
1001          used = "add_n"
1002          out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid)
1003        logging.vlog(2, "  _AggregatedGrads %d x %s using %s", len(out_grad),
1004                     tensor_shape, used)
1005      else:
1006        out_grads[i] = backprop.aggregate_indexed_slices_gradients(out_grad)  # pylint: disable=protected-access
1007    else:  # not out_grad
1008      # out_grads[i] is [], thus its aggregation is simply None.
1009      out_grads[i] = None
1010  return out_grads
1011
1012
1013# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
1014# unfortunately too slow to use here.
1015POSSIBLE_GRADIENT_TYPES_NONE = 0
1016POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1
1017POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
1018
1019
1020def PossibleTapeGradientTypes(tensors):
1021  """Determines whether and how `args` may require tape gradients."""
1022  return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors)
1023