1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python API for executing a tf.data.Dataset using a tf.data service.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21 22import six 23 24from tensorflow.python import tf2 25from tensorflow.python.data.experimental.ops import compression_ops 26from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy 27from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy 28from tensorflow.python.data.ops import dataset_ops 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_spec 32from tensorflow.python.ops import gen_experimental_dataset_ops 33from tensorflow.python.util.tf_export import tf_export 34 35 36class ProcessingMode(object): 37 """tf.data service processing modes.""" 38 39 PARALLEL_EPOCHS = "parallel_epochs" 40 DISTRIBUTED_EPOCH = "distributed_epoch" 41 42 @staticmethod 43 def validate(mode): 44 """Raises a ValueError if the given object is not a valid processing mode.""" 45 valid_modes = [ 46 ProcessingMode.PARALLEL_EPOCHS, ProcessingMode.DISTRIBUTED_EPOCH 47 ] 48 if mode not in valid_modes: 49 raise ValueError( 50 "{0} is not a valid processing mode. Valid modes: {1}".format( 51 mode, valid_modes)) 52 53 54class _DataServiceDatasetV2(dataset_ops.DatasetSource): 55 """A `Dataset` that reads elements from the tf.data service.""" 56 57 def __init__(self, 58 dataset_id, 59 processing_mode, 60 address, 61 protocol, 62 data_transfer_protocol, 63 job_name=None, 64 consumer_index=None, 65 num_consumers=None, 66 max_outstanding_requests=None, 67 task_refresh_interval_hint_ms=None): 68 """Constructs a _DataServiceDatasetV2. 69 70 Args: 71 dataset_id: The dataset id for the dataset to read from. 72 processing_mode: A string specifying the policy for how data should be 73 processed by tf.data workers. Can be either "parallel_epochs" to have 74 each tf.data worker process a copy of the dataset, or 75 "distributed_epoch" to split a single iteration of the dataset across 76 all the workers. 77 address: The tf.data service address, e.g. "localhost:5000". 78 protocol: The protocol to use for communicating with the tf.data service, 79 e.g. "grpc". 80 data_transfer_protocol: The protocol to use for transferring data with the 81 tf.data service, e.g. "grpc". 82 job_name: (Optional.) The name of the job. This argument makes it possible 83 for multiple datasets to share the same job. The default behavior is 84 that the dataset creates anonymous, exclusively owned jobs. 85 consumer_index: (Optional.) The index of the consumer in the range from 86 `0` to `num_consumers`. Must be specified alongside `num_consumers`. 87 When specified, consumers will read from the job in a strict round-robin 88 order, instead of the default first-come-first-served order. 89 num_consumers: (Optional.) The number of consumers which will consume from 90 the job. Must be specified alongside `consumer_index`. When specified, 91 consumers will read from the job in a strict round-robin order, instead 92 of the default first-come-first-served order. When `num_consumers` is 93 specified, the dataset must have infinite cardinality to prevent a 94 producer from running out of data early and causing consumers to go out 95 of sync. 96 max_outstanding_requests: (Optional.) A limit on how many elements may be 97 requested at the same time. You can use this option to control the 98 amount of memory used, since `distribute` won't use more than 99 `element_size` * `max_outstanding_requests` of memory. 100 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query 101 the dispatcher for task changes. 102 """ 103 if consumer_index is None != num_consumers is None: 104 raise ValueError( 105 "Must either set both consumer_index and num_consumers, or neither. ", 106 "consumer_index: ", consumer_index, ", num_consumers: ", 107 num_consumers) 108 if num_consumers is not None and job_name is None: 109 raise ValueError("job_name must be set when setting num_consumers") 110 111 if job_name is None: 112 job_name = "" 113 if max_outstanding_requests is None: 114 max_outstanding_requests = dataset_ops.AUTOTUNE 115 if task_refresh_interval_hint_ms is None: 116 task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE 117 118 self._dataset_id = ops.convert_to_tensor( 119 dataset_id, dtype=dtypes.int64, name="dataset_id") 120 self._processing_mode = ops.convert_to_tensor( 121 processing_mode, dtype=dtypes.string, name="processing_mode") 122 self._address = ops.convert_to_tensor( 123 address, dtype=dtypes.string, name="address") 124 self._protocol = ops.convert_to_tensor( 125 protocol, dtype=dtypes.string, name="protocol") 126 self._job_name = ops.convert_to_tensor( 127 job_name, dtype=dtypes.string, name="job_name") 128 self._consumer_index = ops.convert_to_tensor( 129 -1 if consumer_index is None else consumer_index, 130 dtype=dtypes.int64, 131 name="consumer_index") 132 self._num_consumers = ops.convert_to_tensor( 133 -1 if num_consumers is None else num_consumers, 134 dtype=dtypes.int64, 135 name="num_consumers") 136 self._max_outstanding_requests = ops.convert_to_tensor( 137 max_outstanding_requests, 138 dtype=dtypes.int64, 139 name="max_outstanding_requests") 140 # Datasets executed by the tf.data service produce compressed elements 141 # represented by scalar DT_VARIANTs. 142 self._element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) 143 144 compat_kwargs = {} 145 if data_transfer_protocol is not None: 146 compat_kwargs["data_transfer_protocol"] = data_transfer_protocol 147 148 if num_consumers is None: 149 variant_tensor = gen_experimental_dataset_ops.data_service_dataset( 150 dataset_id=self._dataset_id, 151 processing_mode=self._processing_mode, 152 address=self._address, 153 protocol=self._protocol, 154 job_name=self._job_name, 155 max_outstanding_requests=self._max_outstanding_requests, 156 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 157 iteration_counter=gen_experimental_dataset_ops 158 .dummy_iteration_counter(), 159 **compat_kwargs, 160 **self._flat_structure) 161 else: 162 variant_tensor = gen_experimental_dataset_ops.data_service_dataset_v2( 163 dataset_id=self._dataset_id, 164 processing_mode=self._processing_mode, 165 address=self._address, 166 protocol=self._protocol, 167 job_name=self._job_name, 168 consumer_index=self._consumer_index, 169 num_consumers=self._num_consumers, 170 max_outstanding_requests=self._max_outstanding_requests, 171 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 172 iteration_counter=gen_experimental_dataset_ops 173 .dummy_iteration_counter(), 174 **compat_kwargs, 175 **self._flat_structure) 176 super(_DataServiceDatasetV2, self).__init__(variant_tensor) 177 178 @property 179 def element_spec(self): 180 return self._element_spec 181 182 183class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter): 184 """A `Dataset` that executes its input through the tf.data service.""" 185 186 @functools.wraps(_DataServiceDatasetV2.__init__) 187 def __init__(self, dataset_id, processing_mode, address, protocol, 188 data_transfer_protocol, job_name, consumer_index, num_consumers, 189 max_outstanding_requests, task_refresh_interval_hint_ms): 190 191 self._wrapped = _DataServiceDatasetV2( 192 dataset_id=dataset_id, 193 processing_mode=processing_mode, 194 address=address, 195 protocol=protocol, 196 data_transfer_protocol=data_transfer_protocol, 197 job_name=job_name, 198 consumer_index=consumer_index, 199 num_consumers=num_consumers, 200 max_outstanding_requests=max_outstanding_requests, 201 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) 202 super(_DataServiceDatasetV1, self).__init__(self._wrapped) 203 204 205if tf2.enabled(): 206 _DataServiceDataset = _DataServiceDatasetV2 207else: 208 _DataServiceDataset = _DataServiceDatasetV1 209 210 211def _parse_service(service): 212 """Parses a tf.data service string into a (protocol, address) tuple. 213 214 Args: 215 service: A string in the format "protocol://address". 216 217 Returns: 218 The parsed (protocol, address) tuple 219 """ 220 if not isinstance(service, six.string_types): 221 raise ValueError( 222 "service must be a string, but service was of type {0}. service={1}" 223 .format(type(service), service)) 224 if not service: 225 raise ValueError("service must not be empty") 226 parts = service.split("://") 227 if len(parts) == 1: 228 raise ValueError("service string %s does not begin with a protocol. " 229 "The service should be in the format " 230 "<protocol>://<address>, e.g. grpc://localhost:5000" % 231 service) 232 if len(parts) > 2: 233 raise ValueError("malformed service string has multiple '://': %s" % 234 service) 235 return parts 236 237 238def _from_dataset_id(processing_mode, 239 service, 240 dataset_id, 241 element_spec, 242 job_name=None, 243 consumer_index=None, 244 num_consumers=None, 245 max_outstanding_requests=None, 246 task_refresh_interval_hint_ms=None, 247 data_transfer_protocol=None): 248 """Creates a dataset which reads data from the tf.data service. 249 250 This transformation is similar to `from_dataset_id`, but supports additional 251 parameters which we do not yet want to add to the public Python API. 252 253 Args: 254 processing_mode: A string specifying the policy for how data should be 255 processed by tf.data workers. Can be either "parallel_epochs" to have 256 each tf.data worker process a copy of the dataset, or 257 "distributed_epoch" to split a single iteration of the dataset across 258 all the workers. 259 service: A string indicating how to connect to the tf.data service. The 260 string should be in the format "<protocol>://<address>", e.g. 261 "grpc://localhost:5000". 262 dataset_id: The id of the dataset to read from. This id is returned by 263 `register_dataset` when the dataset is registered with the tf.data 264 service. 265 element_spec: A nested structure of `tf.TypeSpec`s representing the type of 266 elements produced by the dataset. Use `tf.data.Dataset.element_spec` to 267 see the element spec for a given dataset. 268 job_name: (Optional.) The name of the job. This argument makes it possible 269 for multiple datasets to share the same job. The default behavior is that 270 the dataset creates anonymous, exclusively owned jobs. 271 consumer_index: (Optional.) The index of the consumer in the range from 272 `0` to `num_consumers`. Must be specified alongside `num_consumers`. 273 When specified, consumers will read from the job in a strict round-robin 274 order, instead of the default first-come-first-served order. 275 num_consumers: (Optional.) The number of consumers which will consume from 276 the job. Must be specified alongside `consumer_index`. When specified, 277 consumers will read from the job in a strict round-robin order, instead 278 of the default first-come-first-served order. When `num_consumers` is 279 specified, the dataset must have infinite cardinality to prevent a 280 producer from running out of data early and causing consumers to go out of 281 sync. 282 max_outstanding_requests: (Optional.) A limit on how many elements may be 283 requested at the same time. You can use this option to control the amount 284 of memory used, since `distribute` won't use more than `element_size` * 285 `max_outstanding_requests` of memory. 286 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the 287 dispatcher for task changes. 288 data_transfer_protocol: (Optional.) The protocol to use for transferring 289 data with the tf.data service. 290 291 Returns: 292 A `tf.data.Dataset` which reads from the tf.data service. 293 """ 294 ProcessingMode.validate(processing_mode) 295 if job_name is not None: 296 if not isinstance(job_name, six.string_types): 297 raise ValueError("job_name must be a string, but job_name was of type " 298 "{0}. job_name={1}".format(type(job_name), job_name)) 299 if not job_name: 300 raise ValueError("job_name must not be empty") 301 if element_spec is None: 302 raise ValueError("element_spec must not be None") 303 protocol, address = _parse_service(service) 304 305 dataset = _DataServiceDataset( 306 dataset_id=dataset_id, 307 processing_mode=processing_mode, 308 address=address, 309 protocol=protocol, 310 data_transfer_protocol=data_transfer_protocol, 311 job_name=job_name, 312 consumer_index=consumer_index, 313 num_consumers=num_consumers, 314 max_outstanding_requests=max_outstanding_requests, 315 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) 316 dataset = dataset.map( 317 lambda x: compression_ops.uncompress(x, output_spec=element_spec), 318 num_parallel_calls=dataset_ops.AUTOTUNE) 319 320 # Disable autosharding for shared jobs. 321 if job_name: 322 options = dataset_ops.Options() 323 options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF 324 dataset = dataset.with_options(options) 325 return dataset 326 327 328def _distribute(processing_mode, 329 service, 330 job_name=None, 331 consumer_index=None, 332 num_consumers=None, 333 max_outstanding_requests=None, 334 task_refresh_interval_hint_ms=None, 335 data_transfer_protocol=None): 336 """A transformation that moves dataset processing to the tf.data service. 337 338 This transformation is similar to `distribute`, but supports additional 339 parameters which we do not yet want to add to the public Python API. 340 341 Args: 342 processing_mode: A string specifying the policy for how data should be 343 processed by tf.data workers. Can be either "parallel_epochs" to have 344 each tf.data worker process a copy of the dataset, or 345 "distributed_epoch" to split a single iteration of the dataset across 346 all the workers. 347 service: A string indicating how to connect to the tf.data service. The 348 string should be in the format "<protocol>://<address>", e.g. 349 "grpc://localhost:5000". 350 job_name: (Optional.) The name of the job. This argument makes it possible 351 for multiple datasets to share the same job. The default behavior is that 352 the dataset creates anonymous, exclusively owned jobs. 353 consumer_index: (Optional.) The index of the consumer in the range from 354 `0` to `num_consumers`. Must be specified alongside `num_consumers`. 355 When specified, consumers will read from the job in a strict round-robin 356 order, instead of the default first-come-first-served order. 357 num_consumers: (Optional.) The number of consumers which will consume from 358 the job. Must be specified alongside `consumer_index`. When specified, 359 consumers will read from the job in a strict round-robin order, instead 360 of the default first-come-first-served order. When `num_consumers` is 361 specified, the dataset must have infinite cardinality to prevent a 362 producer from running out of data early and causing consumers to go out of 363 sync. 364 max_outstanding_requests: (Optional.) A limit on how many elements may be 365 requested at the same time. You can use this option to control the amount 366 of memory used, since `distribute` won't use more than `element_size` * 367 `max_outstanding_requests` of memory. 368 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the 369 dispatcher for task changes. 370 data_transfer_protocol: (Optional.) The protocol to use for transferring 371 data with the tf.data service. 372 373 Returns: 374 Dataset: A `Dataset` of the elements produced by the data service. 375 """ 376 ProcessingMode.validate(processing_mode) 377 378 def _apply_fn(dataset): # pylint: disable=missing-docstring 379 dataset_id = register_dataset(service, dataset) 380 return _from_dataset_id( 381 processing_mode, 382 service, 383 dataset_id, 384 dataset.element_spec, 385 job_name=job_name, 386 consumer_index=consumer_index, 387 num_consumers=num_consumers, 388 max_outstanding_requests=max_outstanding_requests, 389 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 390 data_transfer_protocol=data_transfer_protocol) 391 392 return _apply_fn 393 394 395@tf_export("data.experimental.service.distribute") 396def distribute(processing_mode, 397 service, 398 job_name=None, 399 consumer_index=None, 400 num_consumers=None, 401 max_outstanding_requests=None, 402 data_transfer_protocol=None): 403 """A transformation that moves dataset processing to the tf.data service. 404 405 When you iterate over a dataset containing the `distribute` transformation, 406 the tf.data service creates a "job" which produces data for the dataset 407 iteration. 408 409 The tf.data service uses a cluster of workers to prepare data for training 410 your model. 411 The `processing_mode` argument to `tf.data.experimental.service.distribute` 412 describes how to leverage multiple workers to process the input dataset. 413 Currently, there are two processing modes to choose from: "distributed_epoch" 414 and "parallel_epochs". 415 416 "distributed_epoch" means that the dataset will be split across all tf.data 417 service workers. 418 The dispatcher produces "splits" for the dataset and sends them to workers for 419 further processing. For example, if a dataset begins with a list of filenames, 420 the dispatcher will iterate through the filenames and send the filenames to 421 tf.data workers, which will perform the rest of the dataset transformations on 422 those files. "distributed_epoch" is useful when your model needs to see each 423 element of the dataset exactly once, or if it needs to see the data in a 424 generally-sequential order. "distributed_epoch" only works for datasets with 425 splittable sources, such as `Dataset.from_tensor_slices`, 426 `Dataset.list_files`, or `Dataset.range`. 427 428 "parallel_epochs" means that the entire input dataset will be processed 429 independently by each of the tf.data service workers. 430 For this reason, it is important to shuffle data (e.g. filenames) 431 non-deterministically, so that each worker will process the elements of the 432 dataset in a different order. "parallel_epochs" can be used to distribute 433 datasets that aren't splittable. 434 435 With two workers, "parallel_epochs" will produce every element of the dataset 436 twice: 437 438 >>> dispatcher = tf.data.experimental.service.DispatchServer() 439 >>> dispatcher_address = dispatcher.target.split("://")[1] 440 >>> # Start two workers 441 >>> workers = [ 442 ... tf.data.experimental.service.WorkerServer( 443 ... tf.data.experimental.service.WorkerConfig( 444 ... dispatcher_address=dispatcher_address)) for _ in range(2) 445 ... ] 446 >>> dataset = tf.data.Dataset.range(10) 447 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 448 ... processing_mode="parallel_epochs", service=dispatcher.target)) 449 >>> print(sorted(list(dataset.as_numpy_iterator()))) 450 [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9] 451 452 "distributed_epoch", on the other hand, will still produce each element once: 453 454 >>> dispatcher = tf.data.experimental.service.DispatchServer() 455 >>> dispatcher_address = dispatcher.target.split("://")[1] 456 >>> workers = [ 457 ... tf.data.experimental.service.WorkerServer( 458 ... tf.data.experimental.service.WorkerConfig( 459 ... dispatcher_address=dispatcher_address)) for _ in range(2) 460 ... ] 461 >>> dataset = tf.data.Dataset.range(10) 462 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 463 ... processing_mode="distributed_epoch", service=dispatcher.target)) 464 >>> print(sorted(list(dataset.as_numpy_iterator()))) 465 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 466 467 When using `apply(tf.data.experimental.service.distribute(...))`, the dataset 468 before the `apply` transformation executes within the tf.data service, while 469 the operations after `apply` happen within the local process. 470 471 >>> dispatcher = tf.data.experimental.service.DispatchServer() 472 >>> dispatcher_address = dispatcher.target.split("://")[1] 473 >>> workers = [ 474 ... tf.data.experimental.service.WorkerServer( 475 ... tf.data.experimental.service.WorkerConfig( 476 ... dispatcher_address=dispatcher_address)) for _ in range(2) 477 ... ] 478 >>> dataset = tf.data.Dataset.range(5) 479 >>> dataset = dataset.map(lambda x: x*x) 480 >>> dataset = dataset.apply( 481 ... tf.data.experimental.service.distribute("parallel_epochs", 482 ... dispatcher.target)) 483 >>> dataset = dataset.map(lambda x: x+1) 484 >>> print(sorted(list(dataset.as_numpy_iterator()))) 485 [1, 1, 2, 2, 5, 5, 10, 10, 17, 17] 486 487 In the above example, the dataset operations (before applying the `distribute` 488 function on the elements) will be executed on the tf.data workers, 489 and the elements are provided over RPC. The remaining transformations 490 (after the call to `distribute`) will be executed locally. The dispatcher 491 and the workers will bind to usused free ports (which are chosen at random), 492 in order to communicate with each other. However, to bind them to specific 493 ports, the `port` parameter can be passed. 494 495 The `job_name` argument allows jobs to be shared across multiple 496 datasets. Instead of each dataset creating its own job, all 497 datasets with the same `job_name` will consume from the same job. A new job 498 will be created for each iteration of the dataset (with each repetition of 499 `Dataset.repeat` counting as a new iteration). Suppose the `DispatchServer` 500 is serving on `localhost:5000` and two training workers (in either a single 501 client or multi-client setup) iterate over the below dataset, and there is a 502 single tf.data worker: 503 504 ``` 505 range5_dataset = tf.data.Dataset.range(5) 506 dataset = range5_dataset.apply(tf.data.experimental.service.distribute( 507 "parallel_epochs", "grpc://localhost:5000", job_name="my_job_name")) 508 for iteration in range(3): 509 print(list(dataset)) 510 ``` 511 512 The elements of each job will be split between the two processes, with 513 elements being consumed by the processes on a first-come first-served basis. 514 One possible result is that process 1 prints 515 516 ``` 517 [0, 2, 4] 518 [0, 1, 3] 519 [1] 520 ``` 521 522 and process 2 prints 523 524 ``` 525 [1, 3] 526 [2, 4] 527 [0, 2, 3, 4] 528 ``` 529 530 Job names must not be re-used across different training jobs within the 531 lifetime of the tf.data service. In general, the tf.data service is expected 532 to live for the duration of a single training job. 533 To use the tf.data service with multiple training jobs, make sure to use 534 different job names to avoid conflicts. For example, suppose a training job 535 calls `distribute` with `job_name="job"` and reads until end of input. If 536 another independent job connects to the same tf.data service and tries to read 537 from `job_name="job"`, it will immediately receive end of input, without 538 getting any data. 539 540 **Round Robin data consumption** 541 542 By default, when multiple consumers read from the same job, they receive data 543 on a first-come first-served basis. In some use cases, it works better to use 544 a strict round-robin order. For example, the tf.data service can be used to 545 coordinate example sizes across a cluster during sychronous training, so that 546 during each step all replicas train on similar-sized elements. To achieve 547 this, define a dataset which generates rounds of `num_consumers` consecutive 548 similar-sized batches, then enable round-robin reads by setting 549 `consumer_index` and `num_consumers`. 550 551 Consumers read data by cycling through all workers, reading one element from 552 each. First, each consumer will read an element from the first worker, then 553 each consumer will read an element from the second worker, and so on. 554 555 NOTE: To keep consumers in sync, round robin data consumption requires that 556 the dataset have infinite cardinality. You can get this by adding `.repeat()` 557 at the end of the dataset definition. 558 559 **Keras and Distribution Strategies** 560 561 The dataset produced by the `distribute` transformation can be passed to 562 Keras' `Model.fit` or Distribution Strategy's 563 `tf.distribute.Strategy.experimental_distribute_dataset` like any other 564 `tf.data.Dataset`. We recommend setting a `job_name` on the call to 565 `distribute` so that if there are multiple workers, they read data from the 566 same job. Note that the autosharding normally performed by 567 `experimental_distribute_dataset` will be disabled when setting a `job_name`, 568 since sharing the job already results in splitting data across the workers. 569 When using a shared job, data will be dynamically balanced across workers, so 570 that they reach end of input about the same time. This results in better 571 worker utilization than with autosharding, where each worker processes an 572 independent set of files, and some workers may run out of data earlier than 573 others. 574 575 Args: 576 processing_mode: A string specifying the policy for how data should be 577 processed by tf.data workers. Can be either "parallel_epochs" to have 578 each tf.data worker process a copy of the dataset, or 579 "distributed_epoch" to split a single iteration of the dataset across 580 all the workers. 581 service: A string indicating how to connect to the tf.data service. The 582 string should be in the format "protocol://address", e.g. 583 "grpc://localhost:5000". 584 job_name: (Optional.) The name of the job. This argument makes it possible 585 for multiple datasets to share the same job. The default behavior is that 586 the dataset creates anonymous, exclusively owned jobs. 587 consumer_index: (Optional.) The index of the consumer in the range from 588 `0` to `num_consumers`. Must be specified alongside `num_consumers`. 589 When specified, consumers will read from the job in a strict round-robin 590 order, instead of the default first-come-first-served order. 591 num_consumers: (Optional.) The number of consumers which will consume from 592 the job. Must be specified alongside `consumer_index`. When specified, 593 consumers will read from the job in a strict round-robin order, instead 594 of the default first-come-first-served order. When `num_consumers` is 595 specified, the dataset must have infinite cardinality to prevent a 596 producer from running out of data early and causing consumers to go out of 597 sync. 598 max_outstanding_requests: (Optional.) A limit on how many elements may be 599 requested at the same time. You can use this option to control the amount 600 of memory used, since `distribute` won't use more than `element_size` * 601 `max_outstanding_requests` of memory. 602 data_transfer_protocol: (Optional.) The protocol to use for transferring 603 data with the tf.data service, e.g. "grpc". 604 605 Returns: 606 Dataset: A `Dataset` of the elements produced by the data service. 607 """ 608 return _distribute( 609 processing_mode=processing_mode, 610 service=service, 611 job_name=job_name, 612 consumer_index=consumer_index, 613 num_consumers=num_consumers, 614 max_outstanding_requests=max_outstanding_requests, 615 data_transfer_protocol=data_transfer_protocol) 616 617 618@tf_export("data.experimental.service.register_dataset") 619def register_dataset(service, dataset): 620 """Registers a dataset with the tf.data service. 621 622 `register_dataset` registers a dataset with the tf.data service so that 623 datasets can be created later with 624 `tf.data.experimental.service.from_dataset_id`. This is useful when the 625 dataset 626 is registered by one process, then used in another process. When the same 627 process is both registering and reading from the dataset, it is simpler to use 628 `tf.data.experimental.service.distribute` instead. 629 630 If the dataset is already registered with the tf.data service, 631 `register_dataset` returns the already-registered dataset's id. 632 633 >>> dispatcher = tf.data.experimental.service.DispatchServer() 634 >>> dispatcher_address = dispatcher.target.split("://")[1] 635 >>> worker = tf.data.experimental.service.WorkerServer( 636 ... tf.data.experimental.service.WorkerConfig( 637 ... dispatcher_address=dispatcher_address)) 638 >>> dataset = tf.data.Dataset.range(10) 639 >>> dataset_id = tf.data.experimental.service.register_dataset( 640 ... dispatcher.target, dataset) 641 >>> dataset = tf.data.experimental.service.from_dataset_id( 642 ... processing_mode="parallel_epochs", 643 ... service=dispatcher.target, 644 ... dataset_id=dataset_id, 645 ... element_spec=dataset.element_spec) 646 >>> print(list(dataset.as_numpy_iterator())) 647 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 648 649 Args: 650 service: A string indicating how to connect to the tf.data service. The 651 string should be in the format "protocol://address", e.g. 652 "grpc://localhost:5000". 653 dataset: A `tf.data.Dataset` to register with the tf.data service. 654 655 Returns: 656 A scalar int64 tensor of the registered dataset's id. 657 """ 658 protocol, address = _parse_service(service) 659 external_state_policy = dataset.options().experimental_external_state_policy 660 if external_state_policy is None: 661 external_state_policy = ExternalStatePolicy.WARN 662 663 # Compress the dataset elements to reduce the amount of data that needs to 664 # be sent over the network. 665 dataset = dataset.map( 666 lambda *x: compression_ops.compress(x), 667 num_parallel_calls=dataset_ops.AUTOTUNE) 668 dataset = dataset.prefetch(dataset_ops.AUTOTUNE) 669 # Apply options so that the dataset executed in the tf.data service will 670 # be optimized and support autotuning. 671 dataset = dataset._apply_options() # pylint: disable=protected-access 672 673 dataset_id = gen_experimental_dataset_ops.register_dataset( 674 dataset._variant_tensor, # pylint: disable=protected-access 675 address=address, 676 protocol=protocol, 677 external_state_policy=external_state_policy.value) 678 679 return dataset_id 680 681 682@tf_export("data.experimental.service.from_dataset_id") 683def from_dataset_id(processing_mode, 684 service, 685 dataset_id, 686 element_spec=None, 687 job_name=None, 688 consumer_index=None, 689 num_consumers=None, 690 max_outstanding_requests=None): 691 """Creates a dataset which reads data from the tf.data service. 692 693 This is useful when the dataset is registered by one process, then used in 694 another process. When the same process is both registering and reading from 695 the dataset, it is simpler to use `tf.data.experimental.service.distribute` 696 instead. 697 698 Before using `from_dataset_id`, the dataset must have been registered with the 699 tf.data service using `tf.data.experimental.service.register_dataset`. 700 `register_dataset` returns a dataset id for the registered dataset. That is 701 the `dataset_id` which should be passed to `from_dataset_id`. 702 703 The `element_spec` argument indicates the `tf.TypeSpec`s for the elements 704 produced by the dataset. Currently `element_spec` must be explicitly 705 specified, and match the dataset registered under `dataset_id`. `element_spec` 706 defaults to `None` so that in the future we can support automatically 707 discovering the `element_spec` by querying the tf.data service. 708 709 `tf.data.experimental.service.distribute` is a convenience method which 710 combines `register_dataset` and `from_dataset_id` into a dataset 711 transformation. 712 See the documentation for `tf.data.experimental.service.distribute` for more 713 detail about how `from_dataset_id` works. 714 715 >>> dispatcher = tf.data.experimental.service.DispatchServer() 716 >>> dispatcher_address = dispatcher.target.split("://")[1] 717 >>> worker = tf.data.experimental.service.WorkerServer( 718 ... tf.data.experimental.service.WorkerConfig( 719 ... dispatcher_address=dispatcher_address)) 720 >>> dataset = tf.data.Dataset.range(10) 721 >>> dataset_id = tf.data.experimental.service.register_dataset( 722 ... dispatcher.target, dataset) 723 >>> dataset = tf.data.experimental.service.from_dataset_id( 724 ... processing_mode="parallel_epochs", 725 ... service=dispatcher.target, 726 ... dataset_id=dataset_id, 727 ... element_spec=dataset.element_spec) 728 >>> print(list(dataset.as_numpy_iterator())) 729 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 730 731 Args: 732 processing_mode: A string specifying the policy for how data should be 733 processed by tf.data workers. Can be either "parallel_epochs" to have 734 each tf.data worker process a copy of the dataset, or 735 "distributed_epoch" to split a single iteration of the dataset across 736 all the workers. 737 service: A string indicating how to connect to the tf.data service. The 738 string should be in the format "protocol://address", e.g. 739 "grpc://localhost:5000". 740 dataset_id: The id of the dataset to read from. This id is returned by 741 `register_dataset` when the dataset is registered with the tf.data 742 service. 743 element_spec: A nested structure of `tf.TypeSpec`s representing the type of 744 elements produced by the dataset. Use `tf.data.Dataset.element_spec` to 745 see the element spec for a given dataset. 746 job_name: (Optional.) The name of the job. This argument makes it possible 747 for multiple datasets to share the same job. The default behavior is that 748 the dataset creates anonymous, exclusively owned jobs. 749 consumer_index: (Optional.) The index of the consumer in the range from 750 `0` to `num_consumers`. Must be specified alongside `num_consumers`. 751 When specified, consumers will read from the job in a strict round-robin 752 order, instead of the default first-come-first-served order. 753 num_consumers: (Optional.) The number of consumers which will consume from 754 the job. Must be specified alongside `consumer_index`. When specified, 755 consumers will read from the job in a strict round-robin order, instead 756 of the default first-come-first-served order. When `num_consumers` is 757 specified, the dataset must have infinite cardinality to prevent a 758 producer from running out of data early and causing consumers to go out of 759 sync. 760 max_outstanding_requests: (Optional.) A limit on how many elements may be 761 requested at the same time. You can use this option to control the amount 762 of memory used, since `distribute` won't use more than `element_size` * 763 `max_outstanding_requests` of memory. 764 765 Returns: 766 A `tf.data.Dataset` which reads from the tf.data service. 767 """ 768 return _from_dataset_id( 769 processing_mode=processing_mode, 770 service=service, 771 dataset_id=dataset_id, 772 element_spec=element_spec, 773 job_name=job_name, 774 consumer_index=consumer_index, 775 num_consumers=num_consumers, 776 max_outstanding_requests=max_outstanding_requests) 777