1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experimental API for controlling distribution in `tf.data` pipelines."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import enum
21
22from tensorflow.core.framework import dataset_options_pb2
23from tensorflow.python.data.util import options
24from tensorflow.python.util.tf_export import tf_export
25
26
27@tf_export("data.experimental.AutoShardPolicy")
28class AutoShardPolicy(enum.IntEnum):
29  """Represents the type of auto-sharding we enable.
30
31  See the `tf.data.experimental.DistributeOptions.auto_shard_policy`
32  documentation for more information.
33  """
34  OFF = -1
35  AUTO = 0
36  FILE = 1
37  DATA = 2
38
39  @classmethod
40  def _to_proto(cls, obj):
41    """Convert enum to proto."""
42    if obj == cls.OFF:
43      return dataset_options_pb2.AutoShardPolicy.OFF
44    if obj == cls.FILE:
45      return dataset_options_pb2.AutoShardPolicy.FILE
46    if obj == cls.DATA:
47      return dataset_options_pb2.AutoShardPolicy.DATA
48    if obj == cls.AUTO:
49      return dataset_options_pb2.AutoShardPolicy.AUTO
50    raise ValueError("%s._to_proto() is called with undefined enum %s." %
51                     (cls.__name__, obj.name))
52
53  @classmethod
54  def _from_proto(cls, pb):
55    """Convert proto to enum."""
56    if pb == dataset_options_pb2.AutoShardPolicy.OFF:
57      return cls.OFF
58    if pb == dataset_options_pb2.AutoShardPolicy.FILE:
59      return cls.FILE
60    if pb == dataset_options_pb2.AutoShardPolicy.DATA:
61      return cls.DATA
62    if pb == dataset_options_pb2.AutoShardPolicy.AUTO:
63      return cls.AUTO
64    raise ValueError("%s._from_proto() is called with undefined enum %s." %
65                     (cls.__name__, pb))
66
67
68@tf_export("data.experimental.ExternalStatePolicy")
69class ExternalStatePolicy(enum.Enum):
70  """Represents how to handle external state during serialization.
71
72  See the `tf.data.Options.experimental_external_state_policy` documentation
73  for more information.
74  """
75  WARN = 0
76  IGNORE = 1
77  FAIL = 2
78
79  @classmethod
80  def _to_proto(cls, obj):
81    """Convert enum to proto."""
82    if obj == cls.IGNORE:
83      return dataset_options_pb2.ExternalStatePolicy.IGNORE
84    if obj == cls.FAIL:
85      return dataset_options_pb2.ExternalStatePolicy.FAIL
86    if obj == cls.WARN:
87      return dataset_options_pb2.ExternalStatePolicy.WARN
88    raise ValueError("%s._to_proto() is called with undefined enum %s." %
89                     (cls.__name__, obj.name))
90
91  @classmethod
92  def _from_proto(cls, pb):
93    """Convert proto to enum."""
94    if pb == dataset_options_pb2.ExternalStatePolicy.IGNORE:
95      return cls.IGNORE
96    if pb == dataset_options_pb2.ExternalStatePolicy.FAIL:
97      return cls.FAIL
98    if pb == dataset_options_pb2.ExternalStatePolicy.WARN:
99      return cls.WARN
100    raise ValueError("%s._from_proto() is called with undefined enum %s." %
101                     (cls.__name__, pb))
102
103
104@tf_export("data.experimental.DistributeOptions")
105class DistributeOptions(options.OptionsBase):
106  """Represents options for distributed data processing.
107
108  You can set the distribution options of a dataset through the
109  `experimental_distribute` property of `tf.data.Options`; the property is
110  an instance of `tf.data.experimental.DistributeOptions`.
111
112  ```python
113  options = tf.data.Options()
114  options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
115  dataset = dataset.with_options(options)
116  ```
117  """
118
119  auto_shard_policy = options.create_option(
120      name="auto_shard_policy",
121      ty=AutoShardPolicy,
122      docstring="The type of sharding that auto-shard should attempt. If this "
123      "is set to FILE, then we will attempt to shard by files (each worker "
124      "will get a set of files to process). If we cannot find a set of files "
125      "to shard for at least one file per worker, we will error out. When this "
126      "option is selected, make sure that you have enough files so that each "
127      "worker gets at least one file. There will be a runtime error thrown if "
128      "there are insufficient files. "
129      "If this is set to DATA, then we will shard by elements produced by the "
130      "dataset, and each worker will process the whole dataset and discard the "
131      "portion that is not for itself. "
132      "If this is set to OFF, then we will not autoshard, and each worker will "
133      "receive a copy of the full dataset. "
134      "This option is set to AUTO by default, AUTO will attempt to first shard "
135      "by FILE, and fall back to sharding by DATA if we cannot find a set of "
136      "files to shard.",
137      default_factory=lambda: AutoShardPolicy.AUTO)
138
139  num_devices = options.create_option(
140      name="num_devices",
141      ty=int,
142      docstring=
143      "The number of devices attached to this input pipeline. This will be "
144      "automatically set by MultiDeviceIterator.")
145
146  def _to_proto(self):
147    pb = dataset_options_pb2.DistributeOptions()
148    pb.auto_shard_policy = AutoShardPolicy._to_proto(self.auto_shard_policy)  # pylint: disable=protected-access
149    if self.num_devices is not None:
150      pb.num_devices = self.num_devices
151    return pb
152
153  def _from_proto(self, pb):
154    self.auto_shard_policy = AutoShardPolicy._from_proto(pb.auto_shard_policy)  # pylint: disable=protected-access
155    if pb.WhichOneof("optional_num_devices") is not None:
156      self.num_devices = pb.num_devices
157