1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the arg_scope used for scoping layers arguments.
16
17  Allows one to define models much more compactly by eliminating boilerplate
18  code. This is accomplished through the use of argument scoping (arg_scope).
19
20  Example of how to use tf.contrib.framework.arg_scope:
21
22  ```
23  from third_party.tensorflow.contrib.layers.python import layers
24
25  arg_scope = tf.contrib.framework.arg_scope
26
27  with arg_scope([layers.conv2d], padding='SAME',
28                 initializer=layers.variance_scaling_initializer(),
29                 regularizer=layers.l2_regularizer(0.05)):
30    net = layers.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
31    net = layers.conv2d(net, 256, [5, 5], scope='conv2')
32  ```
33  The first call to conv2d will behave as follows:
34    layers.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
35                  initializer=layers.variance_scaling_initializer(),
36                  regularizer=layers.l2_regularizer(0.05), scope='conv1')
37
38  The second call to conv2d will also use the arg_scope's default for padding:
39    layers.conv2d(inputs, 256, [5, 5], padding='SAME',
40                  initializer=layers.variance_scaling_initializer(),
41                  regularizer=layers.l2_regularizer(0.05), scope='conv2')
42
43  Example of how to reuse an arg_scope:
44
45  ```
46  with arg_scope([layers.conv2d], padding='SAME',
47                 initializer=layers.variance_scaling_initializer(),
48                 regularizer=layers.l2_regularizer(0.05)) as sc:
49    net = layers.conv2d(net, 256, [5, 5], scope='conv1')
50    ....
51
52  with arg_scope(sc):
53    net = layers.conv2d(net, 256, [5, 5], scope='conv2')
54  ```
55
56  Example of how to use tf.contrib.framework.add_arg_scope to enable your
57  function to be called within an arg_scope later:
58
59  @tf.contrib.framework.add_arg_scope
60  def conv2d(*args, **kwargs)
61"""
62from __future__ import absolute_import
63from __future__ import division
64from __future__ import print_function
65
66from tensorflow.python.util import tf_contextlib
67from tensorflow.python.util import tf_decorator
68
69__all__ = [
70    'arg_scope', 'add_arg_scope', 'current_arg_scope', 'has_arg_scope',
71    'arg_scoped_arguments', 'arg_scope_func_key'
72]
73
74_ARGSTACK = [{}]
75
76_DECORATED_OPS = {}
77
78
79def _get_arg_stack():
80  if _ARGSTACK:
81    return _ARGSTACK
82  else:
83    _ARGSTACK.append({})
84    return _ARGSTACK
85
86
87def current_arg_scope():
88  stack = _get_arg_stack()
89  return stack[-1]
90
91
92def arg_scope_func_key(op):
93  return getattr(op, '_key_op', str(op))
94
95
96def _name_op(op):
97  return (op.__module__, op.__name__)
98
99
100def _kwarg_names(func):
101  kwargs_length = len(func.__defaults__) if func.__defaults__ else 0
102  return func.__code__.co_varnames[-kwargs_length:func.__code__.co_argcount]
103
104
105def _add_op(op):
106  key_op = arg_scope_func_key(op)
107  _DECORATED_OPS[key_op] = _kwarg_names(op)
108
109
110@tf_contextlib.contextmanager
111def arg_scope(list_ops_or_scope, **kwargs):
112  """Stores the default arguments for the given set of list_ops.
113
114  For usage, please see examples at top of the file.
115
116  Args:
117    list_ops_or_scope: List or tuple of operations to set argument scope for or
118      a dictionary containing the current scope. When list_ops_or_scope is a
119      dict, kwargs must be empty. When list_ops_or_scope is a list or tuple,
120      then every op in it need to be decorated with @add_arg_scope to work.
121    **kwargs: keyword=value that will define the defaults for each op in
122              list_ops. All the ops need to accept the given set of arguments.
123
124  Yields:
125    the current_scope, which is a dictionary of {op: {arg: value}}
126  Raises:
127    TypeError: if list_ops is not a list or a tuple.
128    ValueError: if any op in list_ops has not be decorated with @add_arg_scope.
129  """
130  if isinstance(list_ops_or_scope, dict):
131    # Assumes that list_ops_or_scope is a scope that is being reused.
132    if kwargs:
133      raise ValueError('When attempting to re-use a scope by suppling a'
134                       'dictionary, kwargs must be empty.')
135    current_scope = list_ops_or_scope.copy()
136    try:
137      _get_arg_stack().append(current_scope)
138      yield current_scope
139    finally:
140      _get_arg_stack().pop()
141  else:
142    # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs.
143    if not isinstance(list_ops_or_scope, (list, tuple)):
144      raise TypeError('list_ops_or_scope must either be a list/tuple or reused '
145                      'scope (i.e. dict)')
146    try:
147      current_scope = current_arg_scope().copy()
148      for op in list_ops_or_scope:
149        key = arg_scope_func_key(op)
150        if not has_arg_scope(op):
151          raise ValueError('%s is not decorated with @add_arg_scope',
152                           _name_op(op))
153        if key in current_scope:
154          current_kwargs = current_scope[key].copy()
155          current_kwargs.update(kwargs)
156          current_scope[key] = current_kwargs
157        else:
158          current_scope[key] = kwargs.copy()
159      _get_arg_stack().append(current_scope)
160      yield current_scope
161    finally:
162      _get_arg_stack().pop()
163
164
165def add_arg_scope(func):
166  """Decorates a function with args so it can be used within an arg_scope.
167
168  Args:
169    func: function to decorate.
170
171  Returns:
172    A tuple with the decorated function func_with_args().
173  """
174
175  def func_with_args(*args, **kwargs):
176    current_scope = current_arg_scope()
177    current_args = kwargs
178    key_func = arg_scope_func_key(func)
179    if key_func in current_scope:
180      current_args = current_scope[key_func].copy()
181      current_args.update(kwargs)
182    return func(*args, **current_args)
183
184  _add_op(func)
185  setattr(func_with_args, '_key_op', arg_scope_func_key(func))
186  return tf_decorator.make_decorator(func, func_with_args)
187
188
189def has_arg_scope(func):
190  """Checks whether a func has been decorated with @add_arg_scope or not.
191
192  Args:
193    func: function to check.
194
195  Returns:
196    a boolean.
197  """
198  return arg_scope_func_key(func) in _DECORATED_OPS
199
200
201def arg_scoped_arguments(func):
202  """Returns the list kwargs that arg_scope can set for a func.
203
204  Args:
205    func: function which has been decorated with @add_arg_scope.
206
207  Returns:
208    a list of kwargs names.
209  """
210  assert has_arg_scope(func)
211  return _DECORATED_OPS[arg_scope_func_key(func)]
212