1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Python interface for creating dataset servers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
24from tensorflow.core.protobuf import service_config_pb2
25from tensorflow.python import pywrap_tensorflow
26from tensorflow.python.data.experimental.service import _pywrap_server_lib
27from tensorflow.python.util.tf_export import tf_export
28
29
30@tf_export("data.experimental.service.DispatcherConfig")
31class DispatcherConfig(
32    collections.namedtuple("DispatcherConfig", [
33        "port", "protocol", "work_dir", "fault_tolerant_mode",
34        "job_gc_check_interval_ms", "job_gc_timeout_ms"
35    ])):
36  """Configuration class for tf.data service dispatchers.
37
38  Fields:
39    port: Specifies the port to bind to. A value of 0 indicates that the server
40      may bind to any available port.
41    protocol: The protocol to use for communicating with the tf.data service.
42      Defaults to `"grpc"`.
43    work_dir: A directory to store dispatcher state in. This
44      argument is required for the dispatcher to be able to recover from
45      restarts.
46    fault_tolerant_mode: Whether the dispatcher should write its state to a
47      journal so that it can recover from restarts. Dispatcher state, including
48      registered datasets and created jobs, is synchronously written to the
49      journal before responding to RPCs. If `True`, `work_dir` must also be
50      specified.
51    job_gc_check_interval_ms: How often the dispatcher should scan through to
52      delete old and unused jobs, in milliseconds. If not set, the runtime will
53      select a reasonable default. A higher value will reduce load on the
54      dispatcher, while a lower value will reduce the time it takes for the
55      dispatcher to garbage collect expired jobs.
56    job_gc_timeout_ms: How long a job needs to be unused before it becomes a
57      candidate for garbage collection, in milliseconds. If not set, the runtime
58      will select a reasonable default. A higher value will cause jobs to stay
59      around longer with no consumers. This is useful if there is a large gap in
60      time between when consumers read from the job. A lower value will reduce
61      the time it takes to reclaim the resources from expired jobs.
62  """
63
64  def __new__(cls,
65              port=0,
66              protocol="grpc",
67              work_dir=None,
68              fault_tolerant_mode=False,
69              job_gc_check_interval_ms=None,
70              job_gc_timeout_ms=None):
71    if job_gc_check_interval_ms is None:
72      job_gc_check_interval_ms = 10 * 60 * 1000  # 10 minutes.
73    if job_gc_timeout_ms is None:
74      job_gc_timeout_ms = 5 * 60 * 1000  # 5 minutes.
75    return super(DispatcherConfig,
76                 cls).__new__(cls, port, protocol, work_dir,
77                              fault_tolerant_mode, job_gc_check_interval_ms,
78                              job_gc_timeout_ms)
79
80
81@tf_export("data.experimental.service.DispatchServer", v1=[])
82class DispatchServer(object):
83  """An in-process tf.data service dispatch server.
84
85  A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
86  `tf.data.experimental.service.WorkerServer`s. When the workers start, they
87  register themselves with the dispatcher.
88
89  >>> dispatcher = tf.data.experimental.service.DispatchServer()
90  >>> dispatcher_address = dispatcher.target.split("://")[1]
91  >>> worker = tf.data.experimental.service.WorkerServer(
92  ...     tf.data.experimental.service.WorkerConfig(
93  ...     dispatcher_address=dispatcher_address))
94  >>> dataset = tf.data.Dataset.range(10)
95  >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
96  ...     processing_mode="parallel_epochs", service=dispatcher.target))
97  >>> print(list(dataset.as_numpy_iterator()))
98  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
99
100  When starting a dedicated tf.data dispatch process, use join() to block
101  indefinitely after starting up the server.
102
103  ```
104  dispatcher = tf.data.experimental.service.DispatchServer(
105      tf.data.experimental.service.DispatcherConfig(port=5050))
106  dispatcher.join()
107  ```
108
109  To start a `DispatchServer` in fault-tolerant mode, set `work_dir` and
110  `fault_tolerant_mode` like below:
111
112  ```
113  dispatcher = tf.data.experimental.service.DispatchServer(
114      tf.data.experimental.service.DispatcherConfig(
115          port=5050,
116          work_dir="gs://my-bucket/dispatcher/work_dir",
117          fault_tolerant_mode=True))
118  ```
119  """
120
121  def __init__(self, config=None, start=True):
122    """Creates a new dispatch server.
123
124    Args:
125      config: (Optional.) A `tf.data.experimental.service.DispatcherConfig`
126        configration. If `None`, the dispatcher will use default
127        configuration values.
128      start: (Optional.) Boolean, indicating whether to start the server after
129        creating it. Defaults to True.
130    """
131    config = config or DispatcherConfig()
132    if config.fault_tolerant_mode and not config.work_dir:
133      raise ValueError(
134          "Cannot enable fault tolerant mode without configuring a work_dir")
135    self._config = config
136    config_proto = service_config_pb2.DispatcherConfig(
137        port=config.port,
138        protocol=config.protocol,
139        work_dir=config.work_dir,
140        fault_tolerant_mode=config.fault_tolerant_mode,
141        job_gc_check_interval_ms=config.job_gc_check_interval_ms,
142        job_gc_timeout_ms=config.job_gc_timeout_ms)
143    self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
144        config_proto.SerializeToString())
145    if start:
146      self._server.start()
147
148  def start(self):
149    """Starts this server.
150
151    >>> dispatcher = tf.data.experimental.service.DispatchServer(start=False)
152    >>> dispatcher.start()
153
154    Raises:
155      tf.errors.OpError: Or one of its subclasses if an error occurs while
156        starting the server.
157    """
158    self._server.start()
159
160  def join(self):
161    """Blocks until the server has shut down.
162
163    This is useful when starting a dedicated dispatch process.
164
165    ```
166    dispatcher = tf.data.experimental.service.DispatchServer(
167        tf.data.experimental.service.DispatcherConfig(port=5050))
168    dispatcher.join()
169    ```
170
171    Raises:
172      tf.errors.OpError: Or one of its subclasses if an error occurs while
173        joining the server.
174    """
175    self._server.join()
176
177  @property
178  def target(self):
179    """Returns a target that can be used to connect to the server.
180
181    >>> dispatcher = tf.data.experimental.service.DispatchServer()
182    >>> dataset = tf.data.Dataset.range(10)
183    >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
184    ...     processing_mode="parallel_epochs", service=dispatcher.target))
185
186    The returned string will be in the form protocol://address, e.g.
187    "grpc://localhost:5050".
188    """
189    return "{0}://localhost:{1}".format(self._config.protocol,
190                                        self._server.bound_port())
191
192  def _stop(self):
193    """Stops the server.
194
195    Raises:
196      tf.errors.OpError: Or one of its subclasses if an error occurs while
197        stopping the server.
198    """
199    self._server.stop()
200
201  def __del__(self):
202    self._stop()
203
204  @property
205  def _address(self):
206    """Returns the address of the server.
207
208    The returned string will be in the form address:port, e.g. "localhost:1000".
209    """
210    return "localhost:{0}".format(self._server.bound_port())
211
212  def _num_workers(self):
213    """Returns the number of workers registered with the dispatcher."""
214    return self._server.num_workers()
215
216
217@tf_export("data.experimental.service.WorkerConfig")
218class WorkerConfig(
219    collections.namedtuple("WorkerConfig", [
220        "dispatcher_address", "worker_address", "port", "protocol",
221        "heartbeat_interval_ms", "dispatcher_timeout_ms"
222    ])):
223  """Configuration class for tf.data service dispatchers.
224
225  Fields:
226    dispatcher_address: Specifies the address of the dispatcher.
227    worker_address: Specifies the address of the worker server. This address is
228      passed to the dispatcher so that the dispatcher can tell clients how to
229      connect to this worker.
230    port: Specifies the port to bind to. A value of 0 indicates that the worker
231      can bind to any available port.
232    protocol: (Optional.) Specifies the protocol to be used by the server.
233      Defaults to `"grpc"`.
234    heartbeat_interval_ms: How often the worker should heartbeat to the
235      dispatcher, in milliseconds. If not set, the runtime will select a
236      reasonable default. A higher value will reduce the load on the dispatcher,
237      while a lower value will reduce the time it takes to reclaim resources
238      from finished jobs.
239    dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the
240      dispatcher before giving up and reporting an error. Defaults to 1 hour.
241  """
242
243  def __new__(cls,
244              dispatcher_address,
245              worker_address=None,
246              port=0,
247              protocol="grpc",
248              heartbeat_interval_ms=None,
249              dispatcher_timeout_ms=None):
250    if worker_address is None:
251      worker_address = "localhost:%port%"
252    if heartbeat_interval_ms is None:
253      heartbeat_interval_ms = 30 * 1000  # 30 seconds
254    if dispatcher_timeout_ms is None:
255      dispatcher_timeout_ms = 60 * 60 * 1000  # 1 hour
256
257    return super(WorkerConfig,
258                 cls).__new__(cls, dispatcher_address, worker_address, port,
259                              protocol, heartbeat_interval_ms,
260                              dispatcher_timeout_ms)
261
262
263@tf_export("data.experimental.service.WorkerServer", v1=[])
264class WorkerServer(object):
265  """An in-process tf.data service worker server.
266
267  A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
268  processing for user-defined datasets, and provides the resulting elements over
269  RPC. A worker is associated with a single
270  `tf.data.experimental.service.DispatchServer`.
271
272  >>> dispatcher = tf.data.experimental.service.DispatchServer()
273  >>> dispatcher_address = dispatcher.target.split("://")[1]
274  >>> worker = tf.data.experimental.service.WorkerServer(
275  ...     tf.data.experimental.service.WorkerConfig(
276  ...         dispatcher_address=dispatcher_address))
277  >>> dataset = tf.data.Dataset.range(10)
278  >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
279  ...     processing_mode="parallel_epochs", service=dispatcher.target))
280  >>> print(list(dataset.as_numpy_iterator()))
281  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
282
283  When starting a dedicated tf.data worker process, use join() to block
284  indefinitely after starting up the server.
285
286  ```
287  worker = tf.data.experimental.service.WorkerServer(
288      port=5051, dispatcher_address="grpc://localhost:5050")
289  worker.join()
290  ```
291  """
292
293  def __init__(self, config, start=True):
294    """Creates a new worker server.
295
296    Args:
297      config: A `tf.data.experimental.service.WorkerConfig` configration.
298      start: (Optional.) Boolean, indicating whether to start the server after
299        creating it. Defaults to True.
300    """
301    if config.dispatcher_address is None:
302      raise ValueError("must specify a dispatcher_address")
303    self._config = config
304    config_proto = service_config_pb2.WorkerConfig(
305        dispatcher_address=config.dispatcher_address,
306        worker_address=config.worker_address,
307        port=config.port,
308        protocol=config.protocol,
309        heartbeat_interval_ms=config.heartbeat_interval_ms,
310        dispatcher_timeout_ms=config.dispatcher_timeout_ms,
311        data_transfer_protocol=None)
312    self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
313        config_proto.SerializeToString())
314    if start:
315      self._server.start()
316
317  def start(self):
318    """Starts this server.
319
320    Raises:
321      tf.errors.OpError: Or one of its subclasses if an error occurs while
322        starting the server.
323    """
324    self._server.start()
325
326  def join(self):
327    """Blocks until the server has shut down.
328
329    This is useful when starting a dedicated worker process.
330
331    ```
332    worker_server = tf.data.experimental.service.WorkerServer(
333        port=5051, dispatcher_address="grpc://localhost:5050")
334    worker_server.join()
335    ```
336
337    This method currently blocks forever.
338
339    Raises:
340      tf.errors.OpError: Or one of its subclasses if an error occurs while
341        joining the server.
342    """
343    self._server.join()
344
345  def _stop(self):
346    """Stops the server.
347
348    Raises:
349      tf.errors.OpError: Or one of its subclasses if an error occurs while
350        stopping the server.
351    """
352    self._server.stop()
353
354  def __del__(self):
355    self._stop()
356
357  @property
358  def _address(self):
359    """Returns the address of the server.
360
361    The returned string will be in the form address:port, e.g. "localhost:1000".
362    """
363    return "localhost:{0}".format(self._server.bound_port())
364
365  def _num_tasks(self):
366    """Returns the number of tasks currently being executed on the worker."""
367    return self._server.num_tasks()
368