1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Debugger: Tools for debugging gradients."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import re
22import uuid
23
24import six
25
26from tensorflow.python.debug.lib import debug_data
27from tensorflow.python.debug.lib import debug_graphs
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import gen_array_ops
30from tensorflow.python.ops import variables
31
32_GRADIENT_DEBUG_TAG = "gradient_debug_"
33
34_gradient_debuggers = {}
35
36
37def _tensor_to_grad_debug_op_name(tensor, grad_debugger_uuid):
38  op_name, slot = debug_graphs.parse_node_or_tensor_name(tensor.name)
39  return "%s_%d/%s%s" % (op_name, slot, _GRADIENT_DEBUG_TAG, grad_debugger_uuid)
40
41
42def _parse_grad_debug_op_name(op_name):
43  """Parse the name of a debug gradient op.
44
45  Args:
46    op_name: the name of the debug gradient op.
47
48  Returns:
49    1) The UUID of the GradientsDebugger that created the debug gradient op.
50    2) Name of the original tensor whose gradient is debugged by the debug
51       gradient op.
52  """
53  name_items = op_name.split("/")
54  assert len(name_items) > 1
55  assert name_items[-1].startswith(_GRADIENT_DEBUG_TAG)
56
57  grad_debugger_uuid = name_items[-1][len(_GRADIENT_DEBUG_TAG):]
58  if "_" in grad_debugger_uuid:
59    grad_debugger_uuid = grad_debugger_uuid[:grad_debugger_uuid.index("_")]
60  orig_tensor_slot = int(name_items[-2][name_items[-2].rfind("_") + 1:])
61  orig_base_op_name = name_items[-2][:name_items[-2].rfind("_")]
62  orig_tensor_name = ("/".join(name_items[:-2] + [orig_base_op_name]) +
63                      ":%d" % orig_tensor_slot)
64
65  return grad_debugger_uuid, orig_tensor_name
66
67
68class GradientsDebugger(object):
69  """Gradients Debugger.
70
71  Allows retrieval of gradient tensors created by TensorFlow's automatic
72  differentiation algorithm, i.e., `tf.gradients` and optimizer classes that
73  use it.
74  """
75  # TODO(cais): Add examples code in the doc string?
76
77  def __init__(self, y_tensor=None):
78    """Constructor of GradientsDebugger.
79
80    Args:
81      y_tensor: optional: the `tf.Tensor` to be differentiated, i.e., the tensor
82        on the numerator of the differentiation.
83    """
84
85    self._uuid = uuid.uuid4().hex
86    _gradient_debuggers[self._uuid] = self
87
88    # A dict mapping x-tensor names to gradient tensor. x-tensor refers to the
89    # independent tf.Tensor, i.e., the tensor on the denominator of the
90    # differentiation.
91    self._gradient_tensors = {}
92    self._y_tensor = y_tensor
93
94    self._graph = None
95    if y_tensor:
96      self._graph = y_tensor.graph
97
98    self._is_active_context = False
99
100  @property
101  def y_tensor(self):
102    return self._y_tensor
103
104  @property
105  def graph(self):
106    return self._graph
107
108  def __enter__(self):
109    self._is_active_context = True
110
111  def __exit__(self, unused_type, unused_value, unused_traceback):
112    self._is_active_context = False
113
114  def identify_gradient(self, input_tensor):
115    """Create a debug identity tensor that registers and forwards gradients.
116
117    The side effect of this method is that when gradient tensor(s) are created
118    with respect to the any paths that include the `input_tensor`, the gradient
119    tensor(s) with repsect to `input_tensor` will be registered with this
120    this `GradientsDebugger` instance and can later be retrieved, with the
121    methods `gradient_tensor` and `gradient_tensors`.
122
123    Example:
124
125    ```python
126    x = tf.Variable(1.0)
127    y = tf.add(x, x)
128
129    grad_debugger = tf_debug.GradientsDebugger()
130    debug_y = grad_debugger.identify_gradient(y)
131    z = tf.square(debug_y)
132
133    # Create a train op under the grad_debugger context.
134    with grad_debugger:
135      train_op = tf.train.GradientDescentOptimizer(z)
136
137    # Now we can reflect through grad_debugger to get the gradient tensor
138    # with respect to y.
139    y_grad = grad_debugger.gradient_tensor(y)
140    ```
141
142    Args:
143      input_tensor: the input `tf.Tensor` object whose related gradient tensors
144        are to be reigstered with this `GradientsDebugger` instance when they
145        are created, e.g., during `tf.gradients` calls or the construction
146        of optimization (training) op that uses `tf.gradients`.
147
148    Returns:
149      A forwarded identity of `input_tensor`, as a `tf.Tensor`.
150
151    Raises:
152      ValueError: If an op with name that duplicates the gradient-debugging op
153        already exists in the graph (highly unlikely).
154    """
155    # TODO(cais): Allow overriding gradient.
156    # TODO(cais): Implement value_stack.
157    grad_debug_op_name = _tensor_to_grad_debug_op_name(input_tensor, self._uuid)
158    # pylint: disable=protected-access
159    identity_op = (
160        gen_array_ops.debug_gradient_ref_identity
161        if input_tensor.dtype._is_ref_dtype else
162        gen_array_ops.debug_gradient_identity)
163    # pylint: enable=protected-access
164    debug_grad_identity = identity_op(input_tensor, name=grad_debug_op_name)
165    assert debug_grad_identity.dtype == input_tensor.dtype
166    if debug_grad_identity.op.name != grad_debug_op_name:
167      raise ValueError(
168          "The graph already contains an op named %s" % grad_debug_op_name)
169    return debug_grad_identity
170
171  def watch_gradients_by_tensors(self, graph, tensors):
172    """Watch gradient tensors by x-tensor(s).
173
174    The side effect of this method is that when gradient tensor(s) are created
175    with respect to the any paths that include the `x_tensor`s, the gradient
176    tensor(s) with repsect to the tensor will be registered with this
177    this `GradientsDebugger` instance and can later be retrieved, with the
178    methods `gradient_tensor` and `gradient_tensors`.
179
180    Unlike the method `identify_gradient`, this method is used to retrieve
181    gradient tensors after the construction of the forward subgraph has
182    completed (but before the construction of the backward subgraph).
183
184    This method is the same as `watch_gradients_by_x_tensor_names` except that
185    the tensors are specified by the Python `tf.Tensor` or `tf.Variable`
186    objects, instead by name patterns.
187
188    Example:
189
190    ```python
191    x = tf.Variable(1.0)
192    y = tf.add(x, x, name="y")
193    z = tf.square(debug_y)
194
195    # Create a train op under the grad_debugger context.
196    grad_debugger = tf_debug.GradientsDebugger()
197    with grad_debugger.watch_gradients_by_tensors(y):
198      train_op = tf.train.GradientDescentOptimizer(z)
199
200    # Now we can reflect through grad_debugger to get the gradient tensor
201    # with respect to y.
202    y_grad = grad_debugger.gradient_tensor(y)
203    # or
204    y_grad = grad_debugger.gradient_tensor("y:0")
205    ```
206
207    Args:
208      graph: the `tf.Graph` to watch the gradients on.
209      tensors: a `tf.Tensor` or `tf.Variable` object, or a list of such objects.
210
211    Returns:
212      The GradientsDebugger instance itself.
213    """
214
215    if not isinstance(tensors, list):
216      tensors = [tensors]
217
218    tensor_name_regex = []
219    for tensor in tensors:
220      tensor_name_regex.append(re.escape(tensor.name) + "$")
221    tensor_name_regex = "(" + "|".join(tensor_name_regex) + ")"
222    return self.watch_gradients_by_tensor_names(graph, tensor_name_regex)
223
224  def watch_gradients_by_tensor_names(self, graph, tensor_name_regex):
225    """Watch gradient tensors by name(s) of the x-tensor(s).
226
227    The side effect of this method is that when gradient tensor(s) are created
228    with respect to the x-tensors, the gradient tensor(s) will be registered
229    with this `GradientsDebugger` instance and can later be retrieved.
230
231    Unlike the `identify_gradient` method, this method is used after the
232    construction of the forward graph has completed. Unlike the
233    `watch_gradients_by_tensor` method, this method does not use handles to the
234    tensors of interest; it uses their names.
235
236    This method is the same as `watch_gradients_by_tensors` except that the
237    x-tensors are specified by name patterns, instead of `tf.Tensor` or
238    `tf.Variable` objects.
239
240    Example:
241
242    ```python
243    x = tf.Variable(1.0, name="x")
244    y = tf.add(x, x, name="y")
245    z = tf.square(debug_y)
246
247    # Create a train op under the grad_debugger context.
248    grad_debugger = tf_debug.GradientsDebugger()
249    with grad_debugger.watch_gradients_by_tensor_names(r"(x|y):0$"):
250      train_op = tf.train.GradientDescentOptimizer(z)
251
252    # Now we can reflect through grad_debugger to get the gradient tensor
253    # with respect to x and y.
254    x_grad = grad_debugger.gradient_tensor("x:0")
255    y_grad = grad_debugger.gradient_tensor("y:0")
256    ```
257
258    Args:
259      graph: the `tf.Graph` to watch the gradients on.
260      tensor_name_regex: the regular-expression pattern of the name(s) of the
261        x-tensor(s) to watch. x-tensor refers to the tensors on the denominator
262        of the differentiation.
263
264    Returns:
265      The GradientsDebugger instance itself.
266    """
267    tensor_name_pattern = re.compile(tensor_name_regex)
268    with graph.as_default():
269      for op in graph.get_operations():
270        for output in op.outputs:
271          if tensor_name_pattern.match(output.name):
272            debug_op = self.identify_gradient(output)
273
274            # Make a copy of output.consumers() since we'll modify the consumers
275            # TODO(skyewm): this is unnecessary once the C API is enabled
276            for consumer in list(output.consumers()):
277              if consumer == debug_op.op:
278                continue
279
280              # Locate the slot index of the original input.
281              for i, consumer_input in enumerate(consumer.inputs):
282                if consumer_input == output:
283                  consumer._update_input(i, debug_op)  # pylint: disable=protected-access
284    return self
285
286  def _check_same_graph(self, tensor):
287    if self._graph is None:
288      self._graph = tensor.graph
289    elif self._graph != tensor.graph:
290      raise ValueError(
291          "The graph of the value (%s) is not the same as the graph %s" %
292          (tensor.graph, self._graph))
293
294  def register_gradient_tensor(self,
295                               x_tensor_name,
296                               gradient_tensor):
297    """Register the gradient tensor for an x-tensor.
298
299    Args:
300      x_tensor_name: (`str`) the name of the independent `tf.Tensor`, i.e.,
301        the tensor on the denominator of the differentiation.
302      gradient_tensor: the gradient `tf.Tensor`.
303    """
304    if len(_gradient_debuggers) == 1 or self._is_active_context:
305      self._check_same_graph(gradient_tensor)
306      self._gradient_tensors[x_tensor_name] = gradient_tensor
307
308  def gradient_tensor(self, x_tensor):
309    """Get the gradient tensor of an x-tensor.
310
311    Args:
312      x_tensor: (`tf.Tensor`, `tf.Variable` or `str`) The x-tensor object or its
313        name. x-tensor refers to the independent `tf.Tensor`, i.e., the tensor
314        on the denominator of the differentiation.
315
316    Returns:
317      If found, the gradient tensor.
318
319    Raises:
320      TypeError: If `x_tensor` is not a `tf.Tensor`, `tf.Variable` or `str`.
321      LookupError: If the `x_tensor` has not been registered with a gradient
322        tensor.
323    """
324    x_tensor_name = self._get_tensor_name(x_tensor)
325    if x_tensor_name not in self._gradient_tensors:
326      raise LookupError(
327          "This GradientsDebugger has not received any gradient tensor for "
328          "x-tensor %s" % x_tensor_name)
329    return self._gradient_tensors[x_tensor_name]
330
331  def gradient_tensors(self):
332    """Get the gradient tensors that this object is aware of.
333
334    Returns:
335      A dict mapping x-tensor names to gradient tensor objects. x-tensor refers
336      to the tensors on the denominator of the differentation.
337    """
338    return self._gradient_tensors
339
340  def _get_tensor_name(self, tensor):
341    if isinstance(tensor, (ops.Tensor, variables.Variable)):
342      return tensor.name
343    elif isinstance(tensor, six.string_types):
344      return tensor
345    else:
346      raise TypeError(
347          "x_tensor must be a str or tf.Tensor or tf.Variable, "
348          "but instead has type %s" % type(tensor))
349
350
351def clear_gradient_debuggers():
352  """Clear all globally registered gradient debuggers."""
353  _gradient_debuggers.clear()
354
355
356@ops.RegisterGradient("DebugGradientIdentity")
357def _identify_gradient_grad(op, dy):
358  """Gradient function for the DebugIdentity op."""
359  # TODO(cais): Allow overriding gradient.
360  grad_debugger_uuid, orig_tensor_name = _parse_grad_debug_op_name(op.name)
361  grad_debugger = _gradient_debuggers[grad_debugger_uuid]
362  grad_debugger.register_gradient_tensor(orig_tensor_name, dy)
363  return dy
364
365
366@ops.RegisterGradient("DebugGradientRefIdentity")
367def _identify_gradient_grad_ref(op, dy):
368  """Gradient function for the DebugIdentity op."""
369  return _identify_gradient_grad(op, dy)
370
371
372def gradient_values_from_dump(grad_debugger, x_tensor, dump):
373  """Find gradient values from a `DebugDumpDir` object.
374
375  Args:
376    grad_debugger: the `tf_debug.GradientsDebugger` instance to be used.
377    x_tensor: (`tf.Tensor`, `tf.Variable` or `str`) The x-tensor object or its
378      name. x-tensor refers to the independent `tf.Tensor`, i.e., the tensor
379      on the denominator of the differentiation.
380    dump: A `tfdbg.DebugDumpDir` object.
381
382  Returns:
383    If this `GradientsDebugger` instance has the gradient tensor of `x_tensor`
384      registered: a list of `numpy.ndarray` representing the value of the
385      gradient tensor from `dump`. The list could be empty, if the gradient
386      tensor is not executed in the `tf.Session.run()` call that generated
387      the `dump`. The list could also contain multiple values of the gradient
388      tensor, e.g., if gradient tensor is computed repeatedly in a
389      `tf.while_loop` during the run that generated the `dump`.
390
391  Raises:
392    LookupError: If this `GradientsDebugger` instance does not have the
393      gradient tensor of `x_tensor` registered.
394    ValueError: If this `GradientsDebugger` has a `tf.Graph` object that
395      does not match the `tf.Graph` object of the `dump`.
396    TypeError: If `x_tensor` is not a `tf.Tensor`, `tf.Variable` or `str`.
397  """
398  # TODO(cais): Use this method in LocalCLIDebugWrapperSession to present the
399  # gradient tensors to the TFDBG CLI.
400
401  # If possible, verify that the Python graph of the dump and that of this
402  # GradientsDebugger match.
403  if (dump.python_graph and grad_debugger.graph and
404      dump.python_graph != grad_debugger.graph):
405    raise ValueError(
406        "This GradientsDebugger instance has a graph (%s) that differs from "
407        "the graph of the DebugDumpDir object (%s)." %
408        (grad_debugger.graph, dump.python_graph))
409
410  gradient_tensor = grad_debugger.gradient_tensor(x_tensor)
411  node_name, output_slot = debug_graphs.parse_node_or_tensor_name(
412      gradient_tensor.name)
413
414  try:
415    return dump.get_tensors(node_name, output_slot, "DebugIdentity")
416  except debug_data.WatchKeyDoesNotExistInDebugDumpDirError:
417    return []
418