1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradient tape utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22
23from tensorflow.python import pywrap_tfe
24from tensorflow.python.util.lazy_loader import LazyLoader
25
26# There is a circular dependency between this, ops.py, and
27# distribution_strategy_context.
28# TODO(b/117329403): Remove this circular dependency.
29distribution_strategy_context = LazyLoader(
30    "distribution_strategy_context", globals(),
31    "tensorflow.python.distribute."
32    "distribution_strategy_context")
33
34
35class Tape(object):
36  """Represents a gradient propagation trace."""
37
38  __slots__ = ["_tape"]
39
40  def __init__(self, tape):
41    self._tape = tape
42
43  def watched_variables(self):
44    return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape)
45
46
47def push_new_tape(persistent=False, watch_accessed_variables=True):
48  """Pushes a new tape onto the tape stack."""
49  tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables)
50  return Tape(tape)
51
52
53def push_tape(tape):
54  """Pushes an existing tape onto the tape stack."""
55  pywrap_tfe.TFE_Py_TapeSetAdd(tape._tape)  # pylint: disable=protected-access
56
57
58def watch(tape, tensor):
59  """Marks this tensor to be watched by the given tape."""
60  pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor)  # pylint: disable=protected-access
61
62
63class VariableWatcher(object):
64  """A scope that tracks all trainable variable accesses within it.
65
66  This explicitly ignores variables that are not marked as trainable.
67
68  Sample usage:
69
70  var = tf.Variable(0.0)
71  with VariableWatcher() as variable_watcher:
72    var.assign_add(1.0)
73
74  assert variable_watcher.watched_variables == [var]
75  """
76
77  __slots__ = ["_variable_watcher"]
78
79  def __init__(self):
80    self._variable_watcher = None
81
82  def __enter__(self):
83    self._variable_watcher = pywrap_tfe.TFE_Py_VariableWatcherNew()
84    return self
85
86  def __exit__(self, typ, value, traceback):
87    pywrap_tfe.TFE_Py_VariableWatcherRemove(self._variable_watcher)
88
89  def watched_variables(self):
90    """Returns a tuple of variables accessed under this scope."""
91    return pywrap_tfe.TFE_Py_VariableWatcherWatchedVariables(
92        self._variable_watcher)
93
94
95def watch_variable(tape, variable):
96  """Marks this variable to be watched by the given tape."""
97  strategy, context = (
98      distribution_strategy_context.get_strategy_and_replica_context())
99  if context:
100    variables = [strategy.extended.value_container(variable)]
101  else:
102    variables = strategy.experimental_local_results(variable)
103  for var in variables:
104    pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var)  # pylint: disable=protected-access
105    pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
106
107
108def variable_accessed(variable):
109  """Notifies all tapes in the stack that a variable has been accessed.
110
111  Args:
112    variable: variable to be watched.
113  """
114  strategy, context = (
115      distribution_strategy_context.get_strategy_and_replica_context())
116  if context:
117    variables = [strategy.extended.value_container(variable)]
118  else:
119    variables = strategy.experimental_local_results(variable)
120  for var in variables:
121    pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
122    pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
123
124
125def variables_accessed(variables):
126  """Notifies all tapes in the stack that variables have been accessed.
127
128  Only trainable variables are marked as accessed.
129
130  Args:
131    variables: iterable of variables to mark as accessed.
132  """
133  strategy, context = (
134      distribution_strategy_context.get_strategy_and_replica_context())
135  accessed = []
136  if context:
137    accessed = [strategy.extended.value_container(variable)
138                for variable in variables if variable.trainable]
139  else:
140    for variable in variables:
141      if variable.trainable:
142        accessed.extend(strategy.experimental_local_results(variable))
143
144  for var in accessed:
145    pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
146    pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
147
148
149def pop_tape(tape):
150  """Pops the given tape in the stack."""
151  pywrap_tfe.TFE_Py_TapeSetRemove(tape._tape)  # pylint: disable=protected-access
152
153
154@contextlib.contextmanager
155def stop_recording():
156  """Stop all gradient recording (backprop and forwardprop)."""
157  is_stopped = pywrap_tfe.TFE_Py_TapeSetIsStopped()
158  try:
159    if not is_stopped:
160      pywrap_tfe.TFE_Py_TapeSetStopOnThread()
161    yield
162  finally:
163    if not is_stopped:
164      pywrap_tfe.TFE_Py_TapeSetRestartOnThread()
165
166
167def should_record_backprop(tensors):
168  """Returns true if any tape in the stack watches any of these tensors.
169
170  Only takes GradientTapes into account, not forward accumulators.
171
172  Args:
173    tensors: Tensors to check, typically inputs to an operation.
174
175  Returns:
176    Boolean, whether any tape watches any of `tensors`.
177  """
178  return pywrap_tfe.TFE_Py_TapeSetShouldRecordBackprop(tensors)
179
180
181def record_operation(op_type, output_tensors, input_tensors, backward_function,
182                     forward_function=None):
183  """Records the operation on all tapes in the stack."""
184  pywrap_tfe.TFE_Py_TapeSetRecordOperation(op_type, output_tensors,
185                                           input_tensors, backward_function,
186                                           forward_function)
187
188
189def record_operation_backprop_only(op_type, output_tensors, input_tensors,
190                                   backward_function):
191  """Records the operation on all backward tapes in the stack."""
192  pywrap_tfe.TFE_Py_TapeSetRecordOperationBackprop(op_type, output_tensors,
193                                                   input_tensors,
194                                                   backward_function)
195
196
197def record_operation_forwardprop_only(op_type, output_tensors, input_tensors,
198                                      backward_function,
199                                      forwardprop_output_indices):
200  """Records the operation on all forward accumulators in the stack.
201
202  Args:
203    op_type: a string for the operation type, used in the backprop code
204    output_tensors: a list of Python Tensor objects output by the operation
205    input_tensors: a list of input Tensors to the recorded operation
206    backward_function: the function to be called to, given the gradients of the
207      output tensors, produce the gradients of the input tensors. This function
208      is automatically transposed to produce output gradients given input
209      gradients.
210    forwardprop_output_indices: indicates any output_tensors which contain JVPs.
211      Typically these will have come from TFE_Py_PackForwardGradients. May be
212      None or an empty sequence if there are no JVP outputs from the operation.
213  """
214  pywrap_tfe.TFE_Py_TapeSetRecordOperationForwardprop(
215      op_type, output_tensors, input_tensors, backward_function,
216      forwardprop_output_indices)
217
218
219def delete_trace(tensor_id):
220  """Deletes traces for this Tensor from all tapes in the stack."""
221  pywrap_tfe.TFE_Py_TapeSetDeleteTrace(tensor_id)
222
223
224def could_possibly_record():
225  """Returns True if any tape is active."""
226  return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
227