1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of Cluster Resolvers for TF_CONFIG Environment Variables.""" 16 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import json 23import os 24 25from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver 26from tensorflow.python.training.server_lib import ClusterSpec 27from tensorflow.python.util.tf_export import tf_export 28 29_TF_CONFIG_ENV = 'TF_CONFIG' 30_SESSION_MASTER_KEY = 'session_master' 31_RPC_LAYER_KEY = 'rpc_layer' 32_TASK_KEY = 'task' 33 34 35def format_master_url(master, rpc_layer=None): 36 if rpc_layer: 37 return '%s://%s' % (rpc_layer, master) 38 else: 39 return master 40 41 42def _load_tf_config(): 43 return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) 44 45 46def _get_value_in_tfconfig(key, default=None): 47 tf_config = _load_tf_config() 48 return tf_config[key] if key in tf_config else default 49 50 51@tf_export('distribute.cluster_resolver.TFConfigClusterResolver') 52class TFConfigClusterResolver(ClusterResolver): 53 """Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar. 54 55 This is an implementation of cluster resolvers when using TF_CONFIG to set 56 information about the cluster. The cluster spec returned will be 57 initialized from the TF_CONFIG environment variable. 58 59 An example to set TF_CONFIG is: 60 61 ```Python 62 os.environ['TF_CONFIG'] = json.dumps({ 63 'cluster': { 64 'worker': ["localhost:12345", "localhost:23456"] 65 }, 66 'task': {'type': 'worker', 'index': 0} 67 }) 68 ``` 69 70 However, sometimes the container orchestration framework will set TF_CONFIG 71 for you. In this case, you can just create an instance without passing in any 72 arguments. You can find an example here to let Kuburnetes set TF_CONFIG for 73 you: https://github.com/tensorflow/ecosystem/tree/master/kubernetes. Then you 74 can use it with `tf.distribute.Strategy` as: 75 76 ```Python 77 # `TFConfigClusterResolver` is already the default one in the following 78 # strategy. 79 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( 80 cluster_resolver=TFConfigClusterResolver()) 81 ``` 82 """ 83 84 def __init__(self, 85 task_type=None, 86 task_id=None, 87 rpc_layer=None, 88 environment=None): 89 """Creates a new TFConfigClusterResolver. 90 91 Args: 92 task_type: (String, optional) Overrides the task type specified in the 93 TF_CONFIG environment variable. 94 task_id: (Integer, optional) Overrides the task index specified in the 95 TF_CONFIG environment variable. 96 rpc_layer: (String, optional) Overrides the rpc layer TensorFlow uses. 97 environment: (String, optional) Overrides the environment TensorFlow 98 operates in. 99 """ 100 self._task_type = task_type 101 self._task_id = task_id 102 self._rpc_layer = rpc_layer 103 self._environment = environment 104 105 @property 106 def task_type(self): 107 if self._task_type is None: 108 task_info = _get_value_in_tfconfig(_TASK_KEY, {}) 109 return str(task_info['type']) if 'type' in task_info else None 110 else: 111 return str(self._task_type) 112 113 @property 114 def task_id(self): 115 if self._task_id is None: 116 task_info = _get_value_in_tfconfig(_TASK_KEY, {}) 117 return int(task_info['index']) if 'index' in task_info else None 118 else: 119 return int(self._task_id) 120 121 @task_type.setter 122 def task_type(self, task_type): 123 self._task_type = task_type 124 125 @task_id.setter 126 def task_id(self, task_id): 127 self._task_id = task_id 128 129 @property 130 def environment(self): 131 return self._environment 132 133 @property 134 def rpc_layer(self): 135 if self._rpc_layer is None: 136 return _get_value_in_tfconfig(_RPC_LAYER_KEY) 137 else: 138 return self._rpc_layer 139 140 @rpc_layer.setter 141 def rpc_layer(self, rpc_layer): 142 self._rpc_layer = rpc_layer 143 144 def num_accelerators(self, 145 task_type=None, 146 task_id=None, 147 config_proto=None): 148 task_type = self.task_type if task_type is None else task_type 149 task_id = self.task_id if task_id is None else task_id 150 return super(TFConfigClusterResolver, self).num_accelerators( 151 task_type, task_id, config_proto) 152 153 def cluster_spec(self): 154 """Returns a ClusterSpec based on the TF_CONFIG environment variable. 155 156 Returns: 157 A ClusterSpec with information from the TF_CONFIG environment variable. 158 """ 159 tf_config = _load_tf_config() 160 if 'cluster' not in tf_config: 161 return ClusterSpec({}) 162 return ClusterSpec(tf_config['cluster']) 163 164 def master(self, task_type=None, task_id=None, rpc_layer=None): 165 """Returns the master address to use when creating a TensorFlow session. 166 167 Note: this is only useful for TensorFlow 1.x. 168 169 Args: 170 task_type: (String, optional) Overrides and sets the task_type of the 171 master. 172 task_id: (Integer, optional) Overrides and sets the task id of the 173 master. 174 rpc_layer: (String, optional) Overrides and sets the protocol over which 175 TensorFlow nodes communicate with each other. 176 177 Returns: 178 The address of the master. 179 180 Raises: 181 RuntimeError: If the task_type or task_id is not specified and the 182 `TF_CONFIG` environment variable does not contain a task section. 183 """ 184 185 # If `session_master` is set, just use that. 186 session_master = _get_value_in_tfconfig(_SESSION_MASTER_KEY) 187 if session_master is not None: 188 return session_master 189 190 # Return an empty string if we are the only job in the ClusterSpec. 191 cluster_spec = self.cluster_spec() 192 if (not cluster_spec.jobs or 193 (len(cluster_spec.jobs) == 1 and 194 len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)): 195 return '' 196 197 # We try to auto-detect the task type and id, but uses the user-supplied one 198 # where available 199 task_type = task_type if task_type is not None else self.task_type 200 task_id = task_id if task_id is not None else self.task_id 201 rpc_layer = rpc_layer if rpc_layer is not None else self.rpc_layer 202 203 return format_master_url(cluster_spec.task_address(task_type, task_id), 204 rpc_layer) 205