1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Python interface for creating TensorFlow servers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import cluster_pb2 22from tensorflow.core.protobuf import tensorflow_server_pb2 23from tensorflow.python import pywrap_tensorflow 24from tensorflow.python.framework import errors 25from tensorflow.python.util import compat 26from tensorflow.python.util.tf_export import tf_export 27 28 29def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, 30 config): 31 """Creates a `tf.train.ServerDef` protocol buffer. 32 33 Args: 34 server_or_cluster_def: A `tf.train.ServerDef` or 35 `tf.train.ClusterDef` protocol buffer, or a 36 `tf.train.ClusterSpec` object, describing the server to be 37 defined and/or the cluster of which it is a member. 38 job_name: (Optional.) Specifies the name of the job of which the server 39 is a member. Defaults to the value in `server_or_cluster_def`, if 40 specified. 41 task_index: (Optional.) Specifies the task index of the server in its job. 42 Defaults to the value in `server_or_cluster_def`, if specified. Otherwise 43 defaults to 0 if the server's job has only one task. 44 protocol: (Optional.) Specifies the protocol to be used by the server. 45 Acceptable values include `"grpc"`. Defaults to the value in 46 `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. 47 config: (Options.) A `tf.ConfigProto` that specifies default configuration 48 options for all sessions that run on this server. 49 50 Returns: 51 A `tf.train.ServerDef`. 52 53 Raises: 54 TypeError: If the arguments do not have the appropriate type. 55 ValueError: If an argument is not specified and cannot be inferred. 56 """ 57 server_def = tensorflow_server_pb2.ServerDef() 58 if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef): 59 server_def.MergeFrom(server_or_cluster_def) 60 if job_name is not None: 61 server_def.job_name = job_name 62 if task_index is not None: 63 server_def.task_index = task_index 64 if protocol is not None: 65 server_def.protocol = protocol 66 if config is not None: 67 server_def.default_session_config.MergeFrom(config) 68 else: 69 try: 70 cluster_spec = ClusterSpec(server_or_cluster_def) 71 except TypeError: 72 raise TypeError("Could not convert `server_or_cluster_def` to a " 73 "`tf.train.ServerDef` or `tf.train.ClusterSpec`.") 74 if job_name is None: 75 if len(cluster_spec.jobs) == 1: 76 job_name = cluster_spec.jobs[0] 77 else: 78 raise ValueError("Must specify an explicit `job_name`.") 79 if task_index is None: 80 task_indices = cluster_spec.task_indices(job_name) 81 if len(task_indices) == 1: 82 task_index = task_indices[0] 83 else: 84 raise ValueError("Must specify an explicit `task_index`.") 85 if protocol is None: 86 protocol = "grpc" 87 88 server_def = tensorflow_server_pb2.ServerDef( 89 cluster=cluster_spec.as_cluster_def(), 90 job_name=job_name, task_index=task_index, protocol=protocol) 91 if config is not None: 92 server_def.default_session_config.MergeFrom(config) 93 return server_def 94 95 96@tf_export("train.Server") 97class Server(object): 98 """An in-process TensorFlow server, for use in distributed training. 99 100 A `tf.train.Server` instance encapsulates a set of devices and a 101 @{tf.Session} target that 102 can participate in distributed training. A server belongs to a 103 cluster (specified by a @{tf.train.ClusterSpec}), and 104 corresponds to a particular task in a named job. The server can 105 communicate with any other server in the same cluster. 106 """ 107 108 def __init__(self, 109 server_or_cluster_def, 110 job_name=None, 111 task_index=None, 112 protocol=None, 113 config=None, 114 start=True): 115 """Creates a new server with the given definition. 116 117 The `job_name`, `task_index`, and `protocol` arguments are optional, and 118 override any information provided in `server_or_cluster_def`. 119 120 Args: 121 server_or_cluster_def: A `tf.train.ServerDef` or 122 `tf.train.ClusterDef` protocol buffer, or a 123 `tf.train.ClusterSpec` object, describing the server to be 124 created and/or the cluster of which it is a member. 125 job_name: (Optional.) Specifies the name of the job of which the server 126 is a member. Defaults to the value in `server_or_cluster_def`, if 127 specified. 128 task_index: (Optional.) Specifies the task index of the server in its 129 job. Defaults to the value in `server_or_cluster_def`, if specified. 130 Otherwise defaults to 0 if the server's job has only one task. 131 protocol: (Optional.) Specifies the protocol to be used by the server. 132 Acceptable values include `"grpc"`. Defaults to the value in 133 `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. 134 config: (Options.) A `tf.ConfigProto` that specifies default 135 configuration options for all sessions that run on this server. 136 start: (Optional.) Boolean, indicating whether to start the server 137 after creating it. Defaults to `True`. 138 139 Raises: 140 tf.errors.OpError: Or one of its subclasses if an error occurs while 141 creating the TensorFlow server. 142 """ 143 self._server_def = _make_server_def(server_or_cluster_def, 144 job_name, task_index, protocol, config) 145 with errors.raise_exception_on_not_ok_status() as status: 146 self._server = pywrap_tensorflow.PyServer_New( 147 self._server_def.SerializeToString(), status) 148 if start: 149 self.start() 150 151 def start(self): 152 """Starts this server. 153 154 Raises: 155 tf.errors.OpError: Or one of its subclasses if an error occurs while 156 starting the TensorFlow server. 157 """ 158 with errors.raise_exception_on_not_ok_status() as status: 159 pywrap_tensorflow.PyServer_Start(self._server, status) 160 161 def join(self): 162 """Blocks until the server has shut down. 163 164 This method currently blocks forever. 165 166 Raises: 167 tf.errors.OpError: Or one of its subclasses if an error occurs while 168 joining the TensorFlow server. 169 """ 170 with errors.raise_exception_on_not_ok_status() as status: 171 pywrap_tensorflow.PyServer_Join(self._server, status) 172 173 @property 174 def server_def(self): 175 """Returns the `tf.train.ServerDef` for this server. 176 177 Returns: 178 A `tf.train.ServerDef` protocol buffer that describes the configuration 179 of this server. 180 """ 181 return self._server_def 182 183 @property 184 def target(self): 185 """Returns the target for a `tf.Session` to connect to this server. 186 187 To create a 188 @{tf.Session} that 189 connects to this server, use the following snippet: 190 191 ```python 192 server = tf.train.Server(...) 193 with tf.Session(server.target): 194 # ... 195 ``` 196 197 Returns: 198 A string containing a session target for this server. 199 """ 200 return self._server.target() 201 202 @staticmethod 203 def create_local_server(config=None, start=True): 204 """Creates a new single-process cluster running on the local host. 205 206 This method is a convenience wrapper for creating a 207 `tf.train.Server` with a `tf.train.ServerDef` that specifies a 208 single-process cluster containing a single task in a job called 209 `"local"`. 210 211 Args: 212 config: (Options.) A `tf.ConfigProto` that specifies default 213 configuration options for all sessions that run on this server. 214 start: (Optional.) Boolean, indicating whether to start the server after 215 creating it. Defaults to `True`. 216 217 Returns: 218 A local `tf.train.Server`. 219 """ 220 # Specifying port 0 means that the OS will choose a free port for the 221 # server. 222 return Server({"local": ["localhost:0"]}, protocol="grpc", config=config, 223 start=start) 224 225 226@tf_export("train.ClusterSpec") 227class ClusterSpec(object): 228 """Represents a cluster as a set of "tasks", organized into "jobs". 229 230 A `tf.train.ClusterSpec` represents the set of processes that 231 participate in a distributed TensorFlow computation. Every 232 @{tf.train.Server} is constructed in a particular cluster. 233 234 To create a cluster with two jobs and five tasks, you specify the 235 mapping from job names to lists of network addresses (typically 236 hostname-port pairs). 237 238 ```python 239 cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222", 240 "worker1.example.com:2222", 241 "worker2.example.com:2222"], 242 "ps": ["ps0.example.com:2222", 243 "ps1.example.com:2222"]}) 244 ``` 245 246 Each job may also be specified as a sparse mapping from task indices 247 to network addresses. This enables a server to be configured without 248 needing to know the identity of (for example) all other worker 249 tasks: 250 251 ```python 252 cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"}, 253 "ps": ["ps0.example.com:2222", 254 "ps1.example.com:2222"]}) 255 ``` 256 """ 257 258 def __init__(self, cluster): 259 """Creates a `ClusterSpec`. 260 261 Args: 262 cluster: A dictionary mapping one or more job names to (i) a 263 list of network addresses, or (ii) a dictionary mapping integer 264 task indices to network addresses; or a `tf.train.ClusterDef` 265 protocol buffer. 266 267 Raises: 268 TypeError: If `cluster` is not a dictionary mapping strings to lists 269 of strings, and not a `tf.train.ClusterDef` protobuf. 270 """ 271 if isinstance(cluster, dict): 272 self._cluster_spec = {} 273 for job_name, tasks in cluster.items(): 274 if isinstance(tasks, (list, tuple)): 275 job_tasks = {i: task for i, task in enumerate(tasks)} 276 elif isinstance(tasks, dict): 277 job_tasks = {i: task for i, task in tasks.items()} 278 else: 279 raise TypeError("The tasks for job %r must be a list or a dictionary " 280 "from integers to strings." % job_name) 281 self._cluster_spec[job_name] = job_tasks 282 self._make_cluster_def() 283 elif isinstance(cluster, cluster_pb2.ClusterDef): 284 self._cluster_def = cluster 285 self._cluster_spec = {} 286 for job_def in self._cluster_def.job: 287 self._cluster_spec[job_def.name] = { 288 i: t for i, t in job_def.tasks.items()} 289 elif isinstance(cluster, ClusterSpec): 290 self._cluster_def = cluster_pb2.ClusterDef() 291 self._cluster_def.MergeFrom(cluster.as_cluster_def()) 292 self._cluster_spec = {} 293 for job_def in self._cluster_def.job: 294 self._cluster_spec[job_def.name] = { 295 i: t for i, t in job_def.tasks.items()} 296 else: 297 raise TypeError("`cluster` must be a dictionary mapping one or more " 298 "job names to lists of network addresses, or a " 299 "`ClusterDef` protocol buffer") 300 301 def __nonzero__(self): 302 return bool(self._cluster_spec) 303 304 # Python 3.x 305 __bool__ = __nonzero__ 306 307 def __eq__(self, other): 308 return self._cluster_spec == other 309 310 def __ne__(self, other): 311 return self._cluster_spec != other 312 313 def __str__(self): 314 key_values = self.as_dict() 315 string_items = [ 316 repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)] 317 return "ClusterSpec({" + ", ".join(string_items) + "})" 318 319 def as_dict(self): 320 """Returns a dictionary from job names to their tasks. 321 322 For each job, if the task index space is dense, the corresponding 323 value will be a list of network addresses; otherwise it will be a 324 dictionary mapping (sparse) task indices to the corresponding 325 addresses. 326 327 Returns: 328 A dictionary mapping job names to lists or dictionaries 329 describing the tasks in those jobs. 330 """ 331 ret = {} 332 for job in self.jobs: 333 task_indices = self.task_indices(job) 334 if max(task_indices) + 1 == len(task_indices): 335 # Return a list because the task indices are dense. This 336 # matches the behavior of `as_dict()` before support for 337 # sparse jobs was added. 338 ret[job] = self.job_tasks(job) 339 else: 340 ret[job] = {i: self.task_address(job, i) for i in task_indices} 341 return ret 342 343 def as_cluster_def(self): 344 """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster.""" 345 return self._cluster_def 346 347 @property 348 def jobs(self): 349 """Returns a list of job names in this cluster. 350 351 Returns: 352 A list of strings, corresponding to the names of jobs in this cluster. 353 """ 354 return list(self._cluster_spec.keys()) 355 356 def num_tasks(self, job_name): 357 """Returns the number of tasks defined in the given job. 358 359 Args: 360 job_name: The string name of a job in this cluster. 361 362 Returns: 363 The number of tasks defined in the given job. 364 365 Raises: 366 ValueError: If `job_name` does not name a job in this cluster. 367 """ 368 try: 369 job = self._cluster_spec[job_name] 370 except KeyError: 371 raise ValueError("No such job in cluster: %r" % job_name) 372 return len(job) 373 374 def task_indices(self, job_name): 375 """Returns a list of valid task indices in the given job. 376 377 Args: 378 job_name: The string name of a job in this cluster. 379 380 Returns: 381 A list of valid task indices in the given job. 382 383 Raises: 384 ValueError: If `job_name` does not name a job in this cluster, 385 or no task with index `task_index` is defined in that job. 386 """ 387 try: 388 job = self._cluster_spec[job_name] 389 except KeyError: 390 raise ValueError("No such job in cluster: %r" % job_name) 391 return list(sorted(job.keys())) 392 393 def task_address(self, job_name, task_index): 394 """Returns the address of the given task in the given job. 395 396 Args: 397 job_name: The string name of a job in this cluster. 398 task_index: A non-negative integer. 399 400 Returns: 401 The address of the given task in the given job. 402 403 Raises: 404 ValueError: If `job_name` does not name a job in this cluster, 405 or no task with index `task_index` is defined in that job. 406 """ 407 try: 408 job = self._cluster_spec[job_name] 409 except KeyError: 410 raise ValueError("No such job in cluster: %r" % job_name) 411 try: 412 return job[task_index] 413 except KeyError: 414 raise ValueError("No task with index %r in job %r" 415 % (task_index, job_name)) 416 417 def job_tasks(self, job_name): 418 """Returns a mapping from task ID to address in the given job. 419 420 NOTE: For backwards compatibility, this method returns a list. If 421 the given job was defined with a sparse set of task indices, the 422 length of this list may not reflect the number of tasks defined in 423 this job. Use the @{tf.train.ClusterSpec.num_tasks} method 424 to find the number of tasks defined in a particular job. 425 426 Args: 427 job_name: The string name of a job in this cluster. 428 429 Returns: 430 A list of task addresses, where the index in the list 431 corresponds to the task index of each task. The list may contain 432 `None` if the job was defined with a sparse set of task indices. 433 434 Raises: 435 ValueError: If `job_name` does not name a job in this cluster. 436 """ 437 try: 438 job = self._cluster_spec[job_name] 439 except KeyError: 440 raise ValueError("No such job in cluster: %r" % job_name) 441 ret = [None for _ in range(max(job.keys()) + 1)] 442 for i, task in job.items(): 443 ret[i] = task 444 return ret 445 446 def _make_cluster_def(self): 447 """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`. 448 449 Raises: 450 TypeError: If `cluster_spec` is not a dictionary mapping strings to lists 451 of strings. 452 """ 453 self._cluster_def = cluster_pb2.ClusterDef() 454 455 # NOTE(mrry): Sort by job_name to produce deterministic protobufs. 456 for job_name, tasks in sorted(self._cluster_spec.items()): 457 try: 458 job_name = compat.as_bytes(job_name) 459 except TypeError: 460 raise TypeError("Job name %r must be bytes or unicode" % job_name) 461 462 job_def = self._cluster_def.job.add() 463 job_def.name = job_name 464 465 for i, task_address in sorted(tasks.items()): 466 try: 467 task_address = compat.as_bytes(task_address) 468 except TypeError: 469 raise TypeError( 470 "Task address %r must be bytes or unicode" % task_address) 471 job_def.tasks[i] = task_address 472