1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23import collections
24import re
25import six
26
27from tensorflow.python.client import session
28from tensorflow.python.eager import context
29from tensorflow.python.framework import ops
30from tensorflow.python.training.server_lib import ClusterSpec
31from tensorflow.python.util.tf_export import tf_export
32
33
34DEVICE_TYPE_REGEX = re.compile('.*device:([^:]+).*')
35
36
37def format_master_url(master, rpc_layer=None):
38  if rpc_layer:
39    return '%s://%s' % (rpc_layer, master)
40  else:
41    return master
42
43
44def get_accelerator_devices(master, config_proto):
45  """Returns accelerator devices given a master and a configuration."""
46  if context.executing_eagerly():
47    device_names = context.list_devices()  # list_devices returns list(string)
48    devices = []
49    for name in device_names:
50      device_type = 'GPU'  # default device type is GPU
51      device_match = DEVICE_TYPE_REGEX.match(name)
52      if device_match:
53        device_type = device_match.group(1)
54      if device_type == 'CPU' or device_type == 'XLA_CPU':  # Filter CPUs
55        continue
56      devices.append(session._DeviceAttributes(name, device_type, 0, 0))  # pylint: disable=protected-access
57    return devices
58  else:
59    with ops.Graph().as_default():
60      with session.Session(master, config=config_proto) as s:
61        devices = s.list_devices()
62    return devices
63
64
65@tf_export('distribute.cluster_resolver.ClusterResolver')
66@six.add_metaclass(abc.ABCMeta)
67class ClusterResolver(object):
68  """Abstract class for all implementations of ClusterResolvers.
69
70  This defines the skeleton for all implementations of ClusterResolvers.
71  ClusterResolvers are a way for TensorFlow to communicate with various cluster
72  management systems (e.g. GCE, AWS, etc...).
73
74  By letting TensorFlow communicate with these systems, we will be able to
75  automatically discover and resolve IP addresses for various TensorFlow
76  workers. This will eventually allow us to automatically recover from
77  underlying machine failures and scale TensorFlow worker clusters up and down.
78
79  Note to Implementors: In addition to these abstract methods, you must also
80  implement the task_type, task_id, and rpc_layer attributes. You may choose
81  to implement them either as properties with getters or setters or directly
82  set the attributes.
83
84  - task_type is the name of the server's current named job (e.g. 'worker',
85     'ps' in a distributed parameterized training job).
86  - task_id is the ordinal index of the server within the task type.
87  - rpc_layer is the protocol used by TensorFlow to communicate with other
88      TensorFlow servers in a distributed environment.
89  """
90
91  @abc.abstractmethod
92  def cluster_spec(self):
93    """Retrieve the current state of the cluster and returns a ClusterSpec.
94
95    Returns:
96      A ClusterSpec representing the state of the cluster at the moment this
97      function is called.
98
99    Implementors of this function must take care in ensuring that the
100    ClusterSpec returned is up-to-date at the time of calling this function.
101    This usually means retrieving the information from the underlying cluster
102    management system every time this function is invoked and reconstructing
103    a cluster_spec, rather than attempting to cache anything.
104    """
105    raise NotImplementedError()
106
107  @abc.abstractmethod
108  def master(self, task_type=None, task_id=None, rpc_layer=None):
109    """Retrieves the name or URL of the session master.
110
111    Args:
112      task_type: (Optional) The type of the TensorFlow task of the master.
113      task_id: (Optional) The index of the TensorFlow task of the master.
114      rpc_layer: (Optional) The RPC protocol for the given cluster.
115
116    Returns:
117      The name or URL of the session master.
118
119    Implementors of this function must take care in ensuring that the master
120    returned is up-to-date at the time to calling this function. This usually
121    means retrieving the master every time this function is invoked.
122    """
123    raise NotImplementedError()
124
125  def num_accelerators(self,
126                       task_type=None,
127                       task_id=None,
128                       config_proto=None):
129    """Returns the number of accelerator cores per worker.
130
131    This returns the number of accelerator cores (such as GPUs and TPUs)
132    available per worker.
133
134    Optionally, we allow callers to specify the task_type, and task_id, for
135    if they want to target a specific TensorFlow process to query
136    the number of accelerators. This is to support heterogenous environments,
137    where the number of accelerators cores per host is different.
138
139    Args:
140      task_type: (Optional) The type of the TensorFlow task of the machine we
141        want to query.
142      task_id: (Optional) The index of the TensorFlow task of the machine we
143        want to query.
144      config_proto: (Optional) Configuration for starting a new session to
145        query how many accelerator cores it has.
146
147    Returns:
148      A map of accelerator types to number of cores.
149    """
150    master = self.master(task_type, task_id)
151    devices = get_accelerator_devices(master, config_proto)
152    mapping = collections.defaultdict(int)
153    for device in devices:
154      if task_type is not None and task_id is not None:
155        job_path = '/job:%s' % task_type
156        task_path = '/task:%s' % task_id
157        if job_path not in device.name or task_path not in device.name:
158          continue
159      mapping[device.device_type] += 1
160    return mapping
161
162  @property
163  def environment(self):
164    """Returns the current environment which TensorFlow is running in.
165
166    There are two possible return values, "google" (when TensorFlow is running
167    in a Google-internal environment) or an empty string (when TensorFlow is
168    running elsewhere).
169
170    If you are implementing a ClusterResolver that works in both the Google
171    environment and the open-source world (for instance, a TPU ClusterResolver
172    or similar), you will have to return the appropriate string depending on the
173    environment, which you will have to detect.
174
175    Otherwise, if you are implementing a ClusterResolver that will only work
176    in open-source TensorFlow, you do not need to implement this property.
177    """
178    return ''
179
180
181@tf_export('distribute.cluster_resolver.SimpleClusterResolver')
182class SimpleClusterResolver(ClusterResolver):
183  """Simple implementation of ClusterResolver that accepts a ClusterSpec."""
184
185  def __init__(self, cluster_spec, master='', task_type=None, task_id=None,
186               environment='', num_accelerators=None,
187               rpc_layer=None):
188    """Creates a SimpleClusterResolver from a ClusterSpec."""
189    super(SimpleClusterResolver, self).__init__()
190
191    self._task_type = task_type
192    self._task_id = task_id
193    self._environment = environment
194
195    self._num_accelerators = num_accelerators
196    self._rpc_layer = rpc_layer
197
198    if not isinstance(cluster_spec, ClusterSpec):
199      raise TypeError('cluster_spec must be a ClusterSpec.')
200    self._cluster_spec = cluster_spec
201
202    if not isinstance(master, str):
203      raise TypeError('master must be a string.')
204    self._master = master
205
206  def cluster_spec(self):
207    """Returns the ClusterSpec passed into the constructor."""
208    return self._cluster_spec
209
210  def master(self, task_type=None, task_id=None, rpc_layer=None):
211    """Returns the master address to use when creating a session.
212
213    Args:
214      task_type: (Optional) The type of the TensorFlow task of the master.
215      task_id: (Optional) The index of the TensorFlow task of the master.
216      rpc_layer: (Optional) The RPC used by distributed TensorFlow.
217
218    Returns:
219      The name or URL of the session master.
220
221    If a task_type and task_id is given, this will override the `master`
222    string passed into the initialization function.
223    """
224    if task_type is not None and task_id is not None:
225      master = self.cluster_spec().task_address(task_type, task_id)
226    else:
227      master = self._master
228
229    return format_master_url(master, rpc_layer=rpc_layer or self._rpc_layer)
230
231  @property
232  def task_type(self):
233    return self._task_type
234
235  @property
236  def task_id(self):
237    return self._task_id
238
239  @task_type.setter
240  def task_type(self, task_type):
241    self._task_type = task_type
242
243  @task_id.setter
244  def task_id(self, task_id):
245    self._task_id = task_id
246
247  @property
248  def environment(self):
249    return self._environment
250
251  def num_accelerators(self,
252                       task_type=None,
253                       task_id=None,
254                       config_proto=None):
255    """Returns the number of accelerator cores per worker.
256
257    The SimpleClusterResolver does not do automatic detection of accelerators,
258    so a TensorFlow session will never be created, and thus all arguments are
259    unused and we simply assume that the type of accelerator is a GPU and return
260    the value in provided to us in the constructor.
261
262    Args:
263      task_type: Unused.
264      task_id: Unused.
265      config_proto: Unused.
266    """
267    # Unused
268    del task_type, task_id, config_proto
269    if self._num_accelerators is None:
270      return {}
271    return self._num_accelerators
272
273  @property
274  def rpc_layer(self):
275    return self._rpc_layer
276
277  @rpc_layer.setter
278  def rpc_layer(self, rpc_layer):
279    self._rpc_layer = rpc_layer
280
281
282@tf_export('distribute.cluster_resolver.UnionResolver')
283class UnionClusterResolver(ClusterResolver):
284  """Performs a union on underlying ClusterResolvers.
285
286  This class performs a union given two or more existing ClusterResolvers. It
287  merges the underlying ClusterResolvers, and returns one unified ClusterSpec
288  when cluster_spec is called. The details of the merge function is
289  documented in the cluster_spec function.
290
291  For additional Cluster Resolver properties such as task type, task index,
292  rpc layer, environment, etc..., we will return the value from the first
293  ClusterResolver in the union.
294  """
295
296  def __init__(self, *args, **kwargs):
297    """Initializes a UnionClusterResolver with other ClusterResolvers.
298
299    Args:
300      *args: `ClusterResolver` objects to be unionized.
301      **kwargs:
302        rpc_layer - (Optional) Override value for the RPC layer used by
303          TensorFlow.
304        task_type - (Optional) Override value for the current task type.
305        task_id - (Optional) Override value for the current task index.
306
307    Raises:
308      TypeError: If any argument is not a subclass of `ClusterResolvers`.
309      ValueError: If there are no arguments passed.
310    """
311    super(UnionClusterResolver, self).__init__()
312
313    self._rpc_layer = kwargs.pop('rpc_layer', None)
314    self._task_type = kwargs.pop('task_type', None)
315    self._task_id = kwargs.pop('task_id', None)
316
317    if kwargs:
318      raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs))
319
320    if not args:
321      raise ValueError('At least one ClusterResolver is required.')
322
323    for cluster_resolver in args:
324      if not isinstance(cluster_resolver, ClusterResolver):
325        raise TypeError('All arguments must be a sub-class of '
326                        '`ClusterResolver.`')
327    self._cluster_resolvers = args
328
329  def cluster_spec(self):
330    """Returns a union of all the ClusterSpecs from the ClusterResolvers.
331
332    Returns:
333      A ClusterSpec containing host information merged from all the underlying
334      ClusterResolvers.
335
336    Raises:
337      KeyError: If there are conflicting keys detected when merging two or
338      more dictionaries, this exception is raised.
339
340    Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the
341    same job name, we will merge the list/dict of workers.
342
343    If *all* underlying ClusterSpecs expose the set of workers as lists, we will
344    concatenate the lists of workers, starting with the list of workers from
345    the first ClusterResolver passed into the constructor.
346
347    If *any* of the ClusterSpecs expose the set of workers as a dict, we will
348    treat all the sets of workers as dicts (even if they are returned as lists)
349    and will only merge them into a dict if there is no conflicting keys. If
350    there is a conflicting key, we will raise a `KeyError`.
351    """
352
353    merged_cluster = {}
354
355    # We figure out whether it is all lists for a particular job, or whether
356    # there are dicts inside.
357    for cluster_resolver in self._cluster_resolvers:
358      cluster_spec = cluster_resolver.cluster_spec()
359      cluster_dict = cluster_spec.as_dict()
360
361      for job_name, tasks in cluster_dict.items():
362        if job_name in merged_cluster:
363          # If we see a dict, then we write a dict out regardless.
364          if isinstance(tasks, dict):
365            merged_cluster[job_name] = {}
366        else:
367          # We take whichever type is present.
368          if isinstance(tasks, list):
369            merged_cluster[job_name] = []
370          else:
371            merged_cluster[job_name] = {}
372
373    # We then do the merge as appropriate in merged_cluster[job].
374    for cluster_resolver in self._cluster_resolvers:
375      cluster_spec = cluster_resolver.cluster_spec()
376      cluster_dict = cluster_spec.as_dict()
377
378      for job_name, tasks in cluster_dict.items():
379        if isinstance(merged_cluster[job_name], list):
380          # We all have lists, we can just concatenate and be done.
381          merged_cluster[job_name].extend(tasks)
382        else:
383          if isinstance(tasks, list):
384            # We convert to a dictionary if the type is a list.
385            task_dict = dict(zip(range(0, len(tasks)), tasks))
386          else:
387            # We can simply make a copy (for update) and be done.
388            task_dict = tasks.copy()
389
390          # We detect if there are duplicates, and raise an error if so.
391          task_keys = set(task_dict)
392          merged_keys = set(merged_cluster[job_name].keys())
393          intersected_keys = task_keys.intersection(merged_keys)
394          if intersected_keys:
395            raise KeyError('Duplicate keys detected when merging two '
396                           'ClusterSpecs: %s' % repr(intersected_keys))
397
398          # We do the merge after all the processing.
399          merged_cluster[job_name].update(task_dict)
400
401    return ClusterSpec(merged_cluster)
402
403  def master(self, task_type=None, task_id=None, rpc_layer=None):
404    """Returns the master address to use when creating a session.
405
406    This usually returns the master from the first ClusterResolver passed in,
407    but you can override this by specifying the task_type and task_id.
408
409    Args:
410      task_type: (Optional) The type of the TensorFlow task of the master.
411      task_id: (Optional) The index of the TensorFlow task of the master.
412      rpc_layer: (Optional) The RPC protocol for the given cluster.
413
414    Returns:
415      The name or URL of the session master.
416    """
417    if task_type is not None and task_id is not None:
418      master = self.cluster_spec().task_address(task_type, task_id)
419      return format_master_url(master, rpc_layer or self._rpc_layer)
420
421    return self._cluster_resolvers[0].master(rpc_layer=rpc_layer)
422
423  @property
424  def task_type(self):
425    return self._task_type or self._cluster_resolvers[0].task_type
426
427  @property
428  def task_id(self):
429    return self._task_id or self._cluster_resolvers[0].task_id
430
431  @task_type.setter
432  def task_type(self, task_type):
433    self._task_type = task_type
434
435  @task_id.setter
436  def task_id(self, task_id):
437    self._task_id = task_id
438
439  @property
440  def environment(self):
441    return self._cluster_resolvers[0].environment
442
443  def num_accelerators(self,
444                       task_type=None,
445                       task_id=None,
446                       config_proto=None):
447    return self._cluster_resolvers[0].num_accelerators(
448        task_type, task_id, config_proto)
449
450  @property
451  def rpc_layer(self):
452    return self._rpc_layer or self._cluster_resolvers[0].rpc_layer
453
454  @rpc_layer.setter
455  def rpc_layer(self, rpc_layer):
456    self._rpc_layer = rpc_layer
457