1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for controlling the Tensorflow/XLA JIT compiler."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22
23from tensorflow.core.framework import attr_value_pb2
24from tensorflow.python.framework import ops
25
26
27_XLA_SCOPE_KEY = ("__xla_scope",)
28
29
30class _XlaScope(object):
31  """Keeps track of previous XLA scope calls, and depth of current call."""
32
33  def __init__(self, count, depth):
34    self.count = count
35    self.depth = depth
36
37
38@contextlib.contextmanager
39def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False):
40  """Enable or disable JIT compilation of operators within the scope.
41
42  NOTE: This is an experimental feature.
43
44  The compilation is a hint and only supported on a best-effort basis.
45
46  Example usage:
47    with tf.contrib.compiler.experimental_jit_scope():
48      c = tf.matmul(a, b)  # compiled
49    with tf.contrib.compiler.experimental_jit_scope(compile_ops=False):
50      d = tf.matmul(a, c)  # not compiled
51    with tf.contrib.compiler.experimental_jit_scope(
52        compile_ops=lambda node_def: 'matmul' in node_def.op.lower()):
53      e = tf.matmul(a, b) + d  # matmul is compiled, the addition is not.
54
55  Example of separate_compiled_gradients:
56    # In the example below, the computations for f, g and h will all be compiled
57    # in separate scopes.
58    with tf.contrib.compiler.experimental_jit_scope(
59        separate_compiled_gradients=True):
60      f = tf.matmul(a, b)
61    g = tf.gradients([f], [a, b], name='mygrads1')
62    h = tf.gradients([f], [a, b], name='mygrads2')
63
64  Args:
65    compile_ops: Whether to enable or disable compilation in the scope.
66      Either a Python bool, or a callable that accepts the parameter
67      `node_def` and returns a python bool.
68    separate_compiled_gradients: If true put each gradient subgraph into a
69      separate compilation scope. This gives fine-grained control over which
70      portions of the graph will be compiled as a single unit. Compiling
71      gradients separately may yield better performance for some graphs.
72      The scope is named based on the scope of the forward computation as well
73      as the name of the gradients. As a result, the gradients will be compiled
74      in a scope that is separate from both the forward computation, and from
75      other gradients.
76  Yields:
77    The current scope, enabling or disabling compilation.
78
79  """
80  if callable(compile_ops):
81    def xla_compile(node_def):
82      return attr_value_pb2.AttrValue(b=compile_ops(node_def))
83  else:
84    xla_compile = attr_value_pb2.AttrValue(b=compile_ops)
85
86  attrs = {
87      "_XlaCompile":
88          xla_compile,
89      "_XlaSeparateCompiledGradients":
90          attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients))
91  }
92
93  # Find the singleton counter for the current scoped graph.  If it
94  # doesn't exist, create one.
95  xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY)
96  if not xla_scope_counter:
97    xla_scope_counter = _XlaScope(0, 0)
98    ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter)
99  else:
100    xla_scope_counter = xla_scope_counter[0]
101
102  if xla_scope_counter.depth == 0:
103    # If we're at the root xla scope, we can increase the counter so
104    # future calls to jit_scope use a different scope value.
105    # If we're already within a scope, we'll be fusing using the scope
106    # controlled by the parent.
107    attrs["_XlaScope"] = attr_value_pb2.AttrValue(
108        s=("jit_scope_%d" % xla_scope_counter.count).encode())
109    xla_scope_counter.count += 1
110
111  xla_scope_counter.depth += 1
112
113  # pylint: disable=protected-access
114  with ops.get_default_graph()._attr_scope(attrs):
115    yield
116  # pylint: enable=protected-access
117
118  xla_scope_counter.depth -= 1
119