1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Code for backpropagation using the tape utilities."""
16
17# TODO(b/159343581): Properly support CompositeTensor in all functions in this
18# file.
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import functools
25import operator
26import sys
27
28import six
29
30from tensorflow.python import pywrap_tfe
31from tensorflow.python.eager import backprop_util
32from tensorflow.python.eager import context
33from tensorflow.python.eager import execute
34from tensorflow.python.eager import imperative_grad
35from tensorflow.python.eager import tape
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import tensor_shape
40from tensorflow.python.framework import tensor_util
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import check_ops
43from tensorflow.python.ops import control_flow_util
44from tensorflow.python.ops import default_gradient
45from tensorflow.python.ops import gen_array_ops
46from tensorflow.python.ops import gen_math_ops
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import resource_variable_ops
49from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
50from tensorflow.python.platform import tf_logging as logging
51from tensorflow.python.util import _pywrap_utils
52from tensorflow.python.util import nest
53from tensorflow.python.util import tf_contextlib
54from tensorflow.python.util import tf_inspect
55from tensorflow.python.util.lazy_loader import LazyLoader
56from tensorflow.python.util.tf_export import tf_export
57
58
59# Note that we need to lazy load the following two modules to avoid creating
60# circular dependencies.
61# TODO(b/119775953): fix the circular dependencies.
62pfor_ops = LazyLoader(
63    "pfor_ops", globals(),
64    "tensorflow.python.ops.parallel_for.control_flow_ops")
65
66function = LazyLoader("function", globals(),
67                      "tensorflow.python.eager.function")
68
69_op_attr_type_cache = {}
70
71
72def op_attr_type(op_type, attr_name):
73  try:
74    return _op_attr_type_cache[(op_type, attr_name)]
75  except KeyError:
76    context.ensure_initialized()
77    h = context.context()._handle  # pylint: disable=protected-access
78    attr_type = pywrap_tfe.TFE_OpNameGetAttrType(h, op_type, attr_name)
79  _op_attr_type_cache[(op_type, attr_name)] = attr_type
80  return attr_type
81
82
83def make_attr(attr_type, value):
84  # pybind11 enums do not return the raw value like SWIG enums do. They are
85  # useful when comparing amongst each other but not direct integers as we are
86  # doing in most tests.
87  # https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
88  # TODO(amitpatankar): After all SWIG transitions, convert the enum comparisons
89  # from integer value to class.
90  if attr_type == int(pywrap_tfe.TF_ATTR_TYPE):
91    return dtypes.as_dtype(value)
92  if attr_type == [int(pywrap_tfe.TF_ATTR_TYPE)]:
93    return [dtypes.as_dtype(v) for v in value]
94  if attr_type == int(pywrap_tfe.TF_ATTR_SHAPE):
95    return tensor_shape.as_shape(value).as_proto()
96  if attr_type == [int(pywrap_tfe.TF_ATTR_SHAPE)]:
97    return [tensor_shape.as_shape(v).as_proto() for v in value]
98  if isinstance(value, str):
99    return value.encode()
100  return value
101
102
103class _MockOp(object):
104  """Pretends to be a tf.Operation for the gradient functions."""
105
106  def __init__(self, attrs, inputs, outputs, typ, skip_input_indices):
107    self.attrs = attrs
108    self.inputs = inputs
109    self.outputs = outputs
110    self.type = typ
111    self.skip_input_indices = skip_input_indices
112
113  def get_attr(self, attr):
114    typ = op_attr_type(self.type, attr)
115    for i in range(0, len(self.attrs), 2):
116      if self.attrs[i] == attr:
117        return make_attr(typ, self.attrs[i + 1])
118    raise KeyError(attr)
119
120  def _get_control_flow_context(self):
121    raise NotImplementedError(
122        "tf.GradientTape.gradients() does not support graph control flow "
123        "operations like tf.cond or tf.while at this time. Use tf.gradients() "
124        "instead. If you need this feature, please file a feature request at "
125        "https://github.com/tensorflow/tensorflow/issues/new"
126    )
127
128
129def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
130                       out_grads, skip_input_indices, forward_pass_name_scope):
131  """Calls the gradient function of the op.
132
133  Args:
134    op_name: the name of the op to be differentiated.
135    attr_tuple: the attrs, as a tuple.
136    num_inputs: the number of inputs to the op.
137    inputs: inputs to the original operation.
138    outputs: outputs to the original operation.
139    out_grads: gradients of the operation wrt its outputs.
140    skip_input_indices: a tuple that is passed to the gradient function,
141      indicating which inputs to skip calculating the gradient for
142    forward_pass_name_scope: the namescope of the op in the forward pass.
143
144  Returns:
145    The gradients with respect to the inputs of the function, as a list.
146  """
147  mock_op = _MockOp(attr_tuple, inputs, outputs, op_name, skip_input_indices)
148  grad_fn = ops._gradient_registry.lookup(op_name)  # pylint: disable=protected-access
149  if grad_fn is None:
150    return [None] * num_inputs
151
152  # This does not work with v1 TensorArrays.
153  if ops.executing_eagerly_outside_functions(
154  ) or control_flow_util.EnableControlFlowV2(ops.get_default_graph()):
155    gradient_name_scope = "gradient_tape/"
156    if forward_pass_name_scope:
157      gradient_name_scope += forward_pass_name_scope + "/"
158    with ops.name_scope(gradient_name_scope):
159      return grad_fn(mock_op, *out_grads)
160  else:
161    return grad_fn(mock_op, *out_grads)
162
163
164pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function)
165
166
167def _must_record_gradient():
168  return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
169
170
171def _record_gradient(op_name, inputs, attrs, results):
172  return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results,
173                                          ops.get_name_scope())
174
175
176execute.must_record_gradient = _must_record_gradient
177execute.record_gradient = _record_gradient
178
179
180def implicit_val_and_grad(f):
181  """Returns a function which differentiates f with respect to variables.
182
183  The wrapped function returns the value and the gradient of f when called with
184  the same arguments. The gradient is with respect to all trainable TFE
185  variables accessed by `f`.
186
187  This function is useful when the exact set of variables to differentiate with
188  is not known ahead of time.
189
190  Example:
191
192  ```python
193  dense_layer = tf.compat.v1.layers.Dense(1)
194  def loss(x, y):
195    return tf.reduce_sum(tf.square(dense_layer(x) - y))
196
197  # Obtain the gradient function.
198  val_grad_fn = tfe.implicit_value_and_gradients(loss)
199
200  # Invoke the gradient function with concrete values of x and y.
201  x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
202  y = tf.constant([[10.0], [20.0]])
203  value, grads_and_vars = val_grad_fn(x, y)
204  print('Value of loss: %s' % value)
205
206  # Apply the gradients to Variables.
207  optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
208  optimizer.apply_gradients(grads_and_vars)
209  ```
210
211  Args:
212    f: function to be differentiated. If `f` returns a scalar, this scalar will
213      be differentiated. If `f` returns a tensor or list of tensors, by default
214      a scalar will be computed by adding all their values to produce a single
215      scalar.
216
217  Returns:
218    A function which, when called, returns a tuple pair.
219    Its first element is the value to which the function evaluates.
220    Its second element is list of (gradient, variable) pairs.
221
222  Raises:
223    ValueError: if `f` returns None.
224  """
225  # TODO(cais): Remove calls to tf.constant() once the gradients functions
226  # accept lists and np.ndarrays.
227
228  def grad_fn(*args, **kwds):
229    """Computes the gradient of the wrapped function."""
230    this_tape = tape.push_new_tape()
231    try:
232      end_node = f(*args, **kwds)
233      if end_node is None:
234        raise ValueError("Cannot differentiate a function that returns None; "
235                         "did you forget to return a value from {}?".format(
236                             f.__name__))
237    finally:
238      tape.pop_tape(this_tape)
239    # Note: variables are returned in construction order. This ensures unique
240    # order across executions.
241    variables = this_tape.watched_variables()
242    if not variables:
243      raise ValueError("No trainable variables were accessed while the "
244                       "function was being computed.")
245
246    sources = [v.handle for v in variables]
247    for s in sources:
248      if getattr(s, "is_packed", False):
249        raise ValueError(
250            "GradientTape.gradient is not supported on packed EagerTensors yet."
251        )
252    grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
253                                           sources)
254    return end_node, list(zip(grad, variables))
255
256  return grad_fn
257
258
259def implicit_grad(f):
260  """Returns a function which differentiates f with respect to variables.
261
262  The wrapped function returns the gradient of f when called with the same
263  arguments. The gradient is with respect to all trainable TFE variables
264  accessed by `f`.
265
266  This function is useful when the exact set of variables to differentiate with
267  is not known ahead of time.
268
269  Example:
270
271  ```python
272  dense_layer = tf.compat.v1.layers.Dense(1)
273  def loss(x, y):
274    return tf.reduce_sum(tf.square(dense_layer(x) - y))
275
276  # Obtain the gradient function.
277  grad_fn = tfe.implicit_gradients(loss)
278
279  # Invoke the gradient function with concrete values of x and y.
280  x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
281  y = tf.constant([[10.0], [20.0]])
282  grads_and_vars = grad_fn(x, y)
283
284  # Apply the gradients to Variables.
285  optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
286  optimizer.apply_gradients(grads_and_vars)
287  ```
288
289  Args:
290    f: function to be differentiated. If `f` returns a scalar, this scalar will
291      be differentiated. If `f` returns a tensor or list of tensors, by default
292      a scalar will be computed by adding all their values to produce a single
293      scalar.
294
295  Returns:
296    A function which, when called, returns a list of (gradient, variable) pairs.
297  """
298  # TODO(cais): Remove calls to tf.constant() once the gradients functions
299  # accept lists and np.ndarrays.
300
301  def grad_fn(*args, **kwds):
302    """Computes the gradient of the wrapped function."""
303    return implicit_val_and_grad(f)(*args, **kwds)[1]
304
305  return grad_fn
306
307
308def _get_arg_spec(f, params, param_args):
309  """The positions of the parameters of f to be differentiated in param_args."""
310  try:
311    args = tf_inspect.getfullargspec(f).args
312  except TypeError as e:
313    # TypeError can happen when f is a callable object.
314    if params is None:
315      return range(len(param_args))
316    elif all(isinstance(x, int) for x in params):
317      return params
318    raise ValueError("Either callable provided is not a function or could not "
319                     "inspect its arguments by name: %s. Original error: %s"
320                     % (f, e))
321  if params is None:
322    if not args:
323      return range(len(param_args))
324    if args[0] == "self":
325      return range(len(args) - 1)
326    else:
327      return range(len(args))
328  elif all(isinstance(x, six.string_types) for x in params):
329    return [args.index(n) for n in params]
330  elif all(isinstance(x, int) for x in params):
331    return params
332  else:
333    raise ValueError(
334        "params must be all strings or all integers; got %s." % params)
335
336
337def gradients_function(f, params=None):
338  """Returns a function which differentiates f with respect to params.
339
340  Example:
341  ```python
342  # f(x, y) = (x ^ 3) * y - x * (y ^ 2)
343  # Therefore, the 1st order derivatives are:
344  #   df / dx = 3 * (x ^ 2) * y - y ^ 2
345  #   df / dy = x ^ 3 - 2 * x * y
346  # The 2nd order derivatives with respect to x is:
347  #   d^2 f / (dx)^2 = 6 * x * y
348  def f(x, y):
349    return x * x * x * y - x * y * y
350
351  # Obtain a function that returns 1st order gradients.
352  grad_fn = tfe.gradients_function(f)
353
354  x = 2.0
355  y = 3.0
356
357  # Invoke the 1st order gradient function.
358  x_grad, y_grad = grad_fn(x, y)
359  assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2
360  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
361
362  # Obtain a function that returns the 2nd order gradient with respect to x.
363  gradgrad_fn = tfe.gradients_function(lambda x, y: grad_fn(x, y)[0])
364
365  # Invoke the 2nd order gradient function.
366  x_gradgrad = gradgrad_fn(x, y)[0]
367  assert x_gradgrad.numpy() == 6 * 2 * 3
368
369  # To obtain a callable that returns the gradient(s) of `f` with respect to a
370  # subset of its inputs, use the `params` keyword argument with
371  # `gradients_function()`.
372  ygrad_fn = tfe.gradients_function(f, params=[1])
373
374  (y_grad,) = ygrad_fn(x, y)
375  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
376  ```
377
378  Note that only tensors with real or complex dtypes are differentiable.
379
380  Args:
381    f: function to be differentiated. If `f` returns a scalar, this scalar will
382      be differentiated. If `f` returns a tensor or list of tensors, by default
383      a scalar will be computed by adding all their values to produce a single
384      scalar. If desired, the tensors can be elementwise multiplied by the
385      tensors passed as the `dy` keyword argument to the returned gradient
386      function.
387    params: list of parameter names of f or list of integers indexing the
388      parameters with respect to which we'll differentiate. Passing None
389      differentiates with respect to all parameters.
390
391  Returns:
392    function which, when called, returns the value of f and the gradient
393    of `f` with respect to all of `params`. The function takes an extra optional
394    keyword argument `dy`. Setting it allows computation of vector jacobian
395    products for vectors other than the vector of ones.
396
397  Raises:
398    ValueError: if the params are not all strings or all integers.
399  """
400
401  def decorated(*args, **kwds):
402    """Computes the gradient of the decorated function."""
403
404    _, grad = val_and_grad_function(f, params=params)(*args, **kwds)
405    return grad
406
407  return decorated
408
409
410def _ensure_unique_tensor_objects(parameter_positions, args):
411  """Make each of the parameter_positions in args a unique ops.Tensor object.
412
413  Ensure that each parameter is treated independently.
414  For example:
415
416  def f(x, y): return x * y
417  g = gradients_function(f)
418  one = tf.constant(1.)
419
420  g(one, one) should return [1., 1.]
421  (even though the two arguments are the same Tensor object).
422
423  Args:
424    parameter_positions: List of indices into args defining the arguments to
425      differentiate against.
426    args: A list of arguments to the function to be differentiated.
427
428  Returns:
429    args, possibly edited in-place.
430  """
431  s = set()
432  for (i, t) in enumerate(args):
433    if i in parameter_positions:
434      tid = ops.tensor_id(t)
435      if tid in s:
436        args[i] = gen_array_ops.identity(args[i])
437      else:
438        s.add(tid)
439  return args
440
441
442def val_and_grad_function(f, params=None):
443  """Returns a function that computes f and its derivative w.r.t. params.
444
445  Example:
446  ```python
447  # f(x, y) = (x ^ 3) * y - x * (y ^ 2)
448  # Therefore, the 1st order derivatives are:
449  #   df / dx = 3 * (x ^ 2) * y - y ^ 2
450  #   df / dy = x ^ 3 - 2 * x * y
451  def f(x, y):
452    return x * x * x * y - x * y * y
453
454  # Obtain a function that returns the function value and the 1st order
455  # gradients.
456  val_grads_fn = tfe.value_and_gradients_function(f)
457
458  x = 2.0
459  y = 3.0
460
461  # Invoke the value-and-gradients function.
462  f_val, (x_grad, y_grad) = val_grads_fn(x, y)
463  assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2)
464  assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2
465  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
466
467  # To obtain a callable that returns the value of `f` and the gradient(s) of
468  # `f` with respect to a subset of its inputs, use the `params` keyword
469  # argument with `value_and_gradients_function()`.
470  val_ygrad_fn = tfe.value_and_gradients_function(f, params=[1])
471
472  f_val, (y_grad,) = val_ygrad_fn(x, y)
473  assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2)
474  assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
475  ```
476
477  Args:
478    f: function to be differentiated. If `f` returns a scalar, this scalar will
479      be differentiated. If `f` returns a tensor or list of tensors, by default
480      a scalar will be computed by adding all their values to produce a single
481      scalar. If desired, the tensors can be elementwise multiplied by the
482      tensors passed as the `dy` keyword argument to the returned gradient
483      function.
484    params: list of parameter names of f or list of integers indexing the
485      parameters with respect to which we'll differentiate. Passing `None`
486      differentiates with respect to all parameters.
487
488  Returns:
489    function which, when called, returns the value of f and the gradient
490    of f with respect to all of `params`. The function takes an extra optional
491    keyword argument "dy". Setting it allows computation of vector jacobian
492    products for vectors other than the vector of ones.
493
494  Raises:
495    ValueError: if the params are not all strings or all integers.
496  """
497
498  def decorated(*args, **kwds):
499    """Computes the value and gradient of the decorated function."""
500    dy = kwds.pop("dy", None)
501    if kwds:
502      raise ValueError("Functions to be differentiated cannot "
503                       "receive keyword arguments.")
504    val, vjp = make_vjp(f, params)(*args, **kwds)
505    return val, vjp(dy=dy)
506
507  return decorated
508
509
510def make_vjp(f, params=None, persistent=True):
511  """Returns a function that computes f and its vjp w.r.t.
512
513  params.
514
515  The term "vjp" here is an abbreviation for vector-jacobian product.
516
517  Args:
518    f: the function to be differentiated.
519    params: the parameters (numbers or names) to differentiate with respect to.
520      A value of None will differentiate with respect to all parameters.
521    persistent: Boolean controlling whether the VJP function can be re-used.
522      Must be True or False.
523
524  Returns:
525    A function, which when called, returns a tuple (value, vjp), where:
526    - value is the result of calling f.
527    - vjp is a function, which takes a vector as an argument and
528      returns the product of that vector with the Jacobian of f.
529      Providing no argument to vjp is equivalent to providing a
530      vector of ones.
531
532    For example,
533    ```python
534    def f(x):
535      return x * x
536
537    wrapped_fn = tfe.make_vjp(f)
538    result, vjp = wrapped_fn(tf.constant(3.0))
539    # result is 9.0
540    vjp()  # the vjp function returns 6.0
541
542  Raises:
543    ValueError: if `f` returns None.
544  """
545
546  def decorated(*args, **kwds):
547    """Computes the value and gradient of the decorated function."""
548    parameter_positions = _get_arg_spec(f, params, args)
549    assert not kwds, "The gradient function can't take keyword arguments."
550    this_tape = tape.push_new_tape(persistent=persistent)
551    try:
552      sources = []
553      args = [
554          ops.convert_to_tensor(arg) if i in parameter_positions else arg
555          for i, arg in enumerate(args)
556      ]
557      args = _ensure_unique_tensor_objects(parameter_positions, args)
558      for i in parameter_positions:
559        if getattr(args[i], "is_packed", False):
560          raise ValueError(
561              "GradientTape.gradient is not supported on packed EagerTensors"
562              "yet.")
563        sources.append(args[i])
564        tape.watch(this_tape, args[i])
565      result = f(*args)
566      if result is None:
567        raise ValueError("Cannot differentiate a function that returns None; "
568                         "did you forget to return a value from {}?".format(
569                             f.__name__))
570      flat_result = nest.flatten(result)
571      flat_result = [gen_array_ops.identity(x) for x in flat_result]
572      result = nest.pack_sequence_as(result, flat_result)
573    finally:
574      tape.pop_tape(this_tape)
575    def vjp(dy=None):
576      if dy is not None:
577        dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
578      return imperative_grad.imperative_grad(
579          this_tape, nest.flatten(result), sources, output_gradients=dy)
580
581    return result, vjp
582
583  return decorated
584
585
586def flatten_nested_indexed_slices(grad):
587  assert isinstance(grad, ops.IndexedSlices)
588  if isinstance(grad.values, ops.Tensor):
589    return grad
590  else:
591    assert isinstance(grad.values, ops.IndexedSlices)
592    g = flatten_nested_indexed_slices(grad.values)
593    return ops.IndexedSlices(g.values, array_ops.gather(grad.indices,
594                                                        g.indices),
595                             g.dense_shape)
596
597
598def aggregate_indexed_slices_gradients(grads):
599  """Aggregates gradients containing `IndexedSlices`s."""
600  if len(grads) < 1:
601    return None
602  if len(grads) == 1:
603    return grads[0]
604  grads = [g for g in grads if g is not None]
605  # If any gradient is a `Tensor`, sum them up and return a dense tensor
606  # object.
607  if any(isinstance(g, ops.Tensor) for g in grads):
608    return math_ops.add_n(grads)
609
610  # The following `_as_indexed_slices_list` casts ids of IndexedSlices into
611  # int64. It is to make sure the inputs of `concat` all have same the data
612  # type.
613  grads = math_ops._as_indexed_slices_list(grads)  # pylint: disable=protected-access
614
615  grads = [flatten_nested_indexed_slices(x) for x in grads]
616  # Form IndexedSlices out of the concatenated values and indices.
617  concat_grad = ops.IndexedSlices(
618      array_ops.concat([x.values for x in grads], axis=0),
619      array_ops.concat([x.indices for x in grads], axis=0),
620      grads[0].dense_shape)
621
622  return concat_grad
623
624
625def _aggregate_grads(gradients):
626  """Aggregate gradients from multiple sources.
627
628  Args:
629    gradients: A list of 'Tensor' or 'IndexedSlices' gradients.
630
631  Returns:
632    If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'.
633    Otherwise returns an aggregated 'IndexedSlices'.
634  """
635  assert gradients, "No gradients to aggregate"
636
637  if len(gradients) == 1:
638    return gradients[0]
639  if all(isinstance(g, ops.Tensor) for g in gradients):
640    return gen_math_ops.add_n(gradients)
641  else:
642    assert all(isinstance(g, (ops.Tensor, ops.IndexedSlices))
643               for g in gradients)
644    return aggregate_indexed_slices_gradients(gradients)
645
646
647def _num_elements(grad):
648  """The number of elements in the `grad` tensor."""
649  if isinstance(grad, ops.Tensor):
650    shape_tuple = grad._shape_tuple()  # pylint: disable=protected-access
651  elif isinstance(grad, ops.IndexedSlices):
652    shape_tuple = grad.values._shape_tuple()  # pylint: disable=protected-access
653  else:
654    raise ValueError("`grad` not a Tensor or IndexedSlices.")
655  if shape_tuple is None or None in shape_tuple:
656    return 0
657  return functools.reduce(operator.mul, shape_tuple, 1)
658
659
660def _fast_fill(value, shape, dtype):
661  return array_ops.fill(
662      constant_op.constant(shape, dtype=dtypes.int32),
663      constant_op.constant(value, dtype=dtype))
664
665
666def _zeros(shape, dtype):
667  """Helper to return (possibly cached) zero tensors in eager mode."""
668  # Note: variants will use _zeros_like
669  if dtype == dtypes.string or dtype == dtypes.resource:
670    return None
671
672  ctx = context.context()
673  if not ctx.executing_eagerly():
674    return array_ops.zeros(shape, dtype)
675
676  device = ctx.device_name
677
678  if tensor_util.is_tf_type(shape):
679    shape_key = shape.ref()
680  else:
681    shape_key = shape
682  cache_key = shape_key, dtype, device
683  cached = ctx.zeros_cache().get(cache_key)
684  if cached is None:
685    if dtypes.as_dtype(dtype).is_bool:
686      value = False
687    else:
688      value = 0
689    cached = _fast_fill(value, shape, dtype)
690    ctx.zeros_cache().put(cache_key, cached)
691  return cached
692
693
694def _ones(shape, dtype):
695  as_dtype = dtypes.as_dtype(dtype)
696  if as_dtype == dtypes.string:
697    return None
698
699  if not context.executing_eagerly():
700    return array_ops.ones(shape, dtype)
701
702  if as_dtype.is_bool:
703    value = True
704  else:
705    value = 1
706
707  if shape == ():  # pylint: disable=g-explicit-bool-comparison
708    return constant_op.constant(value, dtype=dtype)
709  return _fast_fill(value, shape, dtype)
710
711
712_default_vspace = imperative_grad.VSpace(
713    num_elements_fn=_num_elements,
714    aggregate_fn=_aggregate_grads,
715    zeros_fn=_zeros,
716    ones_fn=_ones,
717    zeros_like_fn=default_gradient.zeros_like,
718    ones_like_fn=default_gradient.ones_like,
719    graph_shape_fn=gen_array_ops.shape)
720pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace)
721
722
723def _handle_or_self(x):
724  """Unwrap resource variable/ndarray to return tensors."""
725  if resource_variable_ops.is_resource_variable(x):
726    return x.handle
727  return x
728
729
730@tf_export("GradientTape", "autodiff.GradientTape", v1=["GradientTape"])
731class GradientTape(object):
732  """Record operations for automatic differentiation.
733
734  Operations are recorded if they are executed within this context manager and
735  at least one of their inputs is being "watched".
736
737  Trainable variables (created by `tf.Variable` or `tf.compat.v1.get_variable`,
738  where `trainable=True` is default in both cases) are automatically watched.
739  Tensors can be manually watched by invoking the `watch` method on this context
740  manager.
741
742  For example, consider the function `y = x * x`. The gradient at `x = 3.0` can
743  be computed as:
744
745  >>> x = tf.constant(3.0)
746  >>> with tf.GradientTape() as g:
747  ...   g.watch(x)
748  ...   y = x * x
749  >>> dy_dx = g.gradient(y, x)
750  >>> print(dy_dx)
751  tf.Tensor(6.0, shape=(), dtype=float32)
752
753  GradientTapes can be nested to compute higher-order derivatives. For example,
754
755  >>> x = tf.constant(5.0)
756  >>> with tf.GradientTape() as g:
757  ...   g.watch(x)
758  ...   with tf.GradientTape() as gg:
759  ...     gg.watch(x)
760  ...     y = x * x
761  ...   dy_dx = gg.gradient(y, x)  # dy_dx = 2 * x
762  >>> d2y_dx2 = g.gradient(dy_dx, x)  # d2y_dx2 = 2
763  >>> print(dy_dx)
764  tf.Tensor(10.0, shape=(), dtype=float32)
765  >>> print(d2y_dx2)
766  tf.Tensor(2.0, shape=(), dtype=float32)
767
768  By default, the resources held by a GradientTape are released as soon as
769  GradientTape.gradient() method is called. To compute multiple gradients over
770  the same computation, create a persistent gradient tape. This allows multiple
771  calls to the gradient() method as resources are released when the tape object
772  is garbage collected. For example:
773
774  >>> x = tf.constant(3.0)
775  >>> with tf.GradientTape(persistent=True) as g:
776  ...   g.watch(x)
777  ...   y = x * x
778  ...   z = y * y
779  >>> dz_dx = g.gradient(z, x)  # (4*x^3 at x = 3)
780  >>> print(dz_dx)
781  tf.Tensor(108.0, shape=(), dtype=float32)
782  >>> dy_dx = g.gradient(y, x)
783  >>> print(dy_dx)
784  tf.Tensor(6.0, shape=(), dtype=float32)
785
786  By default GradientTape will automatically watch any trainable variables that
787  are accessed inside the context. If you want fine grained control over which
788  variables are watched you can disable automatic tracking by passing
789  `watch_accessed_variables=False` to the tape constructor:
790
791  >>> x = tf.Variable(2.0)
792  >>> w = tf.Variable(5.0)
793  >>> with tf.GradientTape(
794  ...     watch_accessed_variables=False, persistent=True) as tape:
795  ...   tape.watch(x)
796  ...   y = x ** 2  # Gradients will be available for `x`.
797  ...   z = w ** 3  # No gradients will be available as `w` isn't being watched.
798  >>> dy_dx = tape.gradient(y, x)
799  >>> print(dy_dx)
800  tf.Tensor(4.0, shape=(), dtype=float32)
801  >>> # No gradients will be available as `w` isn't being watched.
802  >>> dz_dy = tape.gradient(z, w)
803  >>> print(dz_dy)
804  None
805
806  Note that when using models you should ensure that your variables exist when
807  using `watch_accessed_variables=False`. Otherwise it's quite easy to make your
808  first iteration not have any gradients:
809
810  ```python
811  a = tf.keras.layers.Dense(32)
812  b = tf.keras.layers.Dense(32)
813
814  with tf.GradientTape(watch_accessed_variables=False) as tape:
815    tape.watch(a.variables)  # Since `a.build` has not been called at this point
816                             # `a.variables` will return an empty list and the
817                             # tape will not be watching anything.
818    result = b(a(inputs))
819    tape.gradient(result, a.variables)  # The result of this computation will be
820                                        # a list of `None`s since a's variables
821                                        # are not being watched.
822  ```
823
824  Note that only tensors with real or complex dtypes are differentiable.
825  """
826
827  def __init__(self, persistent=False, watch_accessed_variables=True):
828    """Creates a new GradientTape.
829
830    Args:
831      persistent: Boolean controlling whether a persistent gradient tape
832        is created. False by default, which means at most one call can
833        be made to the gradient() method on this object.
834      watch_accessed_variables: Boolean controlling whether the tape will
835        automatically `watch` any (trainable) variables accessed while the tape
836        is active. Defaults to True meaning gradients can be requested from any
837        result computed in the tape derived from reading a trainable `Variable`.
838        If False users must explicitly `watch` any `Variable`s they want to
839        request gradients from.
840    """
841    self._tape = None
842    self._persistent = persistent
843    self._watch_accessed_variables = watch_accessed_variables
844    self._watched_variables = ()
845    self._recording = False
846
847  def __enter__(self):
848    """Enters a context inside which operations are recorded on this tape."""
849    self._push_tape()
850    return self
851
852  def __exit__(self, typ, value, traceback):
853    """Exits the recording context, no further operations are traced."""
854    if self._recording:
855      self._pop_tape()
856
857  def _push_tape(self):
858    """Pushes a new tape onto the tape stack."""
859    if self._recording:
860      raise ValueError("Tape is still recording, This can happen if you try to "
861                       "re-enter an already-active tape.")
862    if self._tape is None:
863      self._tape = tape.push_new_tape(
864          persistent=self._persistent,
865          watch_accessed_variables=self._watch_accessed_variables)
866    else:
867      tape.push_tape(self._tape)
868    self._recording = True
869
870  def _pop_tape(self):
871    if not self._recording:
872      raise ValueError("Tape is not recording.")
873    tape.pop_tape(self._tape)
874    self._recording = False
875
876  @tf_contextlib.contextmanager
877  def _ensure_recording(self):
878    """Ensures that this tape is recording."""
879    if not self._recording:
880      try:
881        self._push_tape()
882        yield
883      finally:
884        self._pop_tape()
885    else:
886      yield
887
888  def watch(self, tensor):
889    """Ensures that `tensor` is being traced by this tape.
890
891    Args:
892      tensor: a Tensor or list of Tensors.
893
894    Raises:
895      ValueError: if it encounters something that is not a tensor.
896    """
897    for t in nest.flatten(tensor, expand_composites=True):
898      if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
899        raise ValueError("Passed in object of type {}, not tf.Tensor".format(
900            type(t)))
901      if not backprop_util.IsTrainable(t):
902        logging.log_first_n(
903            logging.WARN, "The dtype of the watched tensor must be "
904            "floating (e.g. tf.float32), got %r", 5, t.dtype)
905      if hasattr(t, "handle"):
906        # There are many variable-like objects, all of them currently have
907        # `handle` attribute that points to a tensor. If this changes, internals
908        # of watch_variable need to change as well.
909        tape.watch_variable(self._tape, t)
910      else:
911        tape.watch(self._tape, t)
912
913  @tf_contextlib.contextmanager
914  def stop_recording(self):
915    """Temporarily stops recording operations on this tape.
916
917    Operations executed while this context manager is active will not be
918    recorded on the tape. This is useful for reducing the memory used by tracing
919    all computations.
920
921    For example:
922
923    >>> x = tf.constant(4.0)
924    >>> with tf.GradientTape() as tape:
925    ...   with tape.stop_recording():
926    ...     y = x ** 2
927    >>> dy_dx = tape.gradient(y, x)
928    >>> print(dy_dx)
929    None
930
931    Yields:
932      None
933    Raises:
934      RuntimeError: if the tape is not currently recording.
935    """
936    if self._tape is None:
937      raise RuntimeError(
938          "Trying to stop recording a tape which is not recording.")
939    self._pop_tape()
940    try:
941      yield
942    finally:
943      self._push_tape()
944
945  def reset(self):
946    """Clears all information stored in this tape.
947
948    Equivalent to exiting and reentering the tape context manager with a new
949    tape. For example, the two following code blocks are equivalent:
950
951    ```
952    with tf.GradientTape() as t:
953      loss = loss_fn()
954    with tf.GradientTape() as t:
955      loss += other_loss_fn()
956    t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn
957
958
959    # The following is equivalent to the above
960    with tf.GradientTape() as t:
961      loss = loss_fn()
962      t.reset()
963      loss += other_loss_fn()
964    t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn
965    ```
966
967    This is useful if you don't want to exit the context manager for the tape,
968    or can't because the desired reset point is inside a control flow construct:
969
970    ```
971    with tf.GradientTape() as t:
972      loss = ...
973      if loss > k:
974        t.reset()
975    ```
976    """
977    self._pop_tape()
978    self._tape = None
979    self._push_tape()
980
981  def watched_variables(self):
982    """Returns variables watched by this tape in order of construction."""
983    if self._tape is not None:
984      self._watched_variables = self._tape.watched_variables()
985    return self._watched_variables
986
987  def gradient(self,
988               target,
989               sources,
990               output_gradients=None,
991               unconnected_gradients=UnconnectedGradients.NONE):
992    """Computes the gradient using operations recorded in context of this tape.
993
994    Note: Unless you set `persistent=True` a GradientTape can only be used to
995    compute one set of gradients (or jacobians).
996
997    Args:
998      target: a list or nested structure of Tensors or Variables to be
999        differentiated.
1000      sources: a list or nested structure of Tensors or Variables. `target`
1001        will be differentiated against elements in `sources`.
1002      output_gradients: a list of gradients, one for each element of
1003        target. Defaults to None.
1004      unconnected_gradients: a value which can either hold 'none' or 'zero' and
1005        alters the value which will be returned if the target and sources are
1006        unconnected. The possible values and effects are detailed in
1007        'UnconnectedGradients' and it defaults to 'none'.
1008
1009    Returns:
1010      a list or nested structure of Tensors (or IndexedSlices, or None),
1011      one for each element in `sources`. Returned structure is the same as
1012      the structure of `sources`.
1013
1014    Raises:
1015      RuntimeError: If called on a used, non-persistent tape.
1016      RuntimeError: If called inside the context of the tape.
1017      TypeError: If the target is a None object.
1018      ValueError: If the target is a variable or if unconnected gradients is
1019       called with an unknown value.
1020    """
1021    if self._tape is None:
1022      raise RuntimeError("A non-persistent GradientTape can only be used to "
1023                         "compute one set of gradients (or jacobians)")
1024    if self._recording:
1025      if not self._persistent:
1026        self._pop_tape()
1027      else:
1028        logging.log_first_n(
1029            logging.WARN, "Calling GradientTape.gradient on a persistent "
1030            "tape inside its context is significantly less "
1031            "efficient than calling it outside the context (it "
1032            "causes the gradient ops to be recorded on the "
1033            "tape, leading to increased CPU and memory usage). "
1034            "Only call GradientTape.gradient inside the "
1035            "context if you actually want to trace the "
1036            "gradient in order to compute higher order "
1037            "derivatives.", 1)
1038
1039    if target is None:
1040      raise TypeError("Target should be a list or nested structure"
1041                      " of Tensors or Variables to be differentiated,"
1042                      " but recieved %r" % (target))
1043
1044    flat_targets = []
1045    for t in nest.flatten(target):
1046      if not backprop_util.IsTrainable(t):
1047        logging.vlog(
1048            logging.WARN, "The dtype of the target tensor must be "
1049            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
1050            "got %r", t.dtype)
1051      if resource_variable_ops.is_resource_variable(t):
1052        with self:
1053          t = ops.convert_to_tensor(t)
1054      flat_targets.append(t)
1055
1056    flat_sources = nest.flatten(sources)
1057    flat_sources_raw = flat_sources
1058    flat_sources = [_handle_or_self(x) for x in flat_sources]
1059    for t in flat_sources_raw:
1060      if not backprop_util.IsTrainable(t):
1061        logging.vlog(
1062            logging.WARN, "The dtype of the source tensor must be "
1063            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
1064            "got %r", t.dtype)
1065      if getattr(t, "is_packed", False):
1066        raise ValueError(
1067            "GradientTape.gradient is not supported on packed EagerTensors yet."
1068        )
1069
1070    if output_gradients is not None:
1071      output_gradients = [None if x is None else ops.convert_to_tensor(x)
1072                          for x in nest.flatten(output_gradients)]
1073
1074    flat_grad = imperative_grad.imperative_grad(
1075        self._tape,
1076        flat_targets,
1077        flat_sources,
1078        output_gradients=output_gradients,
1079        sources_raw=flat_sources_raw,
1080        unconnected_gradients=unconnected_gradients)
1081
1082    if not self._persistent:
1083      # Keep track of watched variables before setting tape to None
1084      self._watched_variables = self._tape.watched_variables()
1085      self._tape = None
1086
1087    grad = nest.pack_sequence_as(sources, flat_grad)
1088    return grad
1089
1090  def jacobian(self,
1091               target,
1092               sources,
1093               unconnected_gradients=UnconnectedGradients.NONE,
1094               parallel_iterations=None,
1095               experimental_use_pfor=True):
1096    """Computes the jacobian using operations recorded in context of this tape.
1097
1098    Note: Unless you set `persistent=True` a GradientTape can only be used to
1099    compute one set of gradients (or jacobians).
1100
1101    Note: By default the jacobian implementation uses parallel for (pfor), which
1102    creates a tf.function under the hood for each jacobian call. For better
1103    performance, and to avoid recompilation and vectorization rewrites on each
1104    call, enclose GradientTape code in @tf.function.
1105
1106    See[wikipedia
1107    article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant)
1108    for the definition of a Jacobian.
1109
1110    Example usage:
1111
1112    ```python
1113    with tf.GradientTape() as g:
1114      x  = tf.constant([1.0, 2.0])
1115      g.watch(x)
1116      y = x * x
1117    jacobian = g.jacobian(y, x)
1118    # jacobian value is [[2., 0.], [0., 4.]]
1119    ```
1120
1121    Args:
1122      target: Tensor to be differentiated.
1123      sources: a list or nested structure of Tensors or Variables. `target`
1124        will be differentiated against elements in `sources`.
1125      unconnected_gradients: a value which can either hold 'none' or 'zero' and
1126        alters the value which will be returned if the target and sources are
1127        unconnected. The possible values and effects are detailed in
1128        'UnconnectedGradients' and it defaults to 'none'.
1129      parallel_iterations: A knob to control how many iterations are dispatched
1130        in parallel. This knob can be used to control the total memory usage.
1131      experimental_use_pfor: If true, vectorizes the jacobian computation. Else
1132        falls back to a sequential while_loop. Vectorization can sometimes fail
1133        or lead to excessive memory usage. This option can be used to disable
1134        vectorization in such cases.
1135
1136    Returns:
1137      A list or nested structure of Tensors (or None), one for each element in
1138      `sources`. Returned structure is the same as the structure of `sources`.
1139      Note if any gradient is sparse (IndexedSlices), jacobian function
1140      currently makes it dense and returns a Tensor instead. This may change in
1141      the future.
1142
1143
1144    Raises:
1145      RuntimeError: If called on a used, non-persistent tape.
1146      RuntimeError: If called on a non-persistent tape with eager execution
1147        enabled and without enabling experimental_use_pfor.
1148      ValueError: If vectorization of jacobian computation fails.
1149    """
1150    if self._tape is None:
1151      raise RuntimeError("A non-persistent GradientTape can only be used to "
1152                         "compute one set of gradients (or jacobians)")
1153
1154    flat_sources = nest.flatten(sources)
1155    target_static_shape = target.shape
1156    target_shape = array_ops.shape(target)
1157    # Note that we push and pop the tape here and below. This is needed since we
1158    # need gradients through the enclosed operations.
1159    with self._ensure_recording():
1160      target = array_ops.reshape(target, [-1])
1161
1162    def loop_fn(i):
1163      with self._ensure_recording():
1164        y = array_ops.gather(target, i)
1165      return self.gradient(y, flat_sources,
1166                           unconnected_gradients=unconnected_gradients)
1167
1168    try:
1169      target_size = int(target.shape[0])
1170    except TypeError:
1171      target_size = array_ops.shape(target)[0]
1172
1173    if experimental_use_pfor:
1174      try:
1175        output = pfor_ops.pfor(loop_fn, target_size,
1176                               parallel_iterations=parallel_iterations)
1177      except ValueError as err:
1178        six.reraise(
1179            ValueError,
1180            ValueError(
1181                str(err) + "\nEncountered an exception while vectorizing the "
1182                "jacobian computation. Vectorization can be disabled by setting"
1183                " experimental_use_pfor to False."),
1184            sys.exc_info()[2])
1185    else:
1186      if context.executing_eagerly() and not self._persistent:
1187        raise RuntimeError(
1188            "GradientTape must be created with persistent=True"
1189            " to compute the jacobian with eager execution enabled and with "
1190            " experimental_use_pfor set to False.")
1191      output = pfor_ops.for_loop(
1192          loop_fn, [target.dtype] * len(flat_sources), target_size,
1193          parallel_iterations=parallel_iterations)
1194
1195    for i, out in enumerate(output):
1196      if out is not None:
1197        new_shape = array_ops.concat(
1198            [target_shape, array_ops.shape(out)[1:]], axis=0)
1199        out = array_ops.reshape(out, new_shape)
1200        if context.executing_eagerly():
1201          out.set_shape(target_static_shape.concatenate(flat_sources[i].shape))
1202      output[i] = out
1203
1204    return nest.pack_sequence_as(sources, output)
1205
1206  def batch_jacobian(self,
1207                     target,
1208                     source,
1209                     unconnected_gradients=UnconnectedGradients.NONE,
1210                     parallel_iterations=None,
1211                     experimental_use_pfor=True):
1212    """Computes and stacks per-example jacobians.
1213
1214    See [wikipedia article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant)
1215    for the definition of a Jacobian. This function is essentially an efficient
1216    implementation of the following:
1217
1218    `tf.stack([self.jacobian(y[i], x[i]) for i in range(x.shape[0])])`.
1219
1220    Note that compared to `GradientTape.jacobian` which computes gradient of
1221    each output value w.r.t each input value, this function is useful when
1222    `target[i,...]` is independent of `source[j,...]` for `j != i`. This
1223    assumption allows more efficient computation as compared to
1224    `GradientTape.jacobian`. The output, as well as intermediate activations,
1225    are lower dimensional and avoid a bunch of redundant zeros which would
1226    result in the jacobian computation given the independence assumption.
1227
1228    Note: Unless you set `persistent=True` a GradientTape can only be used to
1229    compute one set of gradients (or jacobians).
1230
1231    Note: By default the batch_jacobian implementation uses parallel for (pfor),
1232    which creates a tf.function under the hood for each batch_jacobian call.
1233    For better performance, and to avoid recompilation and vectorization
1234    rewrites on each call, enclose GradientTape code in @tf.function.
1235
1236
1237    Example usage:
1238
1239    ```python
1240    with tf.GradientTape() as g:
1241      x = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32)
1242      g.watch(x)
1243      y = x * x
1244    batch_jacobian = g.batch_jacobian(y, x)
1245    # batch_jacobian is [[[2,  0], [0,  4]], [[6,  0], [0,  8]]]
1246    ```
1247
1248    Args:
1249      target: A tensor with rank 2 or higher and with shape [b, y1, ..., y_n].
1250        `target[i,...]` should only depend on `source[i,...]`.
1251      source: A tensor with rank 2 or higher and with shape [b, x1, ..., x_m].
1252      unconnected_gradients: a value which can either hold 'none' or 'zero' and
1253        alters the value which will be returned if the target and sources are
1254        unconnected. The possible values and effects are detailed in
1255        'UnconnectedGradients' and it defaults to 'none'.
1256      parallel_iterations: A knob to control how many iterations are dispatched
1257        in parallel. This knob can be used to control the total memory usage.
1258      experimental_use_pfor: If true, uses pfor for computing the Jacobian. Else
1259        uses a tf.while_loop.
1260
1261    Returns:
1262      A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]`
1263      is the jacobian of `target[i, ...]` w.r.t. `source[i, ...]`, i.e. stacked
1264      per-example jacobians.
1265
1266    Raises:
1267      RuntimeError: If called on a used, non-persistent tape.
1268      RuntimeError: If called on a non-persistent tape with eager execution
1269        enabled and without enabling experimental_use_pfor.
1270      ValueError: If vectorization of jacobian computation fails or if first
1271        dimension of `target` and `source` do not match.
1272    """
1273    if self._tape is None:
1274      raise RuntimeError("A non-persistent GradientTape can only be used to"
1275                         "compute one set of gradients (or jacobians)")
1276    target_shape = target.shape
1277    if target_shape.rank is None:
1278      dim = tensor_shape.Dimension(None)
1279    else:
1280      dim = target_shape.dims[0]
1281    if not (target_shape.with_rank_at_least(2) and
1282            source.shape.with_rank_at_least(2) and
1283            dim.is_compatible_with(source.shape[0])):
1284      raise ValueError(
1285          "Need first dimension of target shape (%s) and "
1286          "source shape (%s) to match." % (target.shape, source.shape))
1287    if target_shape.is_fully_defined():
1288      batch_size = int(target_shape[0])
1289      target_row_size = target_shape.num_elements() // batch_size
1290    else:
1291      target_shape = array_ops.shape(target)
1292      batch_size = target_shape[0]
1293      target_row_size = array_ops.size(target) // batch_size
1294    source_shape = array_ops.shape(source)
1295    # Flatten target to 2-D.
1296    # Note that we push and pop the tape here and below. This is needed since we
1297    # need gradients through the enclosed operations.
1298    with self._ensure_recording():
1299      with ops.control_dependencies(
1300          [check_ops.assert_equal(batch_size, source_shape[0])]):
1301        target = array_ops.reshape(target, [batch_size, target_row_size])
1302
1303    def loop_fn(i):
1304      with self._ensure_recording():
1305        y = array_ops.gather(target, i, axis=1)
1306      return self.gradient(y, source,
1307                           unconnected_gradients=unconnected_gradients)
1308
1309    if experimental_use_pfor:
1310      try:
1311        output = pfor_ops.pfor(loop_fn, target_row_size,
1312                               parallel_iterations=parallel_iterations)
1313      except ValueError as err:
1314        six.reraise(
1315            ValueError,
1316            ValueError(
1317                str(err) + "\nEncountered an exception while vectorizing the "
1318                "batch_jacobian computation. Vectorization can be disabled by "
1319                "setting experimental_use_pfor to False."),
1320            sys.exc_info()[2])
1321    else:
1322      if context.executing_eagerly() and not self._persistent:
1323        raise RuntimeError(
1324            "GradientTape must be created with persistent=True"
1325            " to compute the batch_jacobian with eager execution enabled and "
1326            " with experimental_use_pfor set to False.")
1327      output = pfor_ops.for_loop(loop_fn, target.dtype, target_row_size,
1328                                 parallel_iterations=parallel_iterations)
1329    new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
1330    if output is None:
1331      # Note that this block is returning zeros when it could use `None` to
1332      # represent unconnected gradients. This is to maintain compatibility with
1333      # the previous behavior, which ignored `unconnected_gradients`.
1334      output = array_ops.zeros(new_shape, target.dtype)
1335      return output
1336    else:
1337      output = array_ops.reshape(output,
1338                                 [target_row_size, batch_size, -1])
1339      output = array_ops.transpose(output, [1, 0, 2])
1340
1341      output = array_ops.reshape(output, new_shape)
1342      return output
1343