1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of Cluster Resolvers for TF_CONFIG Environment Variables."""
16
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import json
23import os
24
25from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
26from tensorflow.python.training.server_lib import ClusterSpec
27from tensorflow.python.util.tf_export import tf_export
28
29_TF_CONFIG_ENV = 'TF_CONFIG'
30_SESSION_MASTER_KEY = 'session_master'
31_RPC_LAYER_KEY = 'rpc_layer'
32_TASK_KEY = 'task'
33
34
35def format_master_url(master, rpc_layer=None):
36  if rpc_layer:
37    return '%s://%s' % (rpc_layer, master)
38  else:
39    return master
40
41
42def _load_tf_config():
43  return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
44
45
46def _get_value_in_tfconfig(key, default=None):
47  tf_config = _load_tf_config()
48  return tf_config[key] if key in tf_config else default
49
50
51@tf_export('distribute.cluster_resolver.TFConfigClusterResolver')
52class TFConfigClusterResolver(ClusterResolver):
53  """Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar.
54
55  This is an implementation of cluster resolvers when using TF_CONFIG to set
56  information about the cluster. The cluster spec returned will be
57  initialized from the TF_CONFIG environment variable.
58
59  An example to set TF_CONFIG is:
60
61    ```Python
62    os.environ['TF_CONFIG'] = json.dumps({
63      'cluster': {
64          'worker': ["localhost:12345", "localhost:23456"]
65      },
66      'task': {'type': 'worker', 'index': 0}
67    })
68    ```
69
70  However, sometimes the container orchestration framework will set TF_CONFIG
71  for you. In this case, you can just create an instance without passing in any
72  arguments. You can find an example here to let Kuburnetes set TF_CONFIG for
73  you: https://github.com/tensorflow/ecosystem/tree/master/kubernetes. Then you
74  can use it with `tf.distribute.Strategy` as:
75
76    ```Python
77    # `TFConfigClusterResolver` is already the default one in the following
78    # strategy.
79    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
80        cluster_resolver=TFConfigClusterResolver())
81    ```
82  """
83
84  def __init__(self,
85               task_type=None,
86               task_id=None,
87               rpc_layer=None,
88               environment=None):
89    """Creates a new TFConfigClusterResolver.
90
91    Args:
92      task_type: (String, optional) Overrides the task type specified in the
93        TF_CONFIG environment variable.
94      task_id: (Integer, optional) Overrides the task index specified in the
95        TF_CONFIG environment variable.
96      rpc_layer: (String, optional) Overrides the rpc layer TensorFlow uses.
97      environment: (String, optional) Overrides the environment TensorFlow
98        operates in.
99    """
100    self._task_type = task_type
101    self._task_id = task_id
102    self._rpc_layer = rpc_layer
103    self._environment = environment
104
105  @property
106  def task_type(self):
107    if self._task_type is None:
108      task_info = _get_value_in_tfconfig(_TASK_KEY, {})
109      return str(task_info['type']) if 'type' in task_info else None
110    else:
111      return str(self._task_type)
112
113  @property
114  def task_id(self):
115    if self._task_id is None:
116      task_info = _get_value_in_tfconfig(_TASK_KEY, {})
117      return int(task_info['index']) if 'index' in task_info else None
118    else:
119      return int(self._task_id)
120
121  @task_type.setter
122  def task_type(self, task_type):
123    self._task_type = task_type
124
125  @task_id.setter
126  def task_id(self, task_id):
127    self._task_id = task_id
128
129  @property
130  def environment(self):
131    return self._environment
132
133  @property
134  def rpc_layer(self):
135    if self._rpc_layer is None:
136      return _get_value_in_tfconfig(_RPC_LAYER_KEY)
137    else:
138      return self._rpc_layer
139
140  @rpc_layer.setter
141  def rpc_layer(self, rpc_layer):
142    self._rpc_layer = rpc_layer
143
144  def num_accelerators(self,
145                       task_type=None,
146                       task_id=None,
147                       config_proto=None):
148    task_type = self.task_type if task_type is None else task_type
149    task_id = self.task_id if task_id is None else task_id
150    return super(TFConfigClusterResolver, self).num_accelerators(
151        task_type, task_id, config_proto)
152
153  def cluster_spec(self):
154    """Returns a ClusterSpec based on the TF_CONFIG environment variable.
155
156    Returns:
157      A ClusterSpec with information from the TF_CONFIG environment variable.
158    """
159    tf_config = _load_tf_config()
160    if 'cluster' not in tf_config:
161      return ClusterSpec({})
162    return ClusterSpec(tf_config['cluster'])
163
164  def master(self, task_type=None, task_id=None, rpc_layer=None):
165    """Returns the master address to use when creating a TensorFlow session.
166
167    Note: this is only useful for TensorFlow 1.x.
168
169    Args:
170      task_type: (String, optional) Overrides and sets the task_type of the
171        master.
172      task_id: (Integer, optional) Overrides and sets the task id of the
173        master.
174      rpc_layer: (String, optional) Overrides and sets the protocol over which
175        TensorFlow nodes communicate with each other.
176
177    Returns:
178      The address of the master.
179
180    Raises:
181      RuntimeError: If the task_type or task_id is not specified and the
182        `TF_CONFIG` environment variable does not contain a task section.
183    """
184
185    # If `session_master` is set, just use that.
186    session_master = _get_value_in_tfconfig(_SESSION_MASTER_KEY)
187    if session_master is not None:
188      return session_master
189
190    # Return an empty string if we are the only job in the ClusterSpec.
191    cluster_spec = self.cluster_spec()
192    if (not cluster_spec.jobs or
193        (len(cluster_spec.jobs) == 1 and
194         len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)):
195      return ''
196
197    # We try to auto-detect the task type and id, but uses the user-supplied one
198    # where available
199    task_type = task_type if task_type is not None else self.task_type
200    task_id = task_id if task_id is not None else self.task_id
201    rpc_layer = rpc_layer if rpc_layer is not None else self.rpc_layer
202
203    return format_master_url(cluster_spec.task_address(task_type, task_id),
204                             rpc_layer)
205