1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implementation notes:
17 //
18 // Tensor.cc uses a few templated classes and structs to facilitate
19 // implementation of the Tensor class.
20 //
21 // * Buffer<T>: provides the implementation for a typed array T[n].
22 //   The array is allocated by the given allocator. It runs T's
23 //   default constructors and destructors when T is not a simple type
24 //   (e.g., string.), and skips them otherwise.
25 //
26 // * Helper<T>: provides various routines given type T.  The routines
27 //   includes running the constructor and destructor of T[], encoding
28 //   an decoding T[] into/from a Cord, etc.
29 
30 #include "tensorflow/core/framework/tensor.h"
31 
32 #include "tensorflow/core/framework/allocation_description.pb.h"
33 #include "tensorflow/core/framework/log_memory.h"
34 #include "tensorflow/core/framework/resource_handle.pb.h"
35 #include "tensorflow/core/framework/tensor.pb.h"
36 #include "tensorflow/core/framework/tensor_description.pb.h"
37 #include "tensorflow/core/framework/type_traits.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/framework/variant.h"
40 #include "tensorflow/core/framework/variant_encode_decode.h"
41 #include "tensorflow/core/framework/variant_op_registry.h"
42 #include "tensorflow/core/framework/variant_tensor_data.h"
43 #include "tensorflow/core/lib/core/coding.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/gtl/inlined_vector.h"
46 #include "tensorflow/core/lib/gtl/stl_util.h"
47 #include "tensorflow/core/lib/strings/str_util.h"
48 #include "tensorflow/core/lib/strings/strcat.h"
49 #include "tensorflow/core/platform/logging.h"
50 #include "tensorflow/core/platform/macros.h"
51 #include "tensorflow/core/platform/protobuf.h"
52 #include "tensorflow/core/platform/tensor_coding.h"
53 #include "tensorflow/core/platform/types.h"
54 #include "tensorflow/core/platform/variant_coding.h"
55 
56 namespace tensorflow {
57 
58 // Allow Tensors to be stored inside Variants with automatic
59 // encoding/decoding when those Variants are themselves being decoded
60 // in a Tensor's FromProto.
61 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
62 
63 namespace {
64 
65 // An un-templated base class for Buffer.
66 class BufferBase : public TensorBuffer {
67  public:
BufferBase(Allocator * alloc)68   explicit BufferBase(Allocator* alloc) : alloc_(alloc) {}
69 
root_buffer()70   TensorBuffer* root_buffer() override { return this; }
FillAllocationDescription(AllocationDescription * proto) const71   void FillAllocationDescription(AllocationDescription* proto) const override {
72     void* data_ptr = data();
73     int64 rb = size();
74     proto->set_requested_bytes(rb);
75     proto->set_allocator_name(alloc_->Name());
76     proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
77     if (alloc_->TracksAllocationSizes()) {
78       int64 ab = alloc_->AllocatedSize(data_ptr);
79       proto->set_allocated_bytes(ab);
80       int64 id = alloc_->AllocationId(data_ptr);
81       if (id > 0) {
82         proto->set_allocation_id(id);
83       }
84       if (RefCountIsOne()) {
85         proto->set_has_single_reference(true);
86       }
87     }
88   }
89 
90  protected:
RecordDeallocation()91   void RecordDeallocation() {
92     LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
93                                         alloc_->Name());
94   }
95 
96   Allocator* const alloc_;
97 };
98 
99 // Typed ref-counted buffer: T[n].
100 template <typename T>
101 class Buffer : public BufferBase {
102  public:
103   Buffer(Allocator* a, int64 n);
104   Buffer(Allocator* a, int64 n, const AllocationAttributes& allocation_attr);
105 
data() const106   void* data() const override { return data_; }
size() const107   size_t size() const override { return sizeof(T) * elem_; }
108 
109  private:
110   T* data_;
111   int64 elem_;
112 
113   ~Buffer() override;
114 
115   TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
116 };
117 
LogUnexpectedSize(int64 actual,int64 expected)118 void LogUnexpectedSize(int64 actual, int64 expected) {
119   LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
120 }
121 
122 // A set of helper functions depending on T.
123 template <typename T>
124 struct Helper {
125   // By default, we assume T is a simple type (float, int32, etc.)
126   static_assert(is_simple_type<T>::value, "T is not a simple type.");
127   typedef protobuf::RepeatedField<T> RepeatedFieldType;
128 
129   // Encoder of simple type T to a string.  We do a copy.
130   template <typename Destination>
Encodetensorflow::__anon2ce1c7b90111::Helper131   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
132     DCHECK_EQ(in->size(), sizeof(T) * n);
133     port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
134                            out);
135   }
136 
137   // Decoder of simple type T. Copy the bytes from "in" into the
138   // tensor buffer.
139   template <typename Source>
Decodetensorflow::__anon2ce1c7b90111::Helper140   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
141     if (in.size() != sizeof(T) * n) {
142       LogUnexpectedSize(in.size(), sizeof(T) * n);
143       return nullptr;
144     }
145     Buffer<T>* buf = new Buffer<T>(a, n);
146     char* data = buf->template base<char>();
147     if (data == nullptr) {
148       buf->Unref();
149       return nullptr;
150     }
151     port::CopyToArray(in, data);
152     return buf;
153   }
154 
155   // Memory usage.
TotalBytestensorflow::__anon2ce1c7b90111::Helper156   static int64 TotalBytes(TensorBuffer* in, int64 n) {
157     DCHECK_EQ(in->size(), sizeof(T) * n);
158     return in->size();
159   }
160 };
161 
162 // Helper specialization for string (the only non-simple type we
163 // support).
164 template <>
165 struct Helper<string> {
166   // Proto message uses RepeatedFieldType to hold repeated T.
167   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
168 
169   // Encodes "n" elements of type string stored in "in" into Cord
170   // "out", which is usually the TensorProto::tensor_content.
171   template <typename Destination>
Encodetensorflow::__anon2ce1c7b90111::Helper172   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
173     port::EncodeStringList(in->base<const string>(), n, out);
174   }
175 
176   // Decodes "n" elements of type string from "in" and constructs a
177   // buffer out of it. Returns nullptr if the decoding fails. "in" is
178   // usually the TensorProto::tensor_content.
179   template <typename Source>
Decodetensorflow::__anon2ce1c7b90111::Helper180   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
181     Buffer<string>* buf = new Buffer<string>(a, n);
182     string* strings = buf->template base<string>();
183     if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
184       buf->Unref();
185       return nullptr;
186     }
187     return buf;
188   }
189 
190   // Returns the estimated memory usage of "n" elements of type T
191   // stored in buffer "in".
TotalBytestensorflow::__anon2ce1c7b90111::Helper192   static int64 TotalBytes(TensorBuffer* in, int n) {
193     int64 tot = in->size();
194     DCHECK_EQ(tot, sizeof(string) * n);
195     const string* p = in->base<const string>();
196     for (int i = 0; i < n; ++i, ++p) tot += p->size();
197     return tot;
198   }
199 };
200 
201 template <>
202 struct Helper<ResourceHandle> {
203   // Proto message uses RepeatedFieldType to hold repeated T.
204   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
205 
206   // Encodes "n" elements of type ResourceHandle stored in "in" into destination
207   // "out", which is usually the TensorProto::tensor_content.
208   template <typename Destination>
Encodetensorflow::__anon2ce1c7b90111::Helper209   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
210     port::EncodeResourceHandleList(in->base<const ResourceHandle>(), n, out);
211   }
212 
213   // Decodes "n" elements of type string from "in" and constructs a
214   // buffer out of it. Returns nullptr if the decoding fails. "in" is
215   // usually the TensorProto::tensor_content.
216   template <typename Source>
Decodetensorflow::__anon2ce1c7b90111::Helper217   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
218     auto* buf = new Buffer<ResourceHandle>(a, n);
219     ResourceHandle* ps = buf->template base<ResourceHandle>();
220     if (ps == nullptr || !port::DecodeResourceHandleList(in, ps, n)) {
221       buf->Unref();
222       return nullptr;
223     }
224     return buf;
225   }
226 
227   // Returns the estimated memory usage of "n" elements of type T
228   // stored in buffer "in".
TotalBytestensorflow::__anon2ce1c7b90111::Helper229   static int64 TotalBytes(TensorBuffer* in, int n) {
230     return n * sizeof(ResourceHandle);
231   }
232 };
233 
234 template <>
235 struct Helper<Variant> {
236   // Encodes "n" elements of type Variant stored in "in" into destination
237   // "out", which is usually the TensorProto::tensor_content.
238   template <typename Destination>
Encodetensorflow::__anon2ce1c7b90111::Helper239   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
240     port::EncodeVariantList(in->base<const Variant>(), n, out);
241   }
242 
243   // Decodes "n" elements of type Variant from "in" and constructs a
244   // buffer out of it. Returns nullptr if the decoding fails. "in" is
245   // usually the TensorProto::tensor_content.
246   template <typename Source>
Decodetensorflow::__anon2ce1c7b90111::Helper247   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
248     auto* buf = new Buffer<Variant>(a, n);
249     Variant* ps = buf->template base<Variant>();
250     if (ps == nullptr || !port::DecodeVariantList(in, ps, n)) {
251       buf->Unref();
252       return nullptr;
253     }
254     return buf;
255   }
256 
257   // Returns the estimated memory usage of "n" elements of type T
258   // stored in buffer "in".
TotalBytestensorflow::__anon2ce1c7b90111::Helper259   static int64 TotalBytes(TensorBuffer* in, int n) {
260     return n * sizeof(Variant);
261   }
262 };
263 
264 template <typename T>
265 struct ProtoHelper {};
266 
267 // For a C++ type "T" (float, double, int32, etc.), the repeated field
268 // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
269 // int32, string, etc) in the TensorProto is used for serializing the
270 // tensor of type "T".
271 #define PROTO_TRAITS(T, F, N)                                          \
272   template <>                                                          \
273   struct ProtoHelper<T> {                                              \
274     typedef Helper<F>::RepeatedFieldType FieldType;                    \
275     static FieldType::const_iterator Begin(const TensorProto& proto) { \
276       return proto.N##_val().begin();                                  \
277     }                                                                  \
278     static size_t NumElements(const TensorProto& proto) {              \
279       return proto.N##_val().size();                                   \
280     }                                                                  \
281     static void Fill(const T* data, size_t n, TensorProto* proto) {    \
282       typename ProtoHelper<T>::FieldType copy(data, data + n);         \
283       proto->mutable_##N##_val()->Swap(&copy);                         \
284     }                                                                  \
285   };
286 PROTO_TRAITS(float, float, float);
287 PROTO_TRAITS(double, double, double);
288 PROTO_TRAITS(int32, int32, int);
289 PROTO_TRAITS(uint8, int32, int);
290 PROTO_TRAITS(uint16, int32, int);
291 PROTO_TRAITS(uint32, uint32, uint32);
292 PROTO_TRAITS(int16, int32, int);
293 PROTO_TRAITS(int8, int32, int);
294 PROTO_TRAITS(bool, bool, bool);
295 PROTO_TRAITS(string, string, string);
296 PROTO_TRAITS(qint8, int32, int);
297 PROTO_TRAITS(quint8, int32, int);
298 PROTO_TRAITS(qint16, int32, int);
299 PROTO_TRAITS(quint16, int32, int);
300 #undef PROTO_TRAITS
301 
302 template <>
303 struct ProtoHelper<int64> {
Begintensorflow::__anon2ce1c7b90111::ProtoHelper304   static const int64* Begin(const TensorProto& proto) {
305     return reinterpret_cast<const int64*>(proto.int64_val().begin());
306   }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper307   static size_t NumElements(const TensorProto& proto) {
308     return proto.int64_val().size();
309   }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper310   static void Fill(const int64* data, size_t n, TensorProto* proto) {
311     protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
312     proto->mutable_int64_val()->Swap(&copy);
313   }
314 };
315 
316 template <>
317 struct ProtoHelper<uint64> {
Begintensorflow::__anon2ce1c7b90111::ProtoHelper318   static const uint64* Begin(const TensorProto& proto) {
319     return reinterpret_cast<const uint64*>(proto.uint64_val().begin());
320   }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper321   static size_t NumElements(const TensorProto& proto) {
322     return proto.uint64_val().size();
323   }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper324   static void Fill(const uint64* data, size_t n, TensorProto* proto) {
325     protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
326     proto->mutable_uint64_val()->Swap(&copy);
327   }
328 };
329 
330 template <>
331 struct ProtoHelper<ResourceHandle> {
Begintensorflow::__anon2ce1c7b90111::ProtoHelper332   static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
333       const TensorProto& proto) {
334     return proto.resource_handle_val().begin();
335   }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper336   static size_t NumElements(const TensorProto& proto) {
337     return proto.resource_handle_val().size();
338   }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper339   static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
340     auto* handles = proto->mutable_resource_handle_val();
341     handles->Clear();
342     for (size_t i = 0; i < n; i++) {
343       data[i].AsProto(handles->Add());
344     }
345   }
346 };
347 
348 template <>
349 struct ProtoHelper<Variant> {
350   static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
Begintensorflow::__anon2ce1c7b90111::ProtoHelper351   Begin(const TensorProto& proto) {
352     return proto.variant_val().begin();
353   }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper354   static size_t NumElements(const TensorProto& proto) {
355     return proto.variant_val().size();
356   }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper357   static void Fill(const Variant* data, size_t n, TensorProto* proto) {
358     auto* variant_values = proto->mutable_variant_val();
359     variant_values->Clear();
360     for (size_t i = 0; i < n; ++i) {
361       VariantTensorData tmp;
362       data[i].Encode(&tmp);
363       tmp.ToProto(variant_values->Add());
364     }
365   }
366 };
367 
368 template <>
369 struct ProtoHelper<complex64> {
370   typedef Helper<float>::RepeatedFieldType FieldType;
Begintensorflow::__anon2ce1c7b90111::ProtoHelper371   static const complex64* Begin(const TensorProto& proto) {
372     return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
373   }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper374   static size_t NumElements(const TensorProto& proto) {
375     return proto.scomplex_val().size() / 2;
376   }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper377   static void Fill(const complex64* data, size_t n, TensorProto* proto) {
378     const float* p = reinterpret_cast<const float*>(data);
379     FieldType copy(p, p + n * 2);
380     proto->mutable_scomplex_val()->Swap(&copy);
381   }
382 };
383 
384 template <>
385 struct ProtoHelper<complex128> {
386   typedef Helper<double>::RepeatedFieldType FieldType;
Begintensorflow::__anon2ce1c7b90111::ProtoHelper387   static const complex128* Begin(const TensorProto& proto) {
388     return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
389   }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper390   static size_t NumElements(const TensorProto& proto) {
391     return proto.dcomplex_val().size() / 2;
392   }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper393   static void Fill(const complex128* data, size_t n, TensorProto* proto) {
394     const double* p = reinterpret_cast<const double*>(data);
395     FieldType copy(p, p + n * 2);
396     proto->mutable_dcomplex_val()->Swap(&copy);
397   }
398 };
399 
400 template <>
401 struct ProtoHelper<qint32> {
402   typedef Helper<int32>::RepeatedFieldType FieldType;
Begintensorflow::__anon2ce1c7b90111::ProtoHelper403   static const qint32* Begin(const TensorProto& proto) {
404     return reinterpret_cast<const qint32*>(proto.int_val().data());
405   }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper406   static size_t NumElements(const TensorProto& proto) {
407     return proto.int_val().size();
408   }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper409   static void Fill(const qint32* data, size_t n, TensorProto* proto) {
410     const int32* p = reinterpret_cast<const int32*>(data);
411     FieldType copy(p, p + n);
412     proto->mutable_int_val()->Swap(&copy);
413   }
414 };
415 
416 template <>
417 struct ProtoHelper<bfloat16> {
Filltensorflow::__anon2ce1c7b90111::ProtoHelper418   static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
419     proto->mutable_half_val()->Reserve(n);
420     for (size_t i = 0; i < n; ++i) {
421       proto->mutable_half_val()->AddAlreadyReserved(data[i].value);
422     }
423   }
424 };
425 
426 template <>
427 struct ProtoHelper<Eigen::half> {
Filltensorflow::__anon2ce1c7b90111::ProtoHelper428   static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
429     proto->mutable_half_val()->Reserve(n);
430     for (size_t i = 0; i < n; ++i) {
431       proto->mutable_half_val()->AddAlreadyReserved(data[i].x);
432     }
433   }
434 };
435 
436 template <typename T>
Buffer(Allocator * a,int64 n)437 Buffer<T>::Buffer(Allocator* a, int64 n)
438     : BufferBase(a), data_(a->Allocate<T>(n)), elem_(n) {}
439 
440 template <typename T>
Buffer(Allocator * a,int64 n,const AllocationAttributes & allocation_attr)441 Buffer<T>::Buffer(Allocator* a, int64 n,
442                   const AllocationAttributes& allocation_attr)
443     : BufferBase(a), data_(a->Allocate<T>(n, allocation_attr)), elem_(n) {}
444 
445 template <typename T>
~Buffer()446 Buffer<T>::~Buffer() {
447   if (data_) {
448     if (LogMemory::IsEnabled()) {
449       RecordDeallocation();
450     }
451     alloc_->Deallocate<T>(data_, elem_);
452   }
453 }
454 
455 // Allocates a T[n] buffer. Fills in the buffer with repeated values
456 // in "in".  If "in" has less values than "n", fills the rest of T[n]
457 // with the last value. If "in" has no values, fills T[n] with the
458 // default value for T.
459 //
460 // This routine is using the typed fields (float_val, etc.) in the
461 // tensor proto as opposed to the untyped binary representation
462 // (tensor_content). This is used when we expect the TensorProto is
463 // used by a client program which may not know how to encode a tensor
464 // in the compact binary representation.
465 template <typename T>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)466 TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
467   CHECK_GT(n, 0);
468   Buffer<T>* buf = new Buffer<T>(a, n);
469   T* data = buf->template base<T>();
470   if (data == nullptr) {
471     buf->Unref();
472     return nullptr;
473   }
474 
475   const int64 in_n = ProtoHelper<T>::NumElements(in);
476   if (in_n <= 0) {
477     std::fill_n(data, n, T());
478   } else {
479     auto begin = ProtoHelper<T>::Begin(in);
480     if (n <= in_n) {
481       std::copy_n(begin, n, data);
482     } else {
483       std::copy_n(begin, in_n, data);
484       const T& last = *(data + in_n - 1);
485       std::fill_n(data + in_n, n - in_n, last);
486     }
487   }
488 
489   return buf;
490 }
491 
492 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)493 TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
494                                       int64 n) {
495   CHECK_GT(n, 0);
496   Buffer<Variant>* buf = new Buffer<Variant>(a, n);
497   Variant* data = buf->template base<Variant>();
498   if (data == nullptr) {
499     buf->Unref();
500     return nullptr;
501   }
502   const int64 in_n = ProtoHelper<Variant>::NumElements(in);
503   if (in_n <= 0) {
504     std::fill_n(data, n, Variant());
505   } else {
506     for (int64 i = 0; i < in_n; ++i) {
507       data[i] = in.variant_val(i);
508       if (!DecodeUnaryVariant(&data[i])) {
509         LOG(ERROR) << "Could not decode variant with type_name: \""
510                    << data[i].TypeName()
511                    << "\".  Perhaps you forgot to register a "
512                       "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
513         buf->Unref();
514         return nullptr;
515       }
516     }
517     for (int64 i = in_n; i < n; ++i) {
518       data[i] = Variant();
519     }
520   }
521   return buf;
522 }
523 
524 // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
525 // identical to uint16 but with data stored in half_val instead of int_val (ie.,
526 // we don't use ProtoHelper<uint16>).
527 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)528 TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
529                                           int64 n) {
530   CHECK_GT(n, 0);
531   Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
532   uint16* data = buf->template base<uint16>();
533   if (data == nullptr) {
534     buf->Unref();
535     return nullptr;
536   }
537   const int64 in_n = in.half_val().size();
538   auto begin = in.half_val().begin();
539   if (n <= in_n) {
540     std::copy_n(begin, n, data);
541   } else if (in_n > 0) {
542     std::copy_n(begin, in_n, data);
543     const uint16 last = *(data + in_n - 1);
544     std::fill_n(data + in_n, n - in_n, last);
545   } else {
546     std::fill_n(data, n, 0);
547   }
548   return buf;
549 }
550 
551 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)552 TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
553                                        int64 n) {
554   CHECK_GT(n, 0);
555   Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
556   uint16* data = buf->template base<uint16>();
557   if (data == nullptr) {
558     buf->Unref();
559     return nullptr;
560   }
561   const int64 in_n = in.half_val().size();
562   auto begin = in.half_val().begin();
563   if (n <= in_n) {
564     std::copy_n(begin, n, data);
565   } else if (in_n > 0) {
566     std::copy_n(begin, in_n, data);
567     const uint16 last = *(data + in_n - 1);
568     std::fill_n(data + in_n, n - in_n, last);
569   } else {
570     std::fill_n(data, n, 0);
571   }
572   return buf;
573 }
574 
575 // Copies T[n] stored in the buffer "in" into the repeated field in
576 // "out" corresponding to type T.
577 template <typename T>
ToProtoField(const TensorBuffer & in,int64 n,TensorProto * out)578 void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) {
579   const T* data = in.base<const T>();
580   // NOTE: T may not the same as
581   // ProtoHelper<T>::FieldType::value_type.  E.g., T==int16,
582   // ProtoHelper<T>::FieldType::value_type==int32.  If performance is
583   // critical, we can specialize T=float and do memcpy directly.
584   ProtoHelper<T>::Fill(data, n, out);
585 }
586 
RefIfNonNull(core::RefCounted * buf)587 void RefIfNonNull(core::RefCounted* buf) {
588   if (buf) buf->Ref();
589 }
590 
UnrefIfNonNull(core::RefCounted * buf)591 void UnrefIfNonNull(core::RefCounted* buf) {
592   if (buf) buf->Unref();
593 }
594 
595 }  // end namespace
596 
Tensor()597 Tensor::Tensor() : Tensor(DT_FLOAT) {}
598 
Tensor(DataType type)599 Tensor::Tensor(DataType type) : shape_({0}), buf_(nullptr) { set_dtype(type); }
600 
Tensor(DataType type,const TensorShape & shape,TensorBuffer * buf)601 Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
602     : shape_(shape), buf_(buf) {
603   set_dtype(type);
604   RefIfNonNull(buf);
605 }
606 
IsInitialized() const607 bool Tensor::IsInitialized() const {
608   return (buf_ != nullptr && buf_->data() != nullptr) ||
609          shape_.num_elements() == 0;
610 }
611 
CheckType(DataType expected_dtype) const612 void Tensor::CheckType(DataType expected_dtype) const {
613   CHECK_EQ(dtype(), expected_dtype);
614 }
615 
CheckTypeAndIsAligned(DataType expected_dtype) const616 void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
617   CHECK_EQ(dtype(), expected_dtype);
618   CHECK(IsAligned()) << "CheckTypeAndIsAligned";
619 }
620 
CheckIsAlignedAndSingleElement() const621 void Tensor::CheckIsAlignedAndSingleElement() const {
622   CHECK(IsAligned()) << "Aligned and single element";
623   CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
624 }
625 
~Tensor()626 Tensor::~Tensor() { UnrefIfNonNull(buf_); }
627 
CopyFromInternal(const Tensor & other,const TensorShape & shape)628 void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
629   CHECK_EQ(shape.num_elements(), other.NumElements());
630   // Data type will be overwritten if this == &other, since dtype is part of
631   // shape.
632   DataType other_dtype = other.dtype();
633   shape_ = shape;
634   set_dtype(other_dtype);
635   if (buf_ != other.buf_) {
636     UnrefIfNonNull(buf_);
637     buf_ = other.buf_;
638     RefIfNonNull(buf_);
639   }
640 }
641 
UnsafeCopyFromInternal(const Tensor & other,DataType dtype,const TensorShape & shape)642 void Tensor::UnsafeCopyFromInternal(const Tensor& other, DataType dtype,
643                                     const TensorShape& shape) {
644   int in_size = DataTypeSize(other.dtype());
645   int out_size = DataTypeSize(dtype);
646   CHECK_NE(in_size, 0);
647   CHECK_NE(out_size, 0);
648   CHECK_EQ(shape.num_elements() * out_size,
649            other.shape().num_elements() * in_size);
650   shape_ = shape;
651   shape_.set_data_type(dtype);
652   if (buf_ != other.buf_) {
653     UnrefIfNonNull(buf_);
654     buf_ = other.buf_;
655     RefIfNonNull(buf_);
656   }
657 }
658 
659 // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
660 // For the latter case, we have to make sure that the refcount is
661 // one both for the SubBuffer _and_ the underlying TensorBuffer.
RefCountIsOne() const662 bool Tensor::RefCountIsOne() const {
663   return buf_ != nullptr && buf_->RefCountIsOne() &&
664          buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
665 }
666 
667 // The macro CASES() expands to a switch statement conditioned on
668 // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
669 #define SINGLE_ARG(...) __VA_ARGS__
670 #define CASE(TYPE, STMTS)             \
671   case DataTypeToEnum<TYPE>::value: { \
672     typedef TYPE T;                   \
673     STMTS;                            \
674     break;                            \
675   }
676 #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
677   switch (TYPE_ENUM) {                                         \
678     CASE(float, SINGLE_ARG(STMTS))                             \
679     CASE(double, SINGLE_ARG(STMTS))                            \
680     CASE(int32, SINGLE_ARG(STMTS))                             \
681     CASE(uint8, SINGLE_ARG(STMTS))                             \
682     CASE(uint16, SINGLE_ARG(STMTS))                            \
683     CASE(uint32, SINGLE_ARG(STMTS))                            \
684     CASE(uint64, SINGLE_ARG(STMTS))                            \
685     CASE(int16, SINGLE_ARG(STMTS))                             \
686     CASE(int8, SINGLE_ARG(STMTS))                              \
687     CASE(string, SINGLE_ARG(STMTS))                            \
688     CASE(complex64, SINGLE_ARG(STMTS))                         \
689     CASE(complex128, SINGLE_ARG(STMTS))                        \
690     CASE(int64, SINGLE_ARG(STMTS))                             \
691     CASE(bool, SINGLE_ARG(STMTS))                              \
692     CASE(qint32, SINGLE_ARG(STMTS))                            \
693     CASE(quint8, SINGLE_ARG(STMTS))                            \
694     CASE(qint8, SINGLE_ARG(STMTS))                             \
695     CASE(quint16, SINGLE_ARG(STMTS))                           \
696     CASE(qint16, SINGLE_ARG(STMTS))                            \
697     CASE(bfloat16, SINGLE_ARG(STMTS))                          \
698     CASE(Eigen::half, SINGLE_ARG(STMTS))                       \
699     CASE(ResourceHandle, SINGLE_ARG(STMTS))                    \
700     CASE(Variant, SINGLE_ARG(STMTS))                           \
701     case DT_INVALID:                                           \
702       INVALID;                                                 \
703       break;                                                   \
704     default:                                                   \
705       DEFAULT;                                                 \
706       break;                                                   \
707   }
708 
709 #define CASES(TYPE_ENUM, STMTS)                                      \
710   CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
711                      , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
712 
Tensor(Allocator * a,DataType type,const TensorShape & shape)713 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
714     : shape_(shape), buf_(nullptr) {
715   set_dtype(type);
716   CHECK_NOTNULL(a);
717   if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
718     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
719   }
720   if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
721     LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
722                                       *this);
723   }
724 }
725 
Tensor(Allocator * a,DataType type,const TensorShape & shape,const AllocationAttributes & allocation_attr)726 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
727                const AllocationAttributes& allocation_attr)
728     : shape_(shape), buf_(nullptr) {
729   set_dtype(type);
730   CHECK_NOTNULL(a);
731   if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
732     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
733   }
734   if (!allocation_attr.allocation_will_be_logged && buf_ != nullptr &&
735       buf_->data() != nullptr && LogMemory::IsEnabled()) {
736     LogMemory::RecordTensorAllocation("Unknown (with attributes)",
737                                       LogMemory::UNKNOWN_STEP_ID, *this);
738   }
739 }
740 
Tensor(DataType type,const TensorShape & shape)741 Tensor::Tensor(DataType type, const TensorShape& shape)
742     : Tensor(cpu_allocator(), type, shape) {}
743 
744 template <typename T>
745 class SubBuffer : public TensorBuffer {
746  public:
747   // This buffer is an alias to buf[delta, delta + n).
SubBuffer(TensorBuffer * buf,int64 delta,int64 n)748   SubBuffer(TensorBuffer* buf, int64 delta, int64 n)
749       : root_(buf->root_buffer()), data_(buf->base<T>() + delta), elem_(n) {
750     // Sanity check. The caller should ensure the sub buffer is valid.
751     CHECK_LE(root_->base<T>(), this->base<T>());
752     T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
753     CHECK_LE(this->base<T>(), root_limit);
754     CHECK_LE(this->base<T>() + n, root_limit);
755     // Hold a ref of the underlying root buffer.
756     // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
757     root_->Ref();
758   }
759 
data() const760   void* data() const override { return data_; }
size() const761   size_t size() const override { return sizeof(T) * elem_; }
root_buffer()762   TensorBuffer* root_buffer() override { return root_; }
FillAllocationDescription(AllocationDescription * proto) const763   void FillAllocationDescription(AllocationDescription* proto) const override {
764     root_->FillAllocationDescription(proto);
765   }
766 
767  private:
768   TensorBuffer* root_;
769   T* data_;
770   int64 elem_;
771 
~SubBuffer()772   ~SubBuffer() override { root_->Unref(); }
773 
774   TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
775 };
776 
Slice(int64 start,int64 limit) const777 Tensor Tensor::Slice(int64 start, int64 limit) const {
778   CHECK_GE(dims(), 1);
779   CHECK_LE(0, start);
780   CHECK_LE(start, limit);
781   int64 dim0_size = shape_.dim_size(0);
782   CHECK_LE(limit, dim0_size);
783   if ((start == 0) && (limit == dim0_size)) {
784     return *this;
785   }
786   Tensor ret;
787   ret.shape_ = shape_;
788   ret.set_dtype(dtype());
789   ret.buf_ = nullptr;
790   if (dim0_size > 0) {
791     const int64 elems_per_dim0 = NumElements() / dim0_size;
792     const int64 delta = start * elems_per_dim0;
793     dim0_size = limit - start;
794     ret.shape_.set_dim(0, dim0_size);
795     const int64 num_elems = dim0_size * elems_per_dim0;
796     if (buf_) {
797       DataType dt = dtype();
798       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
799     }
800   }
801   return ret;
802 }
803 
FromProto(const TensorProto & proto)804 bool Tensor::FromProto(const TensorProto& proto) {
805   return FromProto(cpu_allocator(), proto);
806 }
807 
FromProto(Allocator * a,const TensorProto & proto)808 bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
809   CHECK_NOTNULL(a);
810   TensorBuffer* p = nullptr;
811   if (!TensorShape::IsValid(proto.tensor_shape())) return false;
812   if (proto.dtype() == DT_INVALID) return false;
813   TensorShape shape(proto.tensor_shape());
814   const int64 N = shape.num_elements();
815   if (N > 0 && proto.dtype()) {
816     bool dtype_error = false;
817     if (!proto.tensor_content().empty()) {
818       const auto& content = proto.tensor_content();
819       CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
820                          dtype_error = true, dtype_error = true);
821     } else {
822       CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
823                          dtype_error = true, dtype_error = true);
824     }
825     if (dtype_error || p == nullptr) return false;
826   }
827   shape_ = shape;
828   set_dtype(proto.dtype());
829   UnrefIfNonNull(buf_);
830   buf_ = p;
831   // TODO(misard) add tracking of which kernels and steps are calling
832   // FromProto.
833   if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
834     LogMemory::RecordTensorAllocation("Unknown (from Proto)",
835                                       LogMemory::UNKNOWN_STEP_ID, *this);
836   }
837   return true;
838 }
839 
AsProtoField(TensorProto * proto) const840 void Tensor::AsProtoField(TensorProto* proto) const {
841   proto->Clear();
842   shape_.AsProto(proto->mutable_tensor_shape());
843   proto->set_dtype(dtype());
844   if (buf_) {
845     CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
846   }
847 }
848 
AsProtoTensorContent(TensorProto * proto) const849 void Tensor::AsProtoTensorContent(TensorProto* proto) const {
850   proto->Clear();
851   proto->set_dtype(dtype());
852   shape_.AsProto(proto->mutable_tensor_shape());
853   if (buf_) {
854     CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
855                                      proto->mutable_tensor_content()));
856   }
857 }
858 
TotalBytes() const859 size_t Tensor::TotalBytes() const {
860   if (shape_.num_elements() == 0) return 0;
861   CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
862   CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
863   return 0;  // Makes compiler happy.
864 }
865 
AllocatedBytes() const866 size_t Tensor::AllocatedBytes() const {
867   TensorDescription tensor_description;
868   FillDescription(&tensor_description);
869   if (tensor_description.has_allocation_description() &&
870       tensor_description.allocation_description().allocated_bytes() > 0) {
871     return tensor_description.allocation_description().allocated_bytes();
872   } else {
873     // Fall back to TotalBytes() if the allocator doesn't have its size.
874     return TotalBytes();
875   }
876 }
877 
CanUseDMA() const878 bool Tensor::CanUseDMA() const {
879   CASES(dtype(), return is_simple_type<T>::value);
880   return false;  // Makes compiler happy.
881 }
882 
883 #undef CASES
884 #undef CASE
885 
886 namespace {
887 // Print from left dim to right dim recursively.
888 template <typename T>
PrintOneDim(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 limit,int shape_size,const T * data,int64 * data_index,string * result)889 void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
890                  int64 limit, int shape_size, const T* data, int64* data_index,
891                  string* result) {
892   if (*data_index >= limit) return;
893   int64 element_count = shape[dim_index];
894   // We have reached the right-most dimension of the tensor.
895   if (dim_index == shape_size - 1) {
896     for (int64 i = 0; i < element_count; i++) {
897       if (*data_index >= limit) return;
898       if (i > 0) strings::StrAppend(result, " ");
899       strings::StrAppend(result, data[(*data_index)++]);
900     }
901     return;
902   }
903   // Loop every element of one dim.
904   for (int64 i = 0; i < element_count; i++) {
905     bool flag = false;
906     if (*data_index < limit) {
907       strings::StrAppend(result, "[");
908       flag = true;
909     }
910     // As for each element, print the sub-dim.
911     PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
912                 result);
913     if (*data_index < limit || flag) {
914       strings::StrAppend(result, "]");
915       flag = false;
916     }
917   }
918 }
919 
920 template <typename T>
SummarizeArray(int64 limit,int64 num_elts,const TensorShape & tensor_shape,const char * data)921 string SummarizeArray(int64 limit, int64 num_elts,
922                       const TensorShape& tensor_shape, const char* data) {
923   string ret;
924   const T* array = reinterpret_cast<const T*>(data);
925 
926   const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
927   if (shape.empty()) {
928     for (int64 i = 0; i < limit; ++i) {
929       if (i > 0) strings::StrAppend(&ret, " ");
930       strings::StrAppend(&ret, array[i]);
931     }
932     if (num_elts > limit) strings::StrAppend(&ret, "...");
933     return ret;
934   }
935   int64 data_index = 0;
936   const int shape_size = tensor_shape.dims();
937   PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
938 
939   if (num_elts > limit) strings::StrAppend(&ret, "...");
940   return ret;
941 }
942 }  // namespace
943 
SummarizeValue(int64 max_entries) const944 string Tensor::SummarizeValue(int64 max_entries) const {
945   const int64 num_elts = NumElements();
946   size_t limit = std::min(max_entries, num_elts);
947   if ((limit > 0) && (buf_ == nullptr)) {
948     return strings::StrCat("uninitialized Tensor of ", num_elts,
949                            " elements of type ", dtype());
950   }
951   const char* data = limit > 0 ? tensor_data().data() : nullptr;
952   switch (dtype()) {
953     case DT_HALF:
954       return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data);
955       break;
956     case DT_FLOAT:
957       return SummarizeArray<float>(limit, num_elts, shape_, data);
958       break;
959     case DT_DOUBLE:
960       return SummarizeArray<double>(limit, num_elts, shape_, data);
961       break;
962     case DT_UINT32:
963       return SummarizeArray<uint32>(limit, num_elts, shape_, data);
964       break;
965     case DT_INT32:
966       return SummarizeArray<int32>(limit, num_elts, shape_, data);
967       break;
968     case DT_UINT8:
969     case DT_QUINT8:
970       return SummarizeArray<uint8>(limit, num_elts, shape_, data);
971       break;
972     case DT_UINT16:
973     case DT_QUINT16:
974       return SummarizeArray<uint16>(limit, num_elts, shape_, data);
975       break;
976     case DT_INT16:
977     case DT_QINT16:
978       return SummarizeArray<int16>(limit, num_elts, shape_, data);
979       break;
980     case DT_INT8:
981     case DT_QINT8:
982       return SummarizeArray<int8>(limit, num_elts, shape_, data);
983       break;
984     case DT_UINT64:
985       return SummarizeArray<uint64>(limit, num_elts, shape_, data);
986       break;
987     case DT_INT64:
988       return SummarizeArray<int64>(limit, num_elts, shape_, data);
989       break;
990     case DT_BOOL:
991       // TODO(tucker): Is it better to emit "True False..."?  This
992       // will emit "1 0..." which is more compact.
993       return SummarizeArray<bool>(limit, num_elts, shape_, data);
994       break;
995     default: {
996       // All irregular cases
997       string ret;
998       // TODO(irving): Don't call flat every time around this
999       // loop.
1000       for (size_t i = 0; i < limit; ++i) {
1001         if (i > 0) strings::StrAppend(&ret, " ");
1002         switch (dtype()) {
1003           case DT_STRING:
1004             strings::StrAppend(&ret, str_util::CEscape(flat<string>()(i)));
1005             break;
1006           case DT_VARIANT: {
1007             const Variant& v = flat<Variant>()(i);
1008             strings::StrAppend(&ret, v.DebugString());
1009           } break;
1010           default:
1011             // TODO(zhifengc, josh11b): Pretty-print other types (bool,
1012             // complex64, quantized).
1013             strings::StrAppend(&ret, "?");
1014         }
1015       }
1016       if (max_entries < num_elts) strings::StrAppend(&ret, "...");
1017       return ret;
1018     }
1019   }
1020 }
1021 
tensor_data() const1022 StringPiece Tensor::tensor_data() const {
1023   if (buf_ == nullptr) return StringPiece();  // Don't die for empty tensors
1024   return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
1025 }
1026 
SharesBufferWith(const Tensor & b) const1027 bool Tensor::SharesBufferWith(const Tensor& b) const {
1028   return buf_ != nullptr && b.buf_ != nullptr &&
1029          buf_->root_buffer() == b.buf_->root_buffer();
1030 }
1031 
DebugString() const1032 string Tensor::DebugString() const {
1033   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1034                          " shape: ", shape().DebugString(),
1035                          " values: ", SummarizeValue(3), ">");
1036 }
1037 
FillDescription(TensorDescription * description) const1038 void Tensor::FillDescription(TensorDescription* description) const {
1039   description->set_dtype(dtype());
1040   shape().AsProto(description->mutable_shape());
1041   if (buf_ != nullptr && buf_->data() != nullptr) {
1042     buf_->FillAllocationDescription(
1043         description->mutable_allocation_description());
1044   }
1045 }
1046 
ComputeFlatInnerDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1047 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
1048     gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1049   gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1050   int64 offset = orig.size() - num_out_dims;
1051   for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
1052     const int64 in_dim = out_dim + offset;
1053     out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
1054   }
1055   for (int64 in_dim = 0; in_dim < offset; ++in_dim) {
1056     out_dims[0] *= orig[in_dim];
1057   }
1058   return out_dims;
1059 }
1060 
ComputeFlatOuterDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1061 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
1062     gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1063   gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1064   for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
1065     out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
1066   }
1067   for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
1068     out_dims[num_out_dims - 1] *= orig[in_dim];
1069   }
1070   return out_dims;
1071 }
1072 
1073 }  // namespace tensorflow
1074