1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU specific APIs to be used in conjunction with TPU Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import config_pb2
22from tensorflow.python.client import session as session_lib
23from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
24from tensorflow.python.eager import context
25from tensorflow.python.eager import function
26from tensorflow.python.framework import device as tf_device
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.tpu import functional as tpu_functional_ops
31from tensorflow.python.tpu import topology
32from tensorflow.python.tpu import tpu
33from tensorflow.python.util import compat
34from tensorflow.python.util.tf_export import tf_export
35
36
37def get_first_tpu_host_device(cluster_resolver):
38  """Get the device spec for the first TPU host."""
39  if context.executing_eagerly():
40    tpu_devices = sorted(
41        [x for x in context.list_devices() if "device:TPU:" in x])
42    if not tpu_devices:
43      raise RuntimeError("Could not find any TPU devices")
44    spec = tf_device.DeviceSpec.from_string(tpu_devices[0])
45    task_id = spec.task
46  else:
47    # Session master needs to be configured and the coordinator is not part
48    # of the cluster.
49    task_id = 0
50  if cluster_resolver.get_master() in ("", "local"):
51    return "/replica:0/task:0/device:CPU:0"
52  job_name = cluster_resolver.get_job_name() or "tpu_worker"
53  return "/job:%s/task:%d/device:CPU:0" % (job_name, task_id)
54
55
56@tf_export("tpu.experimental.initialize_tpu_system")
57def initialize_tpu_system(cluster_resolver=None):
58  """Initialize the TPU devices.
59
60  Args:
61    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
62        which provides information about the TPU cluster.
63  Returns:
64    The tf.tpu.Topology object for the topology of the TPU cluster.
65  """
66  if cluster_resolver is None:
67    cluster_resolver = TPUClusterResolver("")
68
69  logging.info("Initializing the TPU system.")
70
71  if context.executing_eagerly():
72    # This function looks as it is for the following non-intuitive reasons.
73    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
74    # DistributedTPURewritePass. This pass actually adds real ops that
75    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
76    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
77    # The easiest way to trigger a rewrite is to run the function with
78    # TPUPartitionedCallOp.
79    @function.defun
80    def _tpu_init_fn():
81      return tpu.initialize_system()
82
83    # We can't call _tpu_init_fn normally (because it contains just a dummy op,
84    # see above) but need to define it to get it added to eager context
85    # and get its assigned name.
86    # pylint: disable=protected-access
87    graph_func = _tpu_init_fn._get_concrete_function_internal()
88    func_name = compat.as_str(graph_func._inference_function.name)
89    # pylint: enable=protected-access
90
91    with ops.device(get_first_tpu_host_device(cluster_resolver)):
92      output = tpu_functional_ops.TPUPartitionedCall(
93          args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name)
94    serialized_topology = output[0].numpy()
95  else:
96    master = cluster_resolver.master()
97    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
98    with ops.Graph().as_default():
99      with session_lib.Session(config=session_config, target=master) as sess:
100        serialized_topology = sess.run(tpu.initialize_system())
101
102  logging.info("Finished initializing TPU system.")
103  return topology.Topology(serialized=serialized_topology)
104