1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python API for executing a tf.data.Dataset using a tf.data service."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21
22import six
23
24from tensorflow.python import tf2
25from tensorflow.python.data.experimental.ops import compression_ops
26from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
27from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
28from tensorflow.python.data.ops import dataset_ops
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.ops import gen_experimental_dataset_ops
33from tensorflow.python.util.tf_export import tf_export
34
35
36class ProcessingMode(object):
37  """tf.data service processing modes."""
38
39  PARALLEL_EPOCHS = "parallel_epochs"
40  DISTRIBUTED_EPOCH = "distributed_epoch"
41
42  @staticmethod
43  def validate(mode):
44    """Raises a ValueError if the given object is not a valid processing mode."""
45    valid_modes = [
46        ProcessingMode.PARALLEL_EPOCHS, ProcessingMode.DISTRIBUTED_EPOCH
47    ]
48    if mode not in valid_modes:
49      raise ValueError(
50          "{0} is not a valid processing mode. Valid modes: {1}".format(
51              mode, valid_modes))
52
53
54class _DataServiceDatasetV2(dataset_ops.DatasetSource):
55  """A `Dataset` that reads elements from the tf.data service."""
56
57  def __init__(self,
58               dataset_id,
59               processing_mode,
60               address,
61               protocol,
62               data_transfer_protocol,
63               job_name=None,
64               consumer_index=None,
65               num_consumers=None,
66               max_outstanding_requests=None,
67               task_refresh_interval_hint_ms=None):
68    """Constructs a _DataServiceDatasetV2.
69
70    Args:
71      dataset_id: The dataset id for the dataset to read from.
72      processing_mode: A string specifying the policy for how data should be
73        processed by tf.data workers. Can be either "parallel_epochs" to have
74        each tf.data worker process a copy of the dataset, or
75        "distributed_epoch" to split a single iteration of the dataset across
76        all the workers.
77      address: The tf.data service address, e.g. "localhost:5000".
78      protocol: The protocol to use for communicating with the tf.data service,
79        e.g. "grpc".
80      data_transfer_protocol: The protocol to use for transferring data with the
81        tf.data service, e.g. "grpc".
82      job_name: (Optional.) The name of the job. This argument makes it possible
83        for multiple datasets to share the same job. The default behavior is
84        that the dataset creates anonymous, exclusively owned jobs.
85      consumer_index: (Optional.) The index of the consumer in the range from
86        `0` to `num_consumers`. Must be specified alongside `num_consumers`.
87        When specified, consumers will read from the job in a strict round-robin
88        order, instead of the default first-come-first-served order.
89      num_consumers: (Optional.) The number of consumers which will consume from
90        the job. Must be specified alongside `consumer_index`. When specified,
91        consumers will read from the job in a strict round-robin order, instead
92        of the default first-come-first-served order. When `num_consumers` is
93        specified, the dataset must have infinite cardinality to prevent a
94        producer from running out of data early and causing consumers to go out
95        of sync.
96      max_outstanding_requests: (Optional.) A limit on how many elements may be
97        requested at the same time. You can use this option to control the
98        amount of memory used, since `distribute` won't use more than
99        `element_size` * `max_outstanding_requests` of memory.
100      task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
101        the dispatcher for task changes.
102    """
103    if consumer_index is None != num_consumers is None:
104      raise ValueError(
105          "Must either set both consumer_index and num_consumers, or neither. ",
106          "consumer_index: ", consumer_index, ", num_consumers: ",
107          num_consumers)
108    if num_consumers is not None and job_name is None:
109      raise ValueError("job_name must be set when setting num_consumers")
110
111    if job_name is None:
112      job_name = ""
113    if max_outstanding_requests is None:
114      max_outstanding_requests = dataset_ops.AUTOTUNE
115    if task_refresh_interval_hint_ms is None:
116      task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE
117
118    self._dataset_id = ops.convert_to_tensor(
119        dataset_id, dtype=dtypes.int64, name="dataset_id")
120    self._processing_mode = ops.convert_to_tensor(
121        processing_mode, dtype=dtypes.string, name="processing_mode")
122    self._address = ops.convert_to_tensor(
123        address, dtype=dtypes.string, name="address")
124    self._protocol = ops.convert_to_tensor(
125        protocol, dtype=dtypes.string, name="protocol")
126    self._job_name = ops.convert_to_tensor(
127        job_name, dtype=dtypes.string, name="job_name")
128    self._consumer_index = ops.convert_to_tensor(
129        -1 if consumer_index is None else consumer_index,
130        dtype=dtypes.int64,
131        name="consumer_index")
132    self._num_consumers = ops.convert_to_tensor(
133        -1 if num_consumers is None else num_consumers,
134        dtype=dtypes.int64,
135        name="num_consumers")
136    self._max_outstanding_requests = ops.convert_to_tensor(
137        max_outstanding_requests,
138        dtype=dtypes.int64,
139        name="max_outstanding_requests")
140    # Datasets executed by the tf.data service produce compressed elements
141    # represented by scalar DT_VARIANTs.
142    self._element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
143
144    compat_kwargs = {}
145    if data_transfer_protocol is not None:
146      compat_kwargs["data_transfer_protocol"] = data_transfer_protocol
147
148    if num_consumers is None:
149      variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
150          dataset_id=self._dataset_id,
151          processing_mode=self._processing_mode,
152          address=self._address,
153          protocol=self._protocol,
154          job_name=self._job_name,
155          max_outstanding_requests=self._max_outstanding_requests,
156          task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
157          iteration_counter=gen_experimental_dataset_ops
158          .dummy_iteration_counter(),
159          **compat_kwargs,
160          **self._flat_structure)
161    else:
162      variant_tensor = gen_experimental_dataset_ops.data_service_dataset_v2(
163          dataset_id=self._dataset_id,
164          processing_mode=self._processing_mode,
165          address=self._address,
166          protocol=self._protocol,
167          job_name=self._job_name,
168          consumer_index=self._consumer_index,
169          num_consumers=self._num_consumers,
170          max_outstanding_requests=self._max_outstanding_requests,
171          task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
172          iteration_counter=gen_experimental_dataset_ops
173          .dummy_iteration_counter(),
174          **compat_kwargs,
175          **self._flat_structure)
176    super(_DataServiceDatasetV2, self).__init__(variant_tensor)
177
178  @property
179  def element_spec(self):
180    return self._element_spec
181
182
183class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
184  """A `Dataset` that executes its input through the tf.data service."""
185
186  @functools.wraps(_DataServiceDatasetV2.__init__)
187  def __init__(self, dataset_id, processing_mode, address, protocol,
188               data_transfer_protocol, job_name, consumer_index, num_consumers,
189               max_outstanding_requests, task_refresh_interval_hint_ms):
190
191    self._wrapped = _DataServiceDatasetV2(
192        dataset_id=dataset_id,
193        processing_mode=processing_mode,
194        address=address,
195        protocol=protocol,
196        data_transfer_protocol=data_transfer_protocol,
197        job_name=job_name,
198        consumer_index=consumer_index,
199        num_consumers=num_consumers,
200        max_outstanding_requests=max_outstanding_requests,
201        task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
202    super(_DataServiceDatasetV1, self).__init__(self._wrapped)
203
204
205if tf2.enabled():
206  _DataServiceDataset = _DataServiceDatasetV2
207else:
208  _DataServiceDataset = _DataServiceDatasetV1
209
210
211def _parse_service(service):
212  """Parses a tf.data service string into a (protocol, address) tuple.
213
214  Args:
215    service: A string in the format "protocol://address".
216
217  Returns:
218    The parsed (protocol, address) tuple
219  """
220  if not isinstance(service, six.string_types):
221    raise ValueError(
222        "service must be a string, but service was of type {0}. service={1}"
223        .format(type(service), service))
224  if not service:
225    raise ValueError("service must not be empty")
226  parts = service.split("://")
227  if len(parts) == 1:
228    raise ValueError("service string %s does not begin with a protocol. "
229                     "The service should be in the format "
230                     "<protocol>://<address>, e.g. grpc://localhost:5000" %
231                     service)
232  if len(parts) > 2:
233    raise ValueError("malformed service string has multiple '://': %s" %
234                     service)
235  return parts
236
237
238def _from_dataset_id(processing_mode,
239                     service,
240                     dataset_id,
241                     element_spec,
242                     job_name=None,
243                     consumer_index=None,
244                     num_consumers=None,
245                     max_outstanding_requests=None,
246                     task_refresh_interval_hint_ms=None,
247                     data_transfer_protocol=None):
248  """Creates a dataset which reads data from the tf.data service.
249
250  This transformation is similar to `from_dataset_id`, but supports additional
251  parameters which we do not yet want to add to the public Python API.
252
253  Args:
254    processing_mode: A string specifying the policy for how data should be
255      processed by tf.data workers. Can be either "parallel_epochs" to have
256      each tf.data worker process a copy of the dataset, or
257      "distributed_epoch" to split a single iteration of the dataset across
258      all the workers.
259    service: A string indicating how to connect to the tf.data service. The
260      string should be in the format "<protocol>://<address>", e.g.
261      "grpc://localhost:5000".
262    dataset_id: The id of the dataset to read from. This id is returned by
263      `register_dataset` when the dataset is registered with the tf.data
264      service.
265    element_spec: A nested structure of `tf.TypeSpec`s representing the type of
266      elements produced by the dataset. Use `tf.data.Dataset.element_spec` to
267      see the element spec for a given dataset.
268    job_name: (Optional.) The name of the job. This argument makes it possible
269      for multiple datasets to share the same job. The default behavior is that
270      the dataset creates anonymous, exclusively owned jobs.
271    consumer_index: (Optional.) The index of the consumer in the range from
272      `0` to `num_consumers`. Must be specified alongside `num_consumers`.
273      When specified, consumers will read from the job in a strict round-robin
274      order, instead of the default first-come-first-served order.
275    num_consumers: (Optional.) The number of consumers which will consume from
276      the job. Must be specified alongside `consumer_index`. When specified,
277      consumers will read from the job in a strict round-robin order, instead
278      of the default first-come-first-served order. When `num_consumers` is
279      specified, the dataset must have infinite cardinality to prevent a
280      producer from running out of data early and causing consumers to go out of
281      sync.
282    max_outstanding_requests: (Optional.) A limit on how many elements may be
283      requested at the same time. You can use this option to control the amount
284      of memory used, since `distribute` won't use more than `element_size` *
285      `max_outstanding_requests` of memory.
286    task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
287      dispatcher for task changes.
288    data_transfer_protocol: (Optional.) The protocol to use for transferring
289      data with the tf.data service.
290
291  Returns:
292    A `tf.data.Dataset` which reads from the tf.data service.
293  """
294  ProcessingMode.validate(processing_mode)
295  if job_name is not None:
296    if not isinstance(job_name, six.string_types):
297      raise ValueError("job_name must be a string, but job_name was of type "
298                       "{0}. job_name={1}".format(type(job_name), job_name))
299    if not job_name:
300      raise ValueError("job_name must not be empty")
301  if element_spec is None:
302    raise ValueError("element_spec must not be None")
303  protocol, address = _parse_service(service)
304
305  dataset = _DataServiceDataset(
306      dataset_id=dataset_id,
307      processing_mode=processing_mode,
308      address=address,
309      protocol=protocol,
310      data_transfer_protocol=data_transfer_protocol,
311      job_name=job_name,
312      consumer_index=consumer_index,
313      num_consumers=num_consumers,
314      max_outstanding_requests=max_outstanding_requests,
315      task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
316  dataset = dataset.map(
317      lambda x: compression_ops.uncompress(x, output_spec=element_spec),
318      num_parallel_calls=dataset_ops.AUTOTUNE)
319
320  # Disable autosharding for shared jobs.
321  if job_name:
322    options = dataset_ops.Options()
323    options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
324    dataset = dataset.with_options(options)
325  return dataset
326
327
328def _distribute(processing_mode,
329                service,
330                job_name=None,
331                consumer_index=None,
332                num_consumers=None,
333                max_outstanding_requests=None,
334                task_refresh_interval_hint_ms=None,
335                data_transfer_protocol=None):
336  """A transformation that moves dataset processing to the tf.data service.
337
338  This transformation is similar to `distribute`, but supports additional
339  parameters which we do not yet want to add to the public Python API.
340
341  Args:
342    processing_mode: A string specifying the policy for how data should be
343      processed by tf.data workers. Can be either "parallel_epochs" to have
344      each tf.data worker process a copy of the dataset, or
345      "distributed_epoch" to split a single iteration of the dataset across
346      all the workers.
347    service: A string indicating how to connect to the tf.data service. The
348      string should be in the format "<protocol>://<address>", e.g.
349      "grpc://localhost:5000".
350    job_name: (Optional.) The name of the job. This argument makes it possible
351      for multiple datasets to share the same job. The default behavior is that
352      the dataset creates anonymous, exclusively owned jobs.
353    consumer_index: (Optional.) The index of the consumer in the range from
354      `0` to `num_consumers`. Must be specified alongside `num_consumers`.
355      When specified, consumers will read from the job in a strict round-robin
356      order, instead of the default first-come-first-served order.
357    num_consumers: (Optional.) The number of consumers which will consume from
358      the job. Must be specified alongside `consumer_index`. When specified,
359      consumers will read from the job in a strict round-robin order, instead
360      of the default first-come-first-served order. When `num_consumers` is
361      specified, the dataset must have infinite cardinality to prevent a
362      producer from running out of data early and causing consumers to go out of
363      sync.
364    max_outstanding_requests: (Optional.) A limit on how many elements may be
365      requested at the same time. You can use this option to control the amount
366      of memory used, since `distribute` won't use more than `element_size` *
367      `max_outstanding_requests` of memory.
368    task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
369      dispatcher for task changes.
370    data_transfer_protocol: (Optional.) The protocol to use for transferring
371      data with the tf.data service.
372
373  Returns:
374    Dataset: A `Dataset` of the elements produced by the data service.
375  """
376  ProcessingMode.validate(processing_mode)
377
378  def _apply_fn(dataset):  # pylint: disable=missing-docstring
379    dataset_id = register_dataset(service, dataset)
380    return _from_dataset_id(
381        processing_mode,
382        service,
383        dataset_id,
384        dataset.element_spec,
385        job_name=job_name,
386        consumer_index=consumer_index,
387        num_consumers=num_consumers,
388        max_outstanding_requests=max_outstanding_requests,
389        task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
390        data_transfer_protocol=data_transfer_protocol)
391
392  return _apply_fn
393
394
395@tf_export("data.experimental.service.distribute")
396def distribute(processing_mode,
397               service,
398               job_name=None,
399               consumer_index=None,
400               num_consumers=None,
401               max_outstanding_requests=None,
402               data_transfer_protocol=None):
403  """A transformation that moves dataset processing to the tf.data service.
404
405  When you iterate over a dataset containing the `distribute` transformation,
406  the tf.data service creates a "job" which produces data for the dataset
407  iteration.
408
409  The tf.data service uses a cluster of workers to prepare data for training
410  your model.
411  The `processing_mode` argument to `tf.data.experimental.service.distribute`
412  describes how to leverage multiple workers to process the input dataset.
413  Currently, there are two processing modes to choose from: "distributed_epoch"
414  and "parallel_epochs".
415
416  "distributed_epoch" means that the dataset will be split across all tf.data
417  service workers.
418  The dispatcher produces "splits" for the dataset and sends them to workers for
419  further processing. For example, if a dataset begins with a list of filenames,
420  the dispatcher will iterate through the filenames and send the filenames to
421  tf.data workers, which will perform the rest of the dataset transformations on
422  those files. "distributed_epoch" is useful when your model needs to see each
423  element of the dataset exactly once, or if it needs to see the data in a
424  generally-sequential order. "distributed_epoch" only works for datasets with
425  splittable sources, such as `Dataset.from_tensor_slices`,
426  `Dataset.list_files`, or `Dataset.range`.
427
428  "parallel_epochs" means that the entire input dataset will be processed
429  independently by each of the tf.data service workers.
430  For this reason, it is important to shuffle data (e.g. filenames)
431  non-deterministically, so that each worker will process the elements of the
432  dataset in a different order. "parallel_epochs" can be used to distribute
433  datasets that aren't splittable.
434
435  With two workers, "parallel_epochs" will produce every element of the dataset
436  twice:
437
438  >>> dispatcher = tf.data.experimental.service.DispatchServer()
439  >>> dispatcher_address = dispatcher.target.split("://")[1]
440  >>> # Start two workers
441  >>> workers = [
442  ...     tf.data.experimental.service.WorkerServer(
443  ...         tf.data.experimental.service.WorkerConfig(
444  ...             dispatcher_address=dispatcher_address)) for _ in range(2)
445  ... ]
446  >>> dataset = tf.data.Dataset.range(10)
447  >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
448  ...     processing_mode="parallel_epochs", service=dispatcher.target))
449  >>> print(sorted(list(dataset.as_numpy_iterator())))
450  [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9]
451
452  "distributed_epoch", on the other hand, will still produce each element once:
453
454  >>> dispatcher = tf.data.experimental.service.DispatchServer()
455  >>> dispatcher_address = dispatcher.target.split("://")[1]
456  >>> workers = [
457  ...     tf.data.experimental.service.WorkerServer(
458  ...         tf.data.experimental.service.WorkerConfig(
459  ...             dispatcher_address=dispatcher_address)) for _ in range(2)
460  ... ]
461  >>> dataset = tf.data.Dataset.range(10)
462  >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
463  ...     processing_mode="distributed_epoch", service=dispatcher.target))
464  >>> print(sorted(list(dataset.as_numpy_iterator())))
465  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
466
467  When using `apply(tf.data.experimental.service.distribute(...))`, the dataset
468  before the `apply` transformation executes within the tf.data service, while
469  the operations after `apply` happen within the local process.
470
471  >>> dispatcher = tf.data.experimental.service.DispatchServer()
472  >>> dispatcher_address = dispatcher.target.split("://")[1]
473  >>> workers = [
474  ...     tf.data.experimental.service.WorkerServer(
475  ...         tf.data.experimental.service.WorkerConfig(
476  ...             dispatcher_address=dispatcher_address)) for _ in range(2)
477  ... ]
478  >>> dataset = tf.data.Dataset.range(5)
479  >>> dataset = dataset.map(lambda x: x*x)
480  >>> dataset = dataset.apply(
481  ...    tf.data.experimental.service.distribute("parallel_epochs",
482  ...                                            dispatcher.target))
483  >>> dataset = dataset.map(lambda x: x+1)
484  >>> print(sorted(list(dataset.as_numpy_iterator())))
485  [1, 1, 2, 2, 5, 5, 10, 10, 17, 17]
486
487  In the above example, the dataset operations (before applying the `distribute`
488  function on the elements) will be executed on the tf.data workers,
489  and the elements are provided over RPC. The remaining transformations
490  (after the call to `distribute`) will be executed locally. The dispatcher
491  and the workers will bind to usused free ports (which are chosen at random),
492  in order to communicate with each other. However, to bind them to specific
493  ports, the `port` parameter can be passed.
494
495  The `job_name` argument allows jobs to be shared across multiple
496  datasets. Instead of each dataset creating its own job, all
497  datasets with the same `job_name` will consume from the same job. A new job
498  will be created for each iteration of the dataset (with each repetition of
499  `Dataset.repeat` counting as a new iteration). Suppose the `DispatchServer`
500  is serving on `localhost:5000` and two training workers (in either a single
501  client or multi-client setup) iterate over the below dataset, and there is a
502  single tf.data worker:
503
504  ```
505  range5_dataset = tf.data.Dataset.range(5)
506  dataset = range5_dataset.apply(tf.data.experimental.service.distribute(
507      "parallel_epochs", "grpc://localhost:5000", job_name="my_job_name"))
508  for iteration in range(3):
509    print(list(dataset))
510  ```
511
512  The elements of each job will be split between the two processes, with
513  elements being consumed by the processes on a first-come first-served basis.
514  One possible result is that process 1 prints
515
516  ```
517  [0, 2, 4]
518  [0, 1, 3]
519  [1]
520  ```
521
522  and process 2 prints
523
524  ```
525  [1, 3]
526  [2, 4]
527  [0, 2, 3, 4]
528  ```
529
530  Job names must not be re-used across different training jobs within the
531  lifetime of the tf.data service. In general, the tf.data service is expected
532  to live for the duration of a single training job.
533  To use the tf.data service with multiple training jobs, make sure to use
534  different job names to avoid conflicts. For example, suppose a training job
535  calls `distribute` with `job_name="job"` and reads until end of input. If
536  another independent job connects to the same tf.data service and tries to read
537  from `job_name="job"`, it will immediately receive end of input, without
538  getting any data.
539
540  **Round Robin data consumption**
541
542  By default, when multiple consumers read from the same job, they receive data
543  on a first-come first-served basis. In some use cases, it works better to use
544  a strict round-robin order. For example, the tf.data service can be used to
545  coordinate example sizes across a cluster during sychronous training, so that
546  during each step all replicas train on similar-sized elements. To achieve
547  this, define a dataset which generates rounds of `num_consumers` consecutive
548  similar-sized batches, then enable round-robin reads by setting
549  `consumer_index` and `num_consumers`.
550
551  Consumers read data by cycling through all workers, reading one element from
552  each. First, each consumer will read an element from the first worker, then
553  each consumer will read an element from the second worker, and so on.
554
555  NOTE: To keep consumers in sync, round robin data consumption requires that
556  the dataset have infinite cardinality. You can get this by adding `.repeat()`
557  at the end of the dataset definition.
558
559  **Keras and Distribution Strategies**
560
561  The dataset produced by the `distribute` transformation can be passed to
562  Keras' `Model.fit` or Distribution Strategy's
563  `tf.distribute.Strategy.experimental_distribute_dataset` like any other
564  `tf.data.Dataset`. We recommend setting a `job_name` on the call to
565  `distribute` so that if there are multiple workers, they read data from the
566  same job. Note that the autosharding normally performed by
567  `experimental_distribute_dataset` will be disabled when setting a `job_name`,
568  since sharing the job already results in splitting data across the workers.
569  When using a shared job, data will be dynamically balanced across workers, so
570  that they reach end of input about the same time. This results in better
571  worker utilization than with autosharding, where each worker processes an
572  independent set of files, and some workers may run out of data earlier than
573  others.
574
575  Args:
576    processing_mode: A string specifying the policy for how data should be
577      processed by tf.data workers. Can be either "parallel_epochs" to have
578      each tf.data worker process a copy of the dataset, or
579      "distributed_epoch" to split a single iteration of the dataset across
580      all the workers.
581    service: A string indicating how to connect to the tf.data service. The
582      string should be in the format "protocol://address", e.g.
583      "grpc://localhost:5000".
584    job_name: (Optional.) The name of the job. This argument makes it possible
585      for multiple datasets to share the same job. The default behavior is that
586      the dataset creates anonymous, exclusively owned jobs.
587    consumer_index: (Optional.) The index of the consumer in the range from
588      `0` to `num_consumers`. Must be specified alongside `num_consumers`.
589      When specified, consumers will read from the job in a strict round-robin
590      order, instead of the default first-come-first-served order.
591    num_consumers: (Optional.) The number of consumers which will consume from
592      the job. Must be specified alongside `consumer_index`. When specified,
593      consumers will read from the job in a strict round-robin order, instead
594      of the default first-come-first-served order. When `num_consumers` is
595      specified, the dataset must have infinite cardinality to prevent a
596      producer from running out of data early and causing consumers to go out of
597      sync.
598    max_outstanding_requests: (Optional.) A limit on how many elements may be
599      requested at the same time. You can use this option to control the amount
600      of memory used, since `distribute` won't use more than `element_size` *
601      `max_outstanding_requests` of memory.
602    data_transfer_protocol: (Optional.) The protocol to use for transferring
603      data with the tf.data service, e.g. "grpc".
604
605  Returns:
606    Dataset: A `Dataset` of the elements produced by the data service.
607  """
608  return _distribute(
609      processing_mode=processing_mode,
610      service=service,
611      job_name=job_name,
612      consumer_index=consumer_index,
613      num_consumers=num_consumers,
614      max_outstanding_requests=max_outstanding_requests,
615      data_transfer_protocol=data_transfer_protocol)
616
617
618@tf_export("data.experimental.service.register_dataset")
619def register_dataset(service, dataset):
620  """Registers a dataset with the tf.data service.
621
622  `register_dataset` registers a dataset with the tf.data service so that
623  datasets can be created later with
624  `tf.data.experimental.service.from_dataset_id`. This is useful when the
625  dataset
626  is registered by one process, then used in another process. When the same
627  process is both registering and reading from the dataset, it is simpler to use
628  `tf.data.experimental.service.distribute` instead.
629
630  If the dataset is already registered with the tf.data service,
631  `register_dataset` returns the already-registered dataset's id.
632
633  >>> dispatcher = tf.data.experimental.service.DispatchServer()
634  >>> dispatcher_address = dispatcher.target.split("://")[1]
635  >>> worker = tf.data.experimental.service.WorkerServer(
636  ...     tf.data.experimental.service.WorkerConfig(
637  ...         dispatcher_address=dispatcher_address))
638  >>> dataset = tf.data.Dataset.range(10)
639  >>> dataset_id = tf.data.experimental.service.register_dataset(
640  ...     dispatcher.target, dataset)
641  >>> dataset = tf.data.experimental.service.from_dataset_id(
642  ...     processing_mode="parallel_epochs",
643  ...     service=dispatcher.target,
644  ...     dataset_id=dataset_id,
645  ...     element_spec=dataset.element_spec)
646  >>> print(list(dataset.as_numpy_iterator()))
647  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
648
649  Args:
650    service: A string indicating how to connect to the tf.data service. The
651      string should be in the format "protocol://address", e.g.
652      "grpc://localhost:5000".
653    dataset: A `tf.data.Dataset` to register with the tf.data service.
654
655  Returns:
656    A scalar int64 tensor of the registered dataset's id.
657  """
658  protocol, address = _parse_service(service)
659  external_state_policy = dataset.options().experimental_external_state_policy
660  if external_state_policy is None:
661    external_state_policy = ExternalStatePolicy.WARN
662
663  # Compress the dataset elements to reduce the amount of data that needs to
664  # be sent over the network.
665  dataset = dataset.map(
666      lambda *x: compression_ops.compress(x),
667      num_parallel_calls=dataset_ops.AUTOTUNE)
668  dataset = dataset.prefetch(dataset_ops.AUTOTUNE)
669  # Apply options so that the dataset executed in the tf.data service will
670  # be optimized and support autotuning.
671  dataset = dataset._apply_options()  # pylint: disable=protected-access
672
673  dataset_id = gen_experimental_dataset_ops.register_dataset(
674      dataset._variant_tensor,  # pylint: disable=protected-access
675      address=address,
676      protocol=protocol,
677      external_state_policy=external_state_policy.value)
678
679  return dataset_id
680
681
682@tf_export("data.experimental.service.from_dataset_id")
683def from_dataset_id(processing_mode,
684                    service,
685                    dataset_id,
686                    element_spec=None,
687                    job_name=None,
688                    consumer_index=None,
689                    num_consumers=None,
690                    max_outstanding_requests=None):
691  """Creates a dataset which reads data from the tf.data service.
692
693  This is useful when the dataset is registered by one process, then used in
694  another process. When the same process is both registering and reading from
695  the dataset, it is simpler to use `tf.data.experimental.service.distribute`
696  instead.
697
698  Before using `from_dataset_id`, the dataset must have been registered with the
699  tf.data service using `tf.data.experimental.service.register_dataset`.
700  `register_dataset` returns a dataset id for the registered dataset. That is
701  the `dataset_id` which should be passed to `from_dataset_id`.
702
703  The `element_spec` argument indicates the `tf.TypeSpec`s for the elements
704  produced by the dataset. Currently `element_spec` must be explicitly
705  specified, and match the dataset registered under `dataset_id`. `element_spec`
706  defaults to `None` so that in the future we can support automatically
707  discovering the `element_spec` by querying the tf.data service.
708
709  `tf.data.experimental.service.distribute` is a convenience method which
710  combines `register_dataset` and `from_dataset_id` into a dataset
711  transformation.
712  See the documentation for `tf.data.experimental.service.distribute` for more
713  detail about how `from_dataset_id` works.
714
715  >>> dispatcher = tf.data.experimental.service.DispatchServer()
716  >>> dispatcher_address = dispatcher.target.split("://")[1]
717  >>> worker = tf.data.experimental.service.WorkerServer(
718  ...     tf.data.experimental.service.WorkerConfig(
719  ...         dispatcher_address=dispatcher_address))
720  >>> dataset = tf.data.Dataset.range(10)
721  >>> dataset_id = tf.data.experimental.service.register_dataset(
722  ...     dispatcher.target, dataset)
723  >>> dataset = tf.data.experimental.service.from_dataset_id(
724  ...     processing_mode="parallel_epochs",
725  ...     service=dispatcher.target,
726  ...     dataset_id=dataset_id,
727  ...     element_spec=dataset.element_spec)
728  >>> print(list(dataset.as_numpy_iterator()))
729  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
730
731  Args:
732    processing_mode: A string specifying the policy for how data should be
733      processed by tf.data workers. Can be either "parallel_epochs" to have
734      each tf.data worker process a copy of the dataset, or
735      "distributed_epoch" to split a single iteration of the dataset across
736      all the workers.
737    service: A string indicating how to connect to the tf.data service. The
738      string should be in the format "protocol://address", e.g.
739      "grpc://localhost:5000".
740    dataset_id: The id of the dataset to read from. This id is returned by
741      `register_dataset` when the dataset is registered with the tf.data
742      service.
743    element_spec: A nested structure of `tf.TypeSpec`s representing the type of
744      elements produced by the dataset. Use `tf.data.Dataset.element_spec` to
745      see the element spec for a given dataset.
746    job_name: (Optional.) The name of the job. This argument makes it possible
747      for multiple datasets to share the same job. The default behavior is that
748      the dataset creates anonymous, exclusively owned jobs.
749    consumer_index: (Optional.) The index of the consumer in the range from
750      `0` to `num_consumers`. Must be specified alongside `num_consumers`.
751      When specified, consumers will read from the job in a strict round-robin
752      order, instead of the default first-come-first-served order.
753    num_consumers: (Optional.) The number of consumers which will consume from
754      the job. Must be specified alongside `consumer_index`. When specified,
755      consumers will read from the job in a strict round-robin order, instead
756      of the default first-come-first-served order. When `num_consumers` is
757      specified, the dataset must have infinite cardinality to prevent a
758      producer from running out of data early and causing consumers to go out of
759      sync.
760    max_outstanding_requests: (Optional.) A limit on how many elements may be
761      requested at the same time. You can use this option to control the amount
762      of memory used, since `distribute` won't use more than `element_size` *
763      `max_outstanding_requests` of memory.
764
765  Returns:
766    A `tf.data.Dataset` which reads from the tf.data service.
767  """
768  return _from_dataset_id(
769      processing_mode=processing_mode,
770      service=service,
771      dataset_id=dataset_id,
772      element_spec=element_spec,
773      job_name=job_name,
774      consumer_index=consumer_index,
775      num_consumers=num_consumers,
776      max_outstanding_requests=max_outstanding_requests)
777