1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Helper library for functions used during TPU compilation."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import contextlib
23import threading
24
25
26class TpuContext(threading.local):
27  """A context object holding state about the TPU computation being built."""
28
29  def __init__(self):
30    """Creates a new TpuContext."""
31    self._number_of_shards = None
32
33  @property
34  def number_of_shards(self):
35    return self._number_of_shards
36
37  def set_number_of_shards(self, number_of_shards):
38    self._number_of_shards = number_of_shards
39
40
41# The Tpu context holds the number of shards when a sharded computation is
42# being built, or None if no computation is being built.
43_current_tpu_context = TpuContext()
44
45
46@contextlib.contextmanager
47def tpu_shard_context(number_of_shards):
48  """A context manager setting current number of shards."""
49  if _current_tpu_context.number_of_shards is not None:
50    raise NotImplementedError(
51        "tpu_shard_context cannot be nested."
52        "If you're using TPUEstimator with inference_on_tpu, "
53        "make sure you have set "
54        "export_saved_model_api_version=ExportSavedModelApiVersion.V2 in "
55        "the creation of TPUEstimator.")
56  try:
57    _current_tpu_context.set_number_of_shards(number_of_shards)
58    yield
59  finally:
60    _current_tpu_context.set_number_of_shards(None)
61
62
63def get_tpu_context():
64  return _current_tpu_context
65
66
67# Decorator function for tpu computation func that was passed to tpu.rewrite()
68# if there is an embedded training loop in this func, trace tools will generate
69# step markers for each iteration.
70def on_device_training_loop(func):
71  # Value for this attribute is from xla.DebugOptions.StepMarkerLocation.
72  setattr(func, "step_marker_location", "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP")
73  return func
74