1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""An Optional type for representing potentially missing values."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21
22import six
23
24from tensorflow.python.data.util import structure
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.ops import gen_dataset_ops
29from tensorflow.python.util.tf_export import tf_export
30
31
32@six.add_metaclass(abc.ABCMeta)
33class Optional(object):
34  """Wraps a nested structure of tensors that may/may not be present at runtime.
35
36  An `Optional` can represent the result of an operation that may fail as a
37  value, rather than raising an exception and halting execution. For example,
38  `tf.data.experimental.get_next_as_optional` returns an `Optional` that either
39  contains the next value from a `tf.data.Iterator` if one exists, or a "none"
40  value that indicates the end of the sequence has been reached.
41  """
42
43  @abc.abstractmethod
44  def has_value(self, name=None):
45    """Returns a tensor that evaluates to `True` if this optional has a value.
46
47    Args:
48      name: (Optional.) A name for the created operation.
49
50    Returns:
51      A scalar `tf.Tensor` of type `tf.bool`.
52    """
53    raise NotImplementedError("Optional.has_value()")
54
55  @abc.abstractmethod
56  def get_value(self, name=None):
57    """Returns a nested structure of values wrapped by this optional.
58
59    If this optional does not have a value (i.e. `self.has_value()` evaluates
60    to `False`), this operation will raise `tf.errors.InvalidArgumentError`
61    at runtime.
62
63    Args:
64      name: (Optional.) A name for the created operation.
65
66    Returns:
67      A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.
68    """
69    raise NotImplementedError("Optional.get_value()")
70
71  @abc.abstractproperty
72  def value_structure(self):
73    """The structure of the components of this optional.
74
75    Returns:
76      A `Structure` object representing the structure of the components of this
77        optional.
78    """
79    raise NotImplementedError("Optional.value_structure")
80
81  @staticmethod
82  def from_value(value):
83    """Returns an `Optional` that wraps the given value.
84
85    Args:
86      value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.
87
88    Returns:
89      An `Optional` that wraps `value`.
90    """
91    with ops.name_scope("optional") as scope:
92      with ops.name_scope("value"):
93        value_structure = structure.Structure.from_value(value)
94        encoded_value = value_structure._to_tensor_list(value)  # pylint: disable=protected-access
95
96    return _OptionalImpl(
97        gen_dataset_ops.optional_from_value(encoded_value, name=scope),
98        value_structure)
99
100  @staticmethod
101  def none_from_structure(value_structure):
102    """Returns an `Optional` that has no value.
103
104    NOTE: This method takes an argument that defines the structure of the value
105    that would be contained in the returned `Optional` if it had a value.
106
107    Args:
108      value_structure: A `Structure` object representing the structure of the
109        components of this optional.
110
111    Returns:
112      An `Optional` that has no value.
113    """
114    return _OptionalImpl(gen_dataset_ops.optional_none(), value_structure)
115
116
117class _OptionalImpl(Optional):
118  """Concrete implementation of `tf.data.experimental.Optional`.
119
120  NOTE(mrry): This implementation is kept private, to avoid defining
121  `Optional.__init__()` in the public API.
122  """
123
124  def __init__(self, variant_tensor, value_structure):
125    self._variant_tensor = variant_tensor
126    self._value_structure = value_structure
127
128  def has_value(self, name=None):
129    return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name)
130
131  def get_value(self, name=None):
132    # TODO(b/110122868): Consolidate the restructuring logic with similar logic
133    # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
134    with ops.name_scope(name, "OptionalGetValue",
135                        [self._variant_tensor]) as scope:
136      # pylint: disable=protected-access
137      return self._value_structure._from_tensor_list(
138          gen_dataset_ops.optional_get_value(
139              self._variant_tensor,
140              name=scope,
141              output_types=self._value_structure._flat_types,
142              output_shapes=self._value_structure._flat_shapes))
143
144  @property
145  def value_structure(self):
146    return self._value_structure
147
148
149@tf_export("data.experimental.OptionalStructure")
150class OptionalStructure(structure.Structure):
151  """Represents an optional potentially containing a structured value."""
152
153  def __init__(self, value_structure):
154    self._value_structure = value_structure
155
156  @property
157  def _flat_shapes(self):
158    return [tensor_shape.scalar()]
159
160  @property
161  def _flat_types(self):
162    return [dtypes.variant]
163
164  def is_compatible_with(self, other):
165    # pylint: disable=protected-access
166    return (isinstance(other, OptionalStructure) and
167            self._value_structure.is_compatible_with(other._value_structure))
168
169  def _to_tensor_list(self, value):
170    return [value._variant_tensor]  # pylint: disable=protected-access
171
172  def _to_batched_tensor_list(self, value):
173    raise NotImplementedError(
174        "Unbatching for `tf.data.experimental.Optional` objects.")
175
176  def _from_tensor_list(self, flat_value):
177    if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
178        not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
179      raise ValueError(
180          "OptionalStructure corresponds to a single tf.variant scalar.")
181    return self._from_compatible_tensor_list(flat_value)
182
183  def _from_compatible_tensor_list(self, flat_value):
184    # pylint: disable=protected-access
185    return _OptionalImpl(flat_value[0], self._value_structure)
186
187  @staticmethod
188  def from_value(value):
189    return OptionalStructure(value.value_structure)
190
191  def _to_legacy_output_types(self):
192    return self
193
194  def _to_legacy_output_shapes(self):
195    return self
196
197  def _to_legacy_output_classes(self):
198    return self
199
200  def _batch(self, batch_size):
201    raise NotImplementedError(
202        "Batching for `tf.data.experimental.Optional` objects.")
203
204  def _unbatch(self):
205    raise NotImplementedError(
206        "Unbatching for `tf.data.experimental.Optional` objects.")
207
208
209# pylint: disable=protected-access
210structure.Structure._register_custom_converter(Optional,
211                                               OptionalStructure.from_value)
212# pylint: enable=protected-access
213