1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_FRAMEWORK_VARIANT_H_ 17 #define TENSORFLOW_FRAMEWORK_VARIANT_H_ 18 19 #include <functional> 20 #include <iostream> 21 #include <memory> 22 #include <type_traits> 23 #include <unordered_map> 24 #include <utility> 25 26 #include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove 27 #include "tensorflow/core/framework/type_index.h" 28 #include "tensorflow/core/framework/variant_tensor_data.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/strings/strcat.h" 31 #include "tensorflow/core/platform/mutex.h" 32 33 namespace tensorflow { 34 35 template <typename T> 36 string TypeNameVariant(const T& value); 37 38 template <typename T> 39 string DebugStringVariant(const T& value); 40 41 template <typename T> 42 void EncodeVariant(const T& value, VariantTensorData* data); 43 44 template <typename T> 45 bool DecodeVariant(const VariantTensorData& data, T* value); 46 47 template <typename T> 48 void EncodeVariant(const T& value, string* buf); 49 50 template <typename T> 51 bool DecodeVariant(const string& buf, T* value); 52 53 // This is an implementation of a type-erased container that can store an 54 // object of any type. The implementation is very similar to std::any, but has 55 // restrictions on the types of objects that can be stored, and eschews some of 56 // the fancier constructors available for std::any. An object of 57 // tensorflow::Variant is intended to be used as the value that will be stored 58 // in a tensorflow::Tensor object when its type is DT_VARIANT. 59 // 60 // tensorflow::Variant can store an object of a class that satisfies the 61 // following constraints: 62 // 63 // * The class is CopyConstructible. 64 // * The class has a default constructor. 65 // * It's either a protocol buffer, a tensorflow::Tensor, or defines the 66 // following functions: 67 // 68 // string TypeName() const; 69 // void Encode(VariantTensorData* data) const; 70 // void Decode(const VariantTensorData& data); 71 // 72 // Simple POD types can elide the Encode/Decode functions, they are provided by 73 // helper methods. 74 // Here are some typical usage patterns: 75 // 76 // Variant x = 10; 77 // EXPECT_EQ(*x.get<int>(), 10); 78 // 79 // Tensor t(DT_FLOAT, TensorShape({})); 80 // t.flat<float>()(0) = 42.0f; 81 // Variant x = t; 82 // EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 42.0f); 83 // 84 // Accessing the stored object: 85 // 86 // The get<T> function is the main mechanism to access the object 87 // stored in the container. It is type-safe, that is, calling 88 // get<T> when the stored object's type is not T, returns a 89 // nullptr. A raw pointer to the stored object can be obtained by calling 90 // get<void>(). 91 // 92 // Serializing/deserializing Variant object: 93 // 94 // The Variant class delegates serializing and deserializing operations to the 95 // contained object. Helper functions to do these operations are provided for 96 // POD data types, tensorflow::Tensor, and protocol buffer objects. However, 97 // other classes have to provide Encode/Decode functions to handle 98 // serialization. 99 // 100 // Objects stored in a Variant object often contain references to other 101 // tensorflow::Tensors of primitive types (Eg., a list of tensorflow::Tensors). 102 // To efficiently support those use cases, a structure is imposed on the 103 // serialization format. Namely, classes should serialize their contents into a 104 // VariantTensorData object: 105 // 106 // struct VariantTensorData { 107 // string type_name; 108 // string metadata; 109 // std::vector<Tensor> tensors; 110 // }; 111 // 112 // Objects with references to other Tensors can simply store those tensors in 113 // the `tensors` field, and serialize other metadata content in to the 114 // `metadata` field. 115 // 116 // Serialization example: 117 // 118 // Foo f = Foo {...}; 119 // Variant x = f; 120 // string serialized_f; 121 // x.Encode(&serialized_f); 122 // 123 // Variant y = Foo(); // default constructed Foo. 124 // y.Decode(&serialized_f); 125 // EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>()); 126 // 127 // 128 // A Variant storing serialized Variant data (a value of type 129 // VariantTensorDataProto) has different behavior from a standard Variant. 130 // Namely, its TypeName matches the TypeName of the original Variant; 131 // and its non-const get method performs lazy deserialization. 132 // 133 // Decode and copy example: 134 // 135 // Foo f = Foo {...}; 136 // Variant x = f; 137 // 138 // VariantTensorData serialized_data_f; 139 // VariantTensorDataProto serialized_proto_f; 140 // x.Encode(&serialized_data_f); 141 // serialized_data_f.ToProto(&serialized_proto_f); 142 // 143 // Variant y_type_unknown = serialized_proto_f; // Store serialized Variant. 144 // 145 // EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo. 146 // EXPECT_EQ(MakeTypeIndex<VariantTensorDataProto>(), 147 // y_type_unknown.TypeId()); 148 // // Decode and get y_type_unknown; compare to value in x. 149 // Foo f_decoded; 150 // EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded)); 151 // EXPECT_EQ(f_decoded, f); 152 // 153 class Variant { 154 public: 155 constexpr Variant() noexcept = default; 156 Variant(const Variant & other)157 Variant(const Variant& other) 158 : value_(other.is_empty() ? std::unique_ptr<ValueInterface>() 159 : other.value_->Clone()) {} 160 161 Variant(Variant&& other) noexcept = default; 162 163 // Make sure that the type is CopyConstructible and not a tensorflow::Variant 164 // object itself. We want the copy constructor to be chosen for the 165 // tensorflow::Variant case. 166 template <typename T, typename VT = typename std::decay<T>::type, 167 typename std::enable_if<!std::is_same<Variant, VT>::value && 168 std::is_copy_constructible<VT>::value, 169 void>::type* = nullptr> Variant(T && value)170 Variant(T&& value) // NOLINT 171 : value_(new Value<VT>(in_place, std::forward<T>(value))) {} 172 173 Variant& operator=(const Variant& rhs) { 174 Variant(rhs).swap(*this); 175 return *this; 176 } 177 178 Variant& operator=(Variant&& rhs) noexcept { 179 Variant(std::move(rhs)).swap(*this); 180 return *this; 181 } 182 is_empty()183 bool is_empty() const { return value_ == nullptr; } 184 clear()185 void clear() noexcept { value_.reset(); } 186 swap(Variant & other)187 void swap(Variant& other) noexcept { value_.swap(other.value_); } 188 189 // Note, unlike TypeName(), TypeId() does not return the TypeIndex 190 // of the original type when a TensorValueDataProto is stored as the 191 // value. In this case, it returns the TypeIndex of TensorValueDataProto. TypeId()192 TypeIndex TypeId() const { 193 const TypeIndex VoidTypeIndex = MakeTypeIndex<void>(); 194 if (is_empty()) { 195 return VoidTypeIndex; 196 } 197 return value_->TypeId(); 198 } 199 DebugString()200 string DebugString() const { 201 return strings::StrCat("Variant<type: ", TypeName(), 202 " value: ", value_->DebugString(), ">"); 203 } 204 205 // Returns a pointer to the stored value if it is type T, or nullptr 206 // otherwise. 207 template <typename T> get()208 T* get() { 209 const TypeIndex TTypeIndex = MakeTypeIndex<T>(); 210 if (is_empty() || (TTypeIndex != TypeId())) return nullptr; 211 return std::addressof(static_cast<Variant::Value<T>*>(value_.get())->value); 212 } 213 214 // Returns a pointer to the stored value if it is type T, or nullptr 215 // otherwise. 216 template <typename T> get()217 const T* get() const { 218 const TypeIndex TTypeIndex = MakeTypeIndex<T>(); 219 if (is_empty() || (TTypeIndex != TypeId())) return nullptr; 220 return std::addressof( 221 static_cast<const Variant::Value<T>*>(value_.get())->value); 222 } 223 224 // Returns TypeNameVariant(value). 225 // 226 // In the special case that a serialized Variant is stored (value 227 // is a VariantTensorDataProto), returns value.TypeName(), the 228 // TypeName field stored in the VariantTensorDataProto buffer. TypeName()229 string TypeName() const { 230 if (is_empty()) { 231 return ""; 232 } 233 return value_->TypeName(); 234 } 235 236 // Serialize the contents of the stored object into `data`. Encode(VariantTensorData * data)237 void Encode(VariantTensorData* data) const { 238 if (!is_empty()) { 239 value_->Encode(data); 240 } 241 } 242 243 // Deserialize `data` and update the stored object. Decode(const VariantTensorData & data)244 bool Decode(const VariantTensorData& data) { 245 if (!is_empty()) { 246 return value_->Decode(data); 247 } 248 return true; 249 } 250 251 // Helper methods to directly serialize/deserialize from strings. Encode(string * buf)252 void Encode(string* buf) const { 253 if (!is_empty()) { 254 value_->Encode(buf); 255 } 256 } Decode(const string & buf)257 bool Decode(const string& buf) { 258 if (!is_empty()) { 259 return value_->Decode(buf); 260 } 261 return true; 262 } 263 264 template <typename T> MaybeDecodeAndCopy(T * out)265 bool MaybeDecodeAndCopy(T* out) const { 266 const T* ret = get<T>(); 267 if (ret != nullptr) { 268 *out = std::move(*ret); 269 return true; 270 }; 271 Variant decoded = T(); 272 if (!TryDecode(&decoded)) return false; 273 T* decoded_ret = decoded.get<T>(); 274 CHECK_NOTNULL(decoded_ret); 275 *out = std::move(*decoded_ret); 276 return true; 277 } 278 279 private: 280 bool TryDecode(Variant* out) const; 281 282 private: 283 struct in_place_t {}; 284 static constexpr in_place_t in_place{}; 285 286 struct ValueInterface { 287 virtual ~ValueInterface() = default; 288 virtual TypeIndex TypeId() const = 0; 289 virtual void* RawPtr() = 0; 290 virtual const void* RawPtr() const = 0; 291 virtual std::unique_ptr<ValueInterface> Clone() const = 0; 292 virtual string TypeName() const = 0; 293 virtual string DebugString() const = 0; 294 virtual void Encode(VariantTensorData* data) const = 0; 295 virtual bool Decode(const VariantTensorData& data) = 0; 296 virtual void Encode(string* buf) const = 0; 297 virtual bool Decode(const string& data) = 0; 298 }; 299 300 template <typename T> 301 struct Value : ValueInterface { 302 template <class... Args> ValueValue303 explicit Value(in_place_t /*tag*/, Args&&... args) 304 : value(std::forward<Args>(args)...) {} 305 TypeIdValue306 TypeIndex TypeId() const override { 307 const TypeIndex value_type_index = 308 MakeTypeIndex<typename std::decay<T>::type>(); 309 return value_type_index; 310 } 311 RawPtrValue312 void* RawPtr() override { return &value; } 313 RawPtrValue314 const void* RawPtr() const override { return &value; } 315 CloneValue316 std::unique_ptr<ValueInterface> Clone() const override { 317 return std::unique_ptr<ValueInterface>(new Value(in_place, value)); 318 } 319 TypeNameValue320 string TypeName() const override { return TypeNameVariant(value); } 321 DebugStringValue322 string DebugString() const override { return DebugStringVariant(value); } 323 EncodeValue324 void Encode(VariantTensorData* data) const override { 325 EncodeVariant(value, data); 326 } 327 DecodeValue328 bool Decode(const VariantTensorData& data) override { 329 return DecodeVariant(data, &value); 330 } 331 EncodeValue332 void Encode(string* buf) const override { EncodeVariant(value, buf); } 333 DecodeValue334 bool Decode(const string& buf) override { 335 return DecodeVariant(buf, &value); 336 } 337 338 T value; 339 }; 340 341 // value_ can point to any type T as wrapped by a ValueInterface. 342 // The only real requirement is that T is default-constructible. 343 std::unique_ptr<ValueInterface> value_; 344 }; 345 346 template <> 347 void* Variant::get(); 348 349 template <> 350 const void* Variant::get() const; 351 352 } // end namespace tensorflow 353 354 #endif // TENSORFLOW_FRAMEWORK_VARIANT_H_ 355