1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for training."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.eager import context
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import graph_io
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import init_ops
25from tensorflow.python.ops import resource_variable_ops
26from tensorflow.python.ops import state_ops
27from tensorflow.python.ops import variable_scope
28from tensorflow.python.ops import variables
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.util.tf_export import tf_export
31
32# Picked a long key value to minimize the chance of collision with user defined
33# collection keys.
34GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache'
35
36# TODO(drpng): remove this after legacy uses are resolved.
37write_graph = graph_io.write_graph
38
39
40@tf_export(v1=['train.global_step'])
41def global_step(sess, global_step_tensor):
42  """Small helper to get the global step.
43
44  ```python
45  # Create a variable to hold the global_step.
46  global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
47  # Create a session.
48  sess = tf.compat.v1.Session()
49  # Initialize the variable
50  sess.run(global_step_tensor.initializer)
51  # Get the variable value.
52  print('global_step: %s' % tf.compat.v1.train.global_step(sess,
53  global_step_tensor))
54
55  global_step: 10
56  ```
57
58  Args:
59    sess: A TensorFlow `Session` object.
60    global_step_tensor:  `Tensor` or the `name` of the operation that contains
61      the global step.
62
63  Returns:
64    The global step value.
65  """
66  if context.executing_eagerly():
67    return int(global_step_tensor.numpy())
68  return int(sess.run(global_step_tensor))
69
70
71@tf_export(v1=['train.get_global_step'])
72def get_global_step(graph=None):
73  """Get the global step tensor.
74
75  The global step tensor must be an integer variable. We first try to find it
76  in the collection `GLOBAL_STEP`, or by name `global_step:0`.
77
78  Args:
79    graph: The graph to find the global step in. If missing, use default graph.
80
81  Returns:
82    The global step variable, or `None` if none was found.
83
84  Raises:
85    TypeError: If the global step tensor has a non-integer type, or if it is not
86      a `Variable`.
87  """
88  graph = graph or ops.get_default_graph()
89  global_step_tensor = None
90  global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP)
91  if len(global_step_tensors) == 1:
92    global_step_tensor = global_step_tensors[0]
93  elif not global_step_tensors:
94    try:
95      global_step_tensor = graph.get_tensor_by_name('global_step:0')
96    except KeyError:
97      return None
98  else:
99    logging.error('Multiple tensors in global_step collection.')
100    return None
101
102  assert_global_step(global_step_tensor)
103  return global_step_tensor
104
105
106@tf_export(v1=['train.create_global_step'])
107def create_global_step(graph=None):
108  """Create global step tensor in graph.
109
110  Args:
111    graph: The graph in which to create the global step tensor. If missing, use
112      default graph.
113
114  Returns:
115    Global step tensor.
116
117  Raises:
118    ValueError: if global step tensor is already defined.
119  """
120  graph = graph or ops.get_default_graph()
121  if get_global_step(graph) is not None:
122    raise ValueError('"global_step" already exists.')
123  if context.executing_eagerly():
124    with ops.device('cpu:0'):
125      return variable_scope.get_variable(
126          ops.GraphKeys.GLOBAL_STEP,
127          shape=[],
128          dtype=dtypes.int64,
129          initializer=init_ops.zeros_initializer(),
130          trainable=False,
131          aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
132          collections=[
133              ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
134          ])
135  # Create in proper graph and base name_scope.
136  with graph.as_default() as g, g.name_scope(None):
137    return variable_scope.get_variable(
138        ops.GraphKeys.GLOBAL_STEP,
139        shape=[],
140        dtype=dtypes.int64,
141        initializer=init_ops.zeros_initializer(),
142        trainable=False,
143        aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
144        collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
145
146
147@tf_export(v1=['train.get_or_create_global_step'])
148def get_or_create_global_step(graph=None):
149  """Returns and create (if necessary) the global step tensor.
150
151  Args:
152    graph: The graph in which to create the global step tensor. If missing, use
153      default graph.
154
155  Returns:
156    The global step tensor.
157  """
158  graph = graph or ops.get_default_graph()
159  global_step_tensor = get_global_step(graph)
160  if global_step_tensor is None:
161    global_step_tensor = create_global_step(graph)
162  return global_step_tensor
163
164
165@tf_export(v1=['train.assert_global_step'])
166def assert_global_step(global_step_tensor):
167  """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.
168
169  Args:
170    global_step_tensor: `Tensor` to test.
171  """
172  if not (isinstance(global_step_tensor, variables.Variable) or
173          isinstance(global_step_tensor, ops.Tensor) or
174          resource_variable_ops.is_resource_variable(global_step_tensor)):
175    raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' %
176                    global_step_tensor)
177
178  if not global_step_tensor.dtype.base_dtype.is_integer:
179    raise TypeError('Existing "global_step" does not have integer type: %s' %
180                    global_step_tensor.dtype)
181
182  if (global_step_tensor.get_shape().ndims != 0 and
183      global_step_tensor.get_shape().is_fully_defined()):
184    raise TypeError('Existing "global_step" is not scalar: %s' %
185                    global_step_tensor.get_shape())
186
187
188def _get_global_step_read(graph=None):
189  """Gets global step read tensor in graph.
190
191  Args:
192    graph: The graph in which to create the global step read tensor. If missing,
193      use default graph.
194
195  Returns:
196    Global step read tensor.
197
198  Raises:
199    RuntimeError: if multiple items found in collection GLOBAL_STEP_READ_KEY.
200  """
201  graph = graph or ops.get_default_graph()
202  global_step_read_tensors = graph.get_collection(GLOBAL_STEP_READ_KEY)
203  if len(global_step_read_tensors) > 1:
204    raise RuntimeError('There are multiple items in collection {}. '
205                       'There should be only one.'.format(GLOBAL_STEP_READ_KEY))
206
207  if len(global_step_read_tensors) == 1:
208    return global_step_read_tensors[0]
209  return None
210
211
212def _get_or_create_global_step_read(graph=None):
213  """Gets or creates global step read tensor in graph.
214
215  Args:
216    graph: The graph in which to create the global step read tensor. If missing,
217      use default graph.
218
219  Returns:
220    Global step read tensor if there is global_step_tensor else return None.
221  """
222  graph = graph or ops.get_default_graph()
223  global_step_read_tensor = _get_global_step_read(graph)
224  if global_step_read_tensor is not None:
225    return global_step_read_tensor
226  global_step_tensor = get_global_step(graph)
227  if global_step_tensor is None:
228    return None
229  # add 'zero' so that it will create a copy of variable as Tensor.
230  with graph.as_default() as g, g.name_scope(None):
231    with g.name_scope(global_step_tensor.op.name + '/'):
232      # using initialized_value to ensure that global_step is initialized before
233      # this run. This is needed for example Estimator makes all model_fn build
234      # under global_step_read_tensor dependency.
235      global_step_value = global_step_tensor.initialized_value() if isinstance(
236          global_step_tensor, variables.Variable) else global_step_tensor
237      global_step_read_tensor = global_step_value + 0
238      ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor)
239  return _get_global_step_read(graph)
240
241
242def _increment_global_step(increment, graph=None):
243  graph = graph or ops.get_default_graph()
244  global_step_tensor = get_global_step(graph)
245  if global_step_tensor is None:
246    raise ValueError(
247        'Global step tensor should be created by '
248        'tf.train.get_or_create_global_step before calling increment.')
249  global_step_read_tensor = _get_or_create_global_step_read(graph)
250  with graph.as_default() as g, g.name_scope(None):
251    with g.name_scope(global_step_tensor.op.name + '/'):
252      with ops.control_dependencies([global_step_read_tensor]):
253        return state_ops.assign_add(global_step_tensor, increment)
254