1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""An Optional type for representing potentially missing values.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21 22import six 23 24from tensorflow.python.data.util import structure 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.ops import gen_dataset_ops 29from tensorflow.python.util.tf_export import tf_export 30 31 32@six.add_metaclass(abc.ABCMeta) 33class Optional(object): 34 """Wraps a nested structure of tensors that may/may not be present at runtime. 35 36 An `Optional` can represent the result of an operation that may fail as a 37 value, rather than raising an exception and halting execution. For example, 38 `tf.data.experimental.get_next_as_optional` returns an `Optional` that either 39 contains the next value from a `tf.data.Iterator` if one exists, or a "none" 40 value that indicates the end of the sequence has been reached. 41 """ 42 43 @abc.abstractmethod 44 def has_value(self, name=None): 45 """Returns a tensor that evaluates to `True` if this optional has a value. 46 47 Args: 48 name: (Optional.) A name for the created operation. 49 50 Returns: 51 A scalar `tf.Tensor` of type `tf.bool`. 52 """ 53 raise NotImplementedError("Optional.has_value()") 54 55 @abc.abstractmethod 56 def get_value(self, name=None): 57 """Returns a nested structure of values wrapped by this optional. 58 59 If this optional does not have a value (i.e. `self.has_value()` evaluates 60 to `False`), this operation will raise `tf.errors.InvalidArgumentError` 61 at runtime. 62 63 Args: 64 name: (Optional.) A name for the created operation. 65 66 Returns: 67 A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects. 68 """ 69 raise NotImplementedError("Optional.get_value()") 70 71 @abc.abstractproperty 72 def value_structure(self): 73 """The structure of the components of this optional. 74 75 Returns: 76 A `Structure` object representing the structure of the components of this 77 optional. 78 """ 79 raise NotImplementedError("Optional.value_structure") 80 81 @staticmethod 82 def from_value(value): 83 """Returns an `Optional` that wraps the given value. 84 85 Args: 86 value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects. 87 88 Returns: 89 An `Optional` that wraps `value`. 90 """ 91 with ops.name_scope("optional") as scope: 92 with ops.name_scope("value"): 93 value_structure = structure.Structure.from_value(value) 94 encoded_value = value_structure._to_tensor_list(value) # pylint: disable=protected-access 95 96 return _OptionalImpl( 97 gen_dataset_ops.optional_from_value(encoded_value, name=scope), 98 value_structure) 99 100 @staticmethod 101 def none_from_structure(value_structure): 102 """Returns an `Optional` that has no value. 103 104 NOTE: This method takes an argument that defines the structure of the value 105 that would be contained in the returned `Optional` if it had a value. 106 107 Args: 108 value_structure: A `Structure` object representing the structure of the 109 components of this optional. 110 111 Returns: 112 An `Optional` that has no value. 113 """ 114 return _OptionalImpl(gen_dataset_ops.optional_none(), value_structure) 115 116 117class _OptionalImpl(Optional): 118 """Concrete implementation of `tf.data.experimental.Optional`. 119 120 NOTE(mrry): This implementation is kept private, to avoid defining 121 `Optional.__init__()` in the public API. 122 """ 123 124 def __init__(self, variant_tensor, value_structure): 125 self._variant_tensor = variant_tensor 126 self._value_structure = value_structure 127 128 def has_value(self, name=None): 129 return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name) 130 131 def get_value(self, name=None): 132 # TODO(b/110122868): Consolidate the restructuring logic with similar logic 133 # in `Iterator.get_next()` and `StructuredFunctionWrapper`. 134 with ops.name_scope(name, "OptionalGetValue", 135 [self._variant_tensor]) as scope: 136 # pylint: disable=protected-access 137 return self._value_structure._from_tensor_list( 138 gen_dataset_ops.optional_get_value( 139 self._variant_tensor, 140 name=scope, 141 output_types=self._value_structure._flat_types, 142 output_shapes=self._value_structure._flat_shapes)) 143 144 @property 145 def value_structure(self): 146 return self._value_structure 147 148 149@tf_export("data.experimental.OptionalStructure") 150class OptionalStructure(structure.Structure): 151 """Represents an optional potentially containing a structured value.""" 152 153 def __init__(self, value_structure): 154 self._value_structure = value_structure 155 156 @property 157 def _flat_shapes(self): 158 return [tensor_shape.scalar()] 159 160 @property 161 def _flat_types(self): 162 return [dtypes.variant] 163 164 def is_compatible_with(self, other): 165 # pylint: disable=protected-access 166 return (isinstance(other, OptionalStructure) and 167 self._value_structure.is_compatible_with(other._value_structure)) 168 169 def _to_tensor_list(self, value): 170 return [value._variant_tensor] # pylint: disable=protected-access 171 172 def _to_batched_tensor_list(self, value): 173 raise NotImplementedError( 174 "Unbatching for `tf.data.experimental.Optional` objects.") 175 176 def _from_tensor_list(self, flat_value): 177 if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or 178 not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())): 179 raise ValueError( 180 "OptionalStructure corresponds to a single tf.variant scalar.") 181 return self._from_compatible_tensor_list(flat_value) 182 183 def _from_compatible_tensor_list(self, flat_value): 184 # pylint: disable=protected-access 185 return _OptionalImpl(flat_value[0], self._value_structure) 186 187 @staticmethod 188 def from_value(value): 189 return OptionalStructure(value.value_structure) 190 191 def _to_legacy_output_types(self): 192 return self 193 194 def _to_legacy_output_shapes(self): 195 return self 196 197 def _to_legacy_output_classes(self): 198 return self 199 200 def _batch(self, batch_size): 201 raise NotImplementedError( 202 "Batching for `tf.data.experimental.Optional` objects.") 203 204 def _unbatch(self): 205 raise NotImplementedError( 206 "Unbatching for `tf.data.experimental.Optional` objects.") 207 208 209# pylint: disable=protected-access 210structure.Structure._register_custom_converter(Optional, 211 OptionalStructure.from_value) 212# pylint: enable=protected-access 213