1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradient tape utilites."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22
23from tensorflow.python import pywrap_tensorflow
24from tensorflow.python.util.lazy_loader import LazyLoader
25
26# There is a circular dependency between this, ops.py, and
27# distribution_strategy_context.
28# TODO(b/117329403): Remove this circular dependency.
29distribution_strategy_context = LazyLoader(
30    "distribution_strategy_context", globals(),
31    "tensorflow.python.distribute."
32    "distribution_strategy_context")
33
34
35class Tape(object):
36  """Represents a gradient propagation trace."""
37
38  def __init__(self, tape):
39    self._tape = tape
40
41  def watched_variables(self):
42    return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
43
44
45def push_new_tape(persistent=False, watch_accessed_variables=True):
46  """Pushes a new tape onto the tape stack."""
47  tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent,
48                                             watch_accessed_variables)
49  return Tape(tape)
50
51
52def push_tape(tape):
53  """Pushes an existing tape onto the tape stack."""
54  pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape)  # pylint: disable=protected-access
55
56
57def watch(tape, tensor):
58  """Marks this tensor to be watched by the given tape."""
59  pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor)  # pylint: disable=protected-access
60
61
62def watch_variable(tape, variable):
63  """Marks this variable to be watched by the given tape."""
64  strategy, context = (
65      distribution_strategy_context.get_strategy_and_replica_context())
66  if context:
67    variables = [strategy.extended.value_container(variable)]
68  else:
69    variables = strategy.unwrap(variable)
70  for var in variables:
71    pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var)  # pylint: disable=protected-access
72
73
74def variable_accessed(variable):
75  """Notifies all tapes in the stack that a variable has been accessed.
76
77  Args:
78    variable: variable to be watched.
79  """
80  strategy, context = (
81      distribution_strategy_context.get_strategy_and_replica_context())
82  if context:
83    variables = [strategy.extended.value_container(variable)]
84  else:
85    variables = strategy.unwrap(variable)
86  for var in variables:
87    pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
88
89
90def variables_accessed(variables):
91  """Notifies all tapes in the stack that variables have been accessed.
92
93  Only trainable variables are marked as accessed.
94
95  Args:
96    variables: iterable of variables to mark as accessed.
97  """
98  strategy, context = (
99      distribution_strategy_context.get_strategy_and_replica_context())
100  accessed = []
101  if context:
102    accessed = [strategy.extended.value_container(variable)
103                for variable in variables if variable.trainable]
104  else:
105    for variable in variables:
106      if variable.trainable:
107        accessed.extend(strategy.unwrap(variable))
108
109  for var in accessed:
110    pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
111
112
113def pop_tape(tape):
114  """Pops the top tape in the stack, if any."""
115  pywrap_tensorflow.TFE_Py_TapeSetRemove(tape._tape)  # pylint: disable=protected-access
116
117
118@contextlib.contextmanager
119def stop_recording():
120  try:
121    pywrap_tensorflow.TFE_Py_TapeSetStopOnThread()
122    yield
123  finally:
124    pywrap_tensorflow.TFE_Py_TapeSetRestartOnThread()
125
126
127def should_record(tensors):
128  """Returns true if any tape in the stack watches any of these tensors."""
129  return pywrap_tensorflow.TFE_Py_TapeSetShouldRecord(tensors)
130
131
132def record_operation(op_type, output_tensors, input_tensors, backward_function):
133  """Records the operation on all tapes in the stack."""
134  pywrap_tensorflow.TFE_Py_TapeSetRecordOperation(
135      op_type, output_tensors, input_tensors, backward_function)
136
137
138def delete_trace(tensor_id):
139  """Deletes traces for this Tensor from all tapes in the stack."""
140  pywrap_tensorflow.TFE_Py_TapeSetDeleteTrace(tensor_id)
141
142
143def could_possibly_record():
144  """Returns True if any tape is active."""
145  return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty()
146