1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Helper library for functions used during TPU compilation."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import contextlib
23
24
25class TpuContext(object):
26  """A context object holding state about the TPU computation being built."""
27
28  def __init__(self):
29    """Creates a new TpuContext."""
30    self._number_of_shards = None
31
32  @property
33  def number_of_shards(self):
34    return self._number_of_shards
35
36  def set_number_of_shards(self, number_of_shards):
37    self._number_of_shards = number_of_shards
38
39
40# The Tpu context holds the number of shards when a sharded computation is
41# being built, or None if no computation is being built.
42_current_tpu_context = TpuContext()
43
44
45@contextlib.contextmanager
46def tpu_shard_context(number_of_shards):
47  if _current_tpu_context.number_of_shards is not None:
48    raise NotImplementedError("tpu_shard_context cannot be nested.")
49  try:
50    _current_tpu_context.set_number_of_shards(number_of_shards)
51    yield
52  finally:
53    _current_tpu_context.set_number_of_shards(None)
54
55
56def get_tpu_context():
57  return _current_tpu_context
58
59
60# Decorator function for tpu computation func that was passed to tpu.rewrite()
61# if there is an embedded training loop in this func, trace tools will generate
62# step markers for each iteration.
63def on_device_training_loop(func):
64  # Value for this attribute is from xla.DebugOptions.StepMarkerLocation.
65  setattr(func, "step_marker_location", "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP")
66  return func
67