1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utility functions for training."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import graph_io
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import init_ops
26from tensorflow.python.ops import resource_variable_ops
27from tensorflow.python.ops import state_ops
28from tensorflow.python.ops import variable_scope
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util.tf_export import tf_export
32
33# Picked a long key value to minimize the chance of collision with user defined
34# collection keys.
35GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache'
36
37
38# TODO(drpng): remove this after legacy uses are resolved.
39write_graph = graph_io.write_graph
40
41
42@tf_export(v1=['train.global_step'])
43def global_step(sess, global_step_tensor):
44  """Small helper to get the global step.
45
46  ```python
47  # Create a variable to hold the global_step.
48  global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
49  # Create a session.
50  sess = tf.Session()
51  # Initialize the variable
52  sess.run(global_step_tensor.initializer)
53  # Get the variable value.
54  print('global_step: %s' % tf.train.global_step(sess, global_step_tensor))
55
56  global_step: 10
57  ```
58
59  Args:
60    sess: A TensorFlow `Session` object.
61    global_step_tensor:  `Tensor` or the `name` of the operation that contains
62      the global step.
63
64  Returns:
65    The global step value.
66  """
67  if context.executing_eagerly():
68    return int(global_step_tensor.numpy())
69  return int(sess.run(global_step_tensor))
70
71
72@tf_export(v1=['train.get_global_step'])
73def get_global_step(graph=None):
74  """Get the global step tensor.
75
76  The global step tensor must be an integer variable. We first try to find it
77  in the collection `GLOBAL_STEP`, or by name `global_step:0`.
78
79  Args:
80    graph: The graph to find the global step in. If missing, use default graph.
81
82  Returns:
83    The global step variable, or `None` if none was found.
84
85  Raises:
86    TypeError: If the global step tensor has a non-integer type, or if it is not
87      a `Variable`.
88  """
89  graph = graph or ops.get_default_graph()
90  global_step_tensor = None
91  global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP)
92  if len(global_step_tensors) == 1:
93    global_step_tensor = global_step_tensors[0]
94  elif not global_step_tensors:
95    try:
96      global_step_tensor = graph.get_tensor_by_name('global_step:0')
97    except KeyError:
98      return None
99  else:
100    logging.error('Multiple tensors in global_step collection.')
101    return None
102
103  assert_global_step(global_step_tensor)
104  return global_step_tensor
105
106
107@tf_export(v1=['train.create_global_step'])
108def create_global_step(graph=None):
109  """Create global step tensor in graph.
110
111  Args:
112    graph: The graph in which to create the global step tensor. If missing,
113      use default graph.
114
115  Returns:
116    Global step tensor.
117
118  Raises:
119    ValueError: if global step tensor is already defined.
120  """
121  graph = graph or ops.get_default_graph()
122  if get_global_step(graph) is not None:
123    raise ValueError('"global_step" already exists.')
124  if context.executing_eagerly():
125    with ops.device('cpu:0'):
126      return variable_scope.get_variable(
127          ops.GraphKeys.GLOBAL_STEP,
128          shape=[],
129          dtype=dtypes.int64,
130          initializer=init_ops.zeros_initializer(),
131          trainable=False,
132          aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
133          collections=[ops.GraphKeys.GLOBAL_VARIABLES,
134                       ops.GraphKeys.GLOBAL_STEP])
135  # Create in proper graph and base name_scope.
136  with graph.as_default() as g, g.name_scope(None):
137    return variable_scope.get_variable(
138        ops.GraphKeys.GLOBAL_STEP,
139        shape=[],
140        dtype=dtypes.int64,
141        initializer=init_ops.zeros_initializer(),
142        trainable=False,
143        aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
144        collections=[ops.GraphKeys.GLOBAL_VARIABLES,
145                     ops.GraphKeys.GLOBAL_STEP])
146
147
148@tf_export(v1=['train.get_or_create_global_step'])
149def get_or_create_global_step(graph=None):
150  """Returns and create (if necessary) the global step tensor.
151
152  Args:
153    graph: The graph in which to create the global step tensor. If missing, use
154      default graph.
155
156  Returns:
157    The global step tensor.
158  """
159  graph = graph or ops.get_default_graph()
160  global_step_tensor = get_global_step(graph)
161  if global_step_tensor is None:
162    global_step_tensor = create_global_step(graph)
163  return global_step_tensor
164
165
166@tf_export(v1=['train.assert_global_step'])
167def assert_global_step(global_step_tensor):
168  """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.
169
170  Args:
171    global_step_tensor: `Tensor` to test.
172  """
173  if not (isinstance(global_step_tensor, variables.Variable) or
174          isinstance(global_step_tensor, ops.Tensor) or
175          resource_variable_ops.is_resource_variable(global_step_tensor)):
176    raise TypeError(
177        'Existing "global_step" must be a Variable or Tensor: %s.' %
178        global_step_tensor)
179
180  if not global_step_tensor.dtype.base_dtype.is_integer:
181    raise TypeError('Existing "global_step" does not have integer type: %s' %
182                    global_step_tensor.dtype)
183
184  if (global_step_tensor.get_shape().ndims != 0 and
185      global_step_tensor.get_shape().is_fully_defined()):
186    raise TypeError('Existing "global_step" is not scalar: %s' %
187                    global_step_tensor.get_shape())
188
189
190def _get_global_step_read(graph=None):
191  """Gets global step read tensor in graph.
192
193  Args:
194    graph: The graph in which to create the global step read tensor. If missing,
195      use default graph.
196
197  Returns:
198    Global step read tensor.
199
200  Raises:
201    RuntimeError: if multiple items found in collection GLOBAL_STEP_READ_KEY.
202  """
203  graph = graph or ops.get_default_graph()
204  global_step_read_tensors = graph.get_collection(GLOBAL_STEP_READ_KEY)
205  if len(global_step_read_tensors) > 1:
206    raise RuntimeError('There are multiple items in collection {}. '
207                       'There should be only one.'.format(GLOBAL_STEP_READ_KEY))
208
209  if len(global_step_read_tensors) == 1:
210    return global_step_read_tensors[0]
211  return None
212
213
214def _get_or_create_global_step_read(graph=None):
215  """Gets or creates global step read tensor in graph.
216
217  Args:
218    graph: The graph in which to create the global step read tensor. If missing,
219      use default graph.
220
221  Returns:
222    Global step read tensor if there is global_step_tensor else return None.
223  """
224  graph = graph or ops.get_default_graph()
225  global_step_read_tensor = _get_global_step_read(graph)
226  if global_step_read_tensor is not None:
227    return global_step_read_tensor
228  global_step_tensor = get_global_step(graph)
229  if global_step_tensor is None:
230    return None
231  # add 'zero' so that it will create a copy of variable as Tensor.
232  with graph.as_default() as g, g.name_scope(None):
233    with g.name_scope(global_step_tensor.op.name + '/'):
234      # using initialized_value to ensure that global_step is initialized before
235      # this run. This is needed for example Estimator makes all model_fn build
236      # under global_step_read_tensor dependency.
237      global_step_value = global_step_tensor.initialized_value() if isinstance(
238          global_step_tensor, variables.Variable) else global_step_tensor
239      global_step_read_tensor = global_step_value + 0
240      ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor)
241  return _get_global_step_read(graph)
242
243
244def _increment_global_step(increment, graph=None):
245  graph = graph or ops.get_default_graph()
246  global_step_tensor = get_global_step(graph)
247  if global_step_tensor is None:
248    raise ValueError(
249        'Global step tensor should be created by '
250        'tf.train.get_or_create_global_step before calling increment.')
251  global_step_read_tensor = _get_or_create_global_step_read(graph)
252  with graph.as_default() as g, g.name_scope(None):
253    with g.name_scope(global_step_tensor.op.name + '/'):
254      with ops.control_dependencies([global_step_read_tensor]):
255        return state_ops.assign_add(global_step_tensor, increment)
256