1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_FRAMEWORK_VARIANT_H_
17 #define TENSORFLOW_FRAMEWORK_VARIANT_H_
18 
19 #include <functional>
20 #include <iostream>
21 #include <memory>
22 #include <type_traits>
23 #include <unordered_map>
24 #include <utility>
25 
26 #include "tensorflow/core/framework/tensor.pb.h"  // TODO(b/62899350): Remove
27 #include "tensorflow/core/framework/type_index.h"
28 #include "tensorflow/core/framework/variant_tensor_data.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/mutex.h"
32 
33 namespace tensorflow {
34 
35 template <typename T>
36 string TypeNameVariant(const T& value);
37 
38 template <typename T>
39 string DebugStringVariant(const T& value);
40 
41 template <typename T>
42 void EncodeVariant(const T& value, VariantTensorData* data);
43 
44 template <typename T>
45 bool DecodeVariant(const VariantTensorData& data, T* value);
46 
47 template <typename T>
48 void EncodeVariant(const T& value, string* buf);
49 
50 template <typename T>
51 bool DecodeVariant(const string& buf, T* value);
52 
53 // This is an implementation of a type-erased container that can store an
54 // object of any type. The implementation is very similar to std::any, but has
55 // restrictions on the types of objects that can be stored, and eschews some of
56 // the fancier constructors available for std::any. An object of
57 // tensorflow::Variant is intended to be used as the value that will be stored
58 // in a tensorflow::Tensor object when its type is DT_VARIANT.
59 //
60 // tensorflow::Variant can store an object of a class that satisfies the
61 // following constraints:
62 //
63 // * The class is CopyConstructible.
64 // * The class has a default constructor.
65 // * It's either a protocol buffer, a tensorflow::Tensor, or defines the
66 // following functions:
67 //
68 //   string TypeName() const;
69 //   void Encode(VariantTensorData* data) const;
70 //   void Decode(const VariantTensorData& data);
71 //
72 // Simple POD types can elide the Encode/Decode functions, they are provided by
73 // helper methods.
74 // Here are some typical usage patterns:
75 //
76 //   Variant x = 10;
77 //   EXPECT_EQ(*x.get<int>(), 10);
78 //
79 //   Tensor t(DT_FLOAT, TensorShape({}));
80 //   t.flat<float>()(0) = 42.0f;
81 //   Variant x = t;
82 //   EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 42.0f);
83 //
84 // Accessing the stored object:
85 //
86 // The get<T> function is the main mechanism to access the object
87 // stored in the container. It is type-safe, that is, calling
88 // get<T> when the stored object's type is not T, returns a
89 // nullptr. A raw pointer to the stored object can be obtained by calling
90 // get<void>().
91 //
92 // Serializing/deserializing Variant object:
93 //
94 // The Variant class delegates serializing and deserializing operations to the
95 // contained object. Helper functions to do these operations are provided for
96 // POD data types, tensorflow::Tensor, and protocol buffer objects. However,
97 // other classes have to provide Encode/Decode functions to handle
98 // serialization.
99 //
100 // Objects stored in a Variant object often contain references to other
101 // tensorflow::Tensors of primitive types (Eg., a list of tensorflow::Tensors).
102 // To efficiently support those use cases, a structure is imposed on the
103 // serialization format. Namely, classes should serialize their contents into a
104 // VariantTensorData object:
105 //
106 //   struct VariantTensorData {
107 //     string type_name;
108 //     string metadata;
109 //     std::vector<Tensor> tensors;
110 //   };
111 //
112 // Objects with references to other Tensors can simply store those tensors in
113 // the `tensors` field, and serialize other metadata content in to the
114 // `metadata` field.
115 //
116 // Serialization example:
117 //
118 //   Foo f = Foo {...};
119 //   Variant x = f;
120 //   string serialized_f;
121 //   x.Encode(&serialized_f);
122 //
123 //   Variant y = Foo(); // default constructed Foo.
124 //   y.Decode(&serialized_f);
125 //   EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>());
126 //
127 //
128 // A Variant storing serialized Variant data (a value of type
129 // VariantTensorDataProto) has different behavior from a standard Variant.
130 // Namely, its TypeName matches the TypeName of the original Variant;
131 // and its non-const get method performs lazy deserialization.
132 //
133 // Decode and copy example:
134 //
135 //   Foo f = Foo {...};
136 //   Variant x = f;
137 //
138 //   VariantTensorData serialized_data_f;
139 //   VariantTensorDataProto serialized_proto_f;
140 //   x.Encode(&serialized_data_f);
141 //   serialized_data_f.ToProto(&serialized_proto_f);
142 //
143 //   Variant y_type_unknown = serialized_proto_f;  // Store serialized Variant.
144 //
145 //   EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName());  // Looks like Foo.
146 //   EXPECT_EQ(MakeTypeIndex<VariantTensorDataProto>(),
147 //             y_type_unknown.TypeId());
148 //   // Decode and get y_type_unknown; compare to value in x.
149 //   Foo f_decoded;
150 //   EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded));
151 //   EXPECT_EQ(f_decoded, f);
152 //
153 class Variant {
154  public:
155   constexpr Variant() noexcept = default;
156 
Variant(const Variant & other)157   Variant(const Variant& other)
158       : value_(other.is_empty() ? std::unique_ptr<ValueInterface>()
159                                 : other.value_->Clone()) {}
160 
161   Variant(Variant&& other) noexcept = default;
162 
163   // Make sure that the type is CopyConstructible and not a tensorflow::Variant
164   // object itself. We want the copy constructor to be chosen for the
165   // tensorflow::Variant case.
166   template <typename T, typename VT = typename std::decay<T>::type,
167             typename std::enable_if<!std::is_same<Variant, VT>::value &&
168                                         std::is_copy_constructible<VT>::value,
169                                     void>::type* = nullptr>
Variant(T && value)170   Variant(T&& value)  // NOLINT
171       : value_(new Value<VT>(in_place, std::forward<T>(value))) {}
172 
173   Variant& operator=(const Variant& rhs) {
174     Variant(rhs).swap(*this);
175     return *this;
176   }
177 
178   Variant& operator=(Variant&& rhs) noexcept {
179     Variant(std::move(rhs)).swap(*this);
180     return *this;
181   }
182 
is_empty()183   bool is_empty() const { return value_ == nullptr; }
184 
clear()185   void clear() noexcept { value_.reset(); }
186 
swap(Variant & other)187   void swap(Variant& other) noexcept { value_.swap(other.value_); }
188 
189   // Note, unlike TypeName(), TypeId() does not return the TypeIndex
190   // of the original type when a TensorValueDataProto is stored as the
191   // value.  In this case, it returns the TypeIndex of TensorValueDataProto.
TypeId()192   TypeIndex TypeId() const {
193     const TypeIndex VoidTypeIndex = MakeTypeIndex<void>();
194     if (is_empty()) {
195       return VoidTypeIndex;
196     }
197     return value_->TypeId();
198   }
199 
DebugString()200   string DebugString() const {
201     return strings::StrCat("Variant<type: ", TypeName(),
202                            " value: ", value_->DebugString(), ">");
203   }
204 
205   // Returns a pointer to the stored value if it is type T, or nullptr
206   // otherwise.
207   template <typename T>
get()208   T* get() {
209     const TypeIndex TTypeIndex = MakeTypeIndex<T>();
210     if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
211     return std::addressof(static_cast<Variant::Value<T>*>(value_.get())->value);
212   }
213 
214   // Returns a pointer to the stored value if it is type T, or nullptr
215   // otherwise.
216   template <typename T>
get()217   const T* get() const {
218     const TypeIndex TTypeIndex = MakeTypeIndex<T>();
219     if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
220     return std::addressof(
221         static_cast<const Variant::Value<T>*>(value_.get())->value);
222   }
223 
224   // Returns TypeNameVariant(value).
225   //
226   // In the special case that a serialized Variant is stored (value
227   // is a VariantTensorDataProto), returns value.TypeName(), the
228   // TypeName field stored in the VariantTensorDataProto buffer.
TypeName()229   string TypeName() const {
230     if (is_empty()) {
231       return "";
232     }
233     return value_->TypeName();
234   }
235 
236   // Serialize the contents of the stored object into `data`.
Encode(VariantTensorData * data)237   void Encode(VariantTensorData* data) const {
238     if (!is_empty()) {
239       value_->Encode(data);
240     }
241   }
242 
243   // Deserialize `data` and update the stored object.
Decode(const VariantTensorData & data)244   bool Decode(const VariantTensorData& data) {
245     if (!is_empty()) {
246       return value_->Decode(data);
247     }
248     return true;
249   }
250 
251   // Helper methods to directly serialize/deserialize from strings.
Encode(string * buf)252   void Encode(string* buf) const {
253     if (!is_empty()) {
254       value_->Encode(buf);
255     }
256   }
Decode(const string & buf)257   bool Decode(const string& buf) {
258     if (!is_empty()) {
259       return value_->Decode(buf);
260     }
261     return true;
262   }
263 
264   template <typename T>
MaybeDecodeAndCopy(T * out)265   bool MaybeDecodeAndCopy(T* out) const {
266     const T* ret = get<T>();
267     if (ret != nullptr) {
268       *out = std::move(*ret);
269       return true;
270     };
271     Variant decoded = T();
272     if (!TryDecode(&decoded)) return false;
273     T* decoded_ret = decoded.get<T>();
274     CHECK_NOTNULL(decoded_ret);
275     *out = std::move(*decoded_ret);
276     return true;
277   }
278 
279  private:
280   bool TryDecode(Variant* out) const;
281 
282  private:
283   struct in_place_t {};
284   static constexpr in_place_t in_place{};
285 
286   struct ValueInterface {
287     virtual ~ValueInterface() = default;
288     virtual TypeIndex TypeId() const = 0;
289     virtual void* RawPtr() = 0;
290     virtual const void* RawPtr() const = 0;
291     virtual std::unique_ptr<ValueInterface> Clone() const = 0;
292     virtual string TypeName() const = 0;
293     virtual string DebugString() const = 0;
294     virtual void Encode(VariantTensorData* data) const = 0;
295     virtual bool Decode(const VariantTensorData& data) = 0;
296     virtual void Encode(string* buf) const = 0;
297     virtual bool Decode(const string& data) = 0;
298   };
299 
300   template <typename T>
301   struct Value : ValueInterface {
302     template <class... Args>
ValueValue303     explicit Value(in_place_t /*tag*/, Args&&... args)
304         : value(std::forward<Args>(args)...) {}
305 
TypeIdValue306     TypeIndex TypeId() const override {
307       const TypeIndex value_type_index =
308           MakeTypeIndex<typename std::decay<T>::type>();
309       return value_type_index;
310     }
311 
RawPtrValue312     void* RawPtr() override { return &value; }
313 
RawPtrValue314     const void* RawPtr() const override { return &value; }
315 
CloneValue316     std::unique_ptr<ValueInterface> Clone() const override {
317       return std::unique_ptr<ValueInterface>(new Value(in_place, value));
318     }
319 
TypeNameValue320     string TypeName() const override { return TypeNameVariant(value); }
321 
DebugStringValue322     string DebugString() const override { return DebugStringVariant(value); }
323 
EncodeValue324     void Encode(VariantTensorData* data) const override {
325       EncodeVariant(value, data);
326     }
327 
DecodeValue328     bool Decode(const VariantTensorData& data) override {
329       return DecodeVariant(data, &value);
330     }
331 
EncodeValue332     void Encode(string* buf) const override { EncodeVariant(value, buf); }
333 
DecodeValue334     bool Decode(const string& buf) override {
335       return DecodeVariant(buf, &value);
336     }
337 
338     T value;
339   };
340 
341   // value_ can point to any type T as wrapped by a ValueInterface.
342   // The only real requirement is that T is default-constructible.
343   std::unique_ptr<ValueInterface> value_;
344 };
345 
346 template <>
347 void* Variant::get();
348 
349 template <>
350 const void* Variant::get() const;
351 
352 }  // end namespace tensorflow
353 
354 #endif  // TENSORFLOW_FRAMEWORK_VARIANT_H_
355