1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implementation notes:
17 //
18 // Tensor.cc uses a few templated classes and structs to facilitate
19 // implementation of the Tensor class.
20 //
21 // * Buffer<T>: provides the implementation for a typed array T[n].
22 // The array is allocated by the given allocator. It runs T's
23 // default constructors and destructors when T is not a simple type
24 // (e.g., string.), and skips them otherwise.
25 //
26 // * Helper<T>: provides various routines given type T. The routines
27 // includes running the constructor and destructor of T[], encoding
28 // an decoding T[] into/from a Cord, etc.
29
30 #include "tensorflow/core/framework/tensor.h"
31
32 #include "tensorflow/core/framework/allocation_description.pb.h"
33 #include "tensorflow/core/framework/log_memory.h"
34 #include "tensorflow/core/framework/resource_handle.pb.h"
35 #include "tensorflow/core/framework/tensor.pb.h"
36 #include "tensorflow/core/framework/tensor_description.pb.h"
37 #include "tensorflow/core/framework/type_traits.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/framework/variant.h"
40 #include "tensorflow/core/framework/variant_encode_decode.h"
41 #include "tensorflow/core/framework/variant_op_registry.h"
42 #include "tensorflow/core/framework/variant_tensor_data.h"
43 #include "tensorflow/core/lib/core/coding.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/gtl/inlined_vector.h"
46 #include "tensorflow/core/lib/gtl/stl_util.h"
47 #include "tensorflow/core/lib/strings/str_util.h"
48 #include "tensorflow/core/lib/strings/strcat.h"
49 #include "tensorflow/core/platform/logging.h"
50 #include "tensorflow/core/platform/macros.h"
51 #include "tensorflow/core/platform/protobuf.h"
52 #include "tensorflow/core/platform/tensor_coding.h"
53 #include "tensorflow/core/platform/types.h"
54 #include "tensorflow/core/platform/variant_coding.h"
55
56 namespace tensorflow {
57
58 // Allow Tensors to be stored inside Variants with automatic
59 // encoding/decoding when those Variants are themselves being decoded
60 // in a Tensor's FromProto.
61 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
62
63 namespace {
64
65 // An un-templated base class for Buffer.
66 class BufferBase : public TensorBuffer {
67 public:
BufferBase(Allocator * alloc)68 explicit BufferBase(Allocator* alloc) : alloc_(alloc) {}
69
root_buffer()70 TensorBuffer* root_buffer() override { return this; }
FillAllocationDescription(AllocationDescription * proto) const71 void FillAllocationDescription(AllocationDescription* proto) const override {
72 void* data_ptr = data();
73 int64 rb = size();
74 proto->set_requested_bytes(rb);
75 proto->set_allocator_name(alloc_->Name());
76 proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
77 if (alloc_->TracksAllocationSizes()) {
78 int64 ab = alloc_->AllocatedSize(data_ptr);
79 proto->set_allocated_bytes(ab);
80 int64 id = alloc_->AllocationId(data_ptr);
81 if (id > 0) {
82 proto->set_allocation_id(id);
83 }
84 if (RefCountIsOne()) {
85 proto->set_has_single_reference(true);
86 }
87 }
88 }
89
90 protected:
RecordDeallocation()91 void RecordDeallocation() {
92 LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
93 alloc_->Name());
94 }
95
96 Allocator* const alloc_;
97 };
98
99 // Typed ref-counted buffer: T[n].
100 template <typename T>
101 class Buffer : public BufferBase {
102 public:
103 Buffer(Allocator* a, int64 n);
104 Buffer(Allocator* a, int64 n, const AllocationAttributes& allocation_attr);
105
data() const106 void* data() const override { return data_; }
size() const107 size_t size() const override { return sizeof(T) * elem_; }
108
109 private:
110 T* data_;
111 int64 elem_;
112
113 ~Buffer() override;
114
115 TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
116 };
117
LogUnexpectedSize(int64 actual,int64 expected)118 void LogUnexpectedSize(int64 actual, int64 expected) {
119 LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
120 }
121
122 // A set of helper functions depending on T.
123 template <typename T>
124 struct Helper {
125 // By default, we assume T is a simple type (float, int32, etc.)
126 static_assert(is_simple_type<T>::value, "T is not a simple type.");
127 typedef protobuf::RepeatedField<T> RepeatedFieldType;
128
129 // Encoder of simple type T to a string. We do a copy.
130 template <typename Destination>
Encodetensorflow::__anon2ce1c7b90111::Helper131 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
132 DCHECK_EQ(in->size(), sizeof(T) * n);
133 port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
134 out);
135 }
136
137 // Decoder of simple type T. Copy the bytes from "in" into the
138 // tensor buffer.
139 template <typename Source>
Decodetensorflow::__anon2ce1c7b90111::Helper140 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
141 if (in.size() != sizeof(T) * n) {
142 LogUnexpectedSize(in.size(), sizeof(T) * n);
143 return nullptr;
144 }
145 Buffer<T>* buf = new Buffer<T>(a, n);
146 char* data = buf->template base<char>();
147 if (data == nullptr) {
148 buf->Unref();
149 return nullptr;
150 }
151 port::CopyToArray(in, data);
152 return buf;
153 }
154
155 // Memory usage.
TotalBytestensorflow::__anon2ce1c7b90111::Helper156 static int64 TotalBytes(TensorBuffer* in, int64 n) {
157 DCHECK_EQ(in->size(), sizeof(T) * n);
158 return in->size();
159 }
160 };
161
162 // Helper specialization for string (the only non-simple type we
163 // support).
164 template <>
165 struct Helper<string> {
166 // Proto message uses RepeatedFieldType to hold repeated T.
167 typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
168
169 // Encodes "n" elements of type string stored in "in" into Cord
170 // "out", which is usually the TensorProto::tensor_content.
171 template <typename Destination>
Encodetensorflow::__anon2ce1c7b90111::Helper172 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
173 port::EncodeStringList(in->base<const string>(), n, out);
174 }
175
176 // Decodes "n" elements of type string from "in" and constructs a
177 // buffer out of it. Returns nullptr if the decoding fails. "in" is
178 // usually the TensorProto::tensor_content.
179 template <typename Source>
Decodetensorflow::__anon2ce1c7b90111::Helper180 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
181 Buffer<string>* buf = new Buffer<string>(a, n);
182 string* strings = buf->template base<string>();
183 if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
184 buf->Unref();
185 return nullptr;
186 }
187 return buf;
188 }
189
190 // Returns the estimated memory usage of "n" elements of type T
191 // stored in buffer "in".
TotalBytestensorflow::__anon2ce1c7b90111::Helper192 static int64 TotalBytes(TensorBuffer* in, int n) {
193 int64 tot = in->size();
194 DCHECK_EQ(tot, sizeof(string) * n);
195 const string* p = in->base<const string>();
196 for (int i = 0; i < n; ++i, ++p) tot += p->size();
197 return tot;
198 }
199 };
200
201 template <>
202 struct Helper<ResourceHandle> {
203 // Proto message uses RepeatedFieldType to hold repeated T.
204 typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
205
206 // Encodes "n" elements of type ResourceHandle stored in "in" into destination
207 // "out", which is usually the TensorProto::tensor_content.
208 template <typename Destination>
Encodetensorflow::__anon2ce1c7b90111::Helper209 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
210 port::EncodeResourceHandleList(in->base<const ResourceHandle>(), n, out);
211 }
212
213 // Decodes "n" elements of type string from "in" and constructs a
214 // buffer out of it. Returns nullptr if the decoding fails. "in" is
215 // usually the TensorProto::tensor_content.
216 template <typename Source>
Decodetensorflow::__anon2ce1c7b90111::Helper217 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
218 auto* buf = new Buffer<ResourceHandle>(a, n);
219 ResourceHandle* ps = buf->template base<ResourceHandle>();
220 if (ps == nullptr || !port::DecodeResourceHandleList(in, ps, n)) {
221 buf->Unref();
222 return nullptr;
223 }
224 return buf;
225 }
226
227 // Returns the estimated memory usage of "n" elements of type T
228 // stored in buffer "in".
TotalBytestensorflow::__anon2ce1c7b90111::Helper229 static int64 TotalBytes(TensorBuffer* in, int n) {
230 return n * sizeof(ResourceHandle);
231 }
232 };
233
234 template <>
235 struct Helper<Variant> {
236 // Encodes "n" elements of type Variant stored in "in" into destination
237 // "out", which is usually the TensorProto::tensor_content.
238 template <typename Destination>
Encodetensorflow::__anon2ce1c7b90111::Helper239 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
240 port::EncodeVariantList(in->base<const Variant>(), n, out);
241 }
242
243 // Decodes "n" elements of type Variant from "in" and constructs a
244 // buffer out of it. Returns nullptr if the decoding fails. "in" is
245 // usually the TensorProto::tensor_content.
246 template <typename Source>
Decodetensorflow::__anon2ce1c7b90111::Helper247 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
248 auto* buf = new Buffer<Variant>(a, n);
249 Variant* ps = buf->template base<Variant>();
250 if (ps == nullptr || !port::DecodeVariantList(in, ps, n)) {
251 buf->Unref();
252 return nullptr;
253 }
254 return buf;
255 }
256
257 // Returns the estimated memory usage of "n" elements of type T
258 // stored in buffer "in".
TotalBytestensorflow::__anon2ce1c7b90111::Helper259 static int64 TotalBytes(TensorBuffer* in, int n) {
260 return n * sizeof(Variant);
261 }
262 };
263
264 template <typename T>
265 struct ProtoHelper {};
266
267 // For a C++ type "T" (float, double, int32, etc.), the repeated field
268 // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
269 // int32, string, etc) in the TensorProto is used for serializing the
270 // tensor of type "T".
271 #define PROTO_TRAITS(T, F, N) \
272 template <> \
273 struct ProtoHelper<T> { \
274 typedef Helper<F>::RepeatedFieldType FieldType; \
275 static FieldType::const_iterator Begin(const TensorProto& proto) { \
276 return proto.N##_val().begin(); \
277 } \
278 static size_t NumElements(const TensorProto& proto) { \
279 return proto.N##_val().size(); \
280 } \
281 static void Fill(const T* data, size_t n, TensorProto* proto) { \
282 typename ProtoHelper<T>::FieldType copy(data, data + n); \
283 proto->mutable_##N##_val()->Swap(©); \
284 } \
285 };
286 PROTO_TRAITS(float, float, float);
287 PROTO_TRAITS(double, double, double);
288 PROTO_TRAITS(int32, int32, int);
289 PROTO_TRAITS(uint8, int32, int);
290 PROTO_TRAITS(uint16, int32, int);
291 PROTO_TRAITS(uint32, uint32, uint32);
292 PROTO_TRAITS(int16, int32, int);
293 PROTO_TRAITS(int8, int32, int);
294 PROTO_TRAITS(bool, bool, bool);
295 PROTO_TRAITS(string, string, string);
296 PROTO_TRAITS(qint8, int32, int);
297 PROTO_TRAITS(quint8, int32, int);
298 PROTO_TRAITS(qint16, int32, int);
299 PROTO_TRAITS(quint16, int32, int);
300 #undef PROTO_TRAITS
301
302 template <>
303 struct ProtoHelper<int64> {
Begintensorflow::__anon2ce1c7b90111::ProtoHelper304 static const int64* Begin(const TensorProto& proto) {
305 return reinterpret_cast<const int64*>(proto.int64_val().begin());
306 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper307 static size_t NumElements(const TensorProto& proto) {
308 return proto.int64_val().size();
309 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper310 static void Fill(const int64* data, size_t n, TensorProto* proto) {
311 protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
312 proto->mutable_int64_val()->Swap(©);
313 }
314 };
315
316 template <>
317 struct ProtoHelper<uint64> {
Begintensorflow::__anon2ce1c7b90111::ProtoHelper318 static const uint64* Begin(const TensorProto& proto) {
319 return reinterpret_cast<const uint64*>(proto.uint64_val().begin());
320 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper321 static size_t NumElements(const TensorProto& proto) {
322 return proto.uint64_val().size();
323 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper324 static void Fill(const uint64* data, size_t n, TensorProto* proto) {
325 protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
326 proto->mutable_uint64_val()->Swap(©);
327 }
328 };
329
330 template <>
331 struct ProtoHelper<ResourceHandle> {
Begintensorflow::__anon2ce1c7b90111::ProtoHelper332 static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
333 const TensorProto& proto) {
334 return proto.resource_handle_val().begin();
335 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper336 static size_t NumElements(const TensorProto& proto) {
337 return proto.resource_handle_val().size();
338 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper339 static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
340 auto* handles = proto->mutable_resource_handle_val();
341 handles->Clear();
342 for (size_t i = 0; i < n; i++) {
343 data[i].AsProto(handles->Add());
344 }
345 }
346 };
347
348 template <>
349 struct ProtoHelper<Variant> {
350 static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
Begintensorflow::__anon2ce1c7b90111::ProtoHelper351 Begin(const TensorProto& proto) {
352 return proto.variant_val().begin();
353 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper354 static size_t NumElements(const TensorProto& proto) {
355 return proto.variant_val().size();
356 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper357 static void Fill(const Variant* data, size_t n, TensorProto* proto) {
358 auto* variant_values = proto->mutable_variant_val();
359 variant_values->Clear();
360 for (size_t i = 0; i < n; ++i) {
361 VariantTensorData tmp;
362 data[i].Encode(&tmp);
363 tmp.ToProto(variant_values->Add());
364 }
365 }
366 };
367
368 template <>
369 struct ProtoHelper<complex64> {
370 typedef Helper<float>::RepeatedFieldType FieldType;
Begintensorflow::__anon2ce1c7b90111::ProtoHelper371 static const complex64* Begin(const TensorProto& proto) {
372 return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
373 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper374 static size_t NumElements(const TensorProto& proto) {
375 return proto.scomplex_val().size() / 2;
376 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper377 static void Fill(const complex64* data, size_t n, TensorProto* proto) {
378 const float* p = reinterpret_cast<const float*>(data);
379 FieldType copy(p, p + n * 2);
380 proto->mutable_scomplex_val()->Swap(©);
381 }
382 };
383
384 template <>
385 struct ProtoHelper<complex128> {
386 typedef Helper<double>::RepeatedFieldType FieldType;
Begintensorflow::__anon2ce1c7b90111::ProtoHelper387 static const complex128* Begin(const TensorProto& proto) {
388 return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
389 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper390 static size_t NumElements(const TensorProto& proto) {
391 return proto.dcomplex_val().size() / 2;
392 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper393 static void Fill(const complex128* data, size_t n, TensorProto* proto) {
394 const double* p = reinterpret_cast<const double*>(data);
395 FieldType copy(p, p + n * 2);
396 proto->mutable_dcomplex_val()->Swap(©);
397 }
398 };
399
400 template <>
401 struct ProtoHelper<qint32> {
402 typedef Helper<int32>::RepeatedFieldType FieldType;
Begintensorflow::__anon2ce1c7b90111::ProtoHelper403 static const qint32* Begin(const TensorProto& proto) {
404 return reinterpret_cast<const qint32*>(proto.int_val().data());
405 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper406 static size_t NumElements(const TensorProto& proto) {
407 return proto.int_val().size();
408 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper409 static void Fill(const qint32* data, size_t n, TensorProto* proto) {
410 const int32* p = reinterpret_cast<const int32*>(data);
411 FieldType copy(p, p + n);
412 proto->mutable_int_val()->Swap(©);
413 }
414 };
415
416 template <>
417 struct ProtoHelper<bfloat16> {
Filltensorflow::__anon2ce1c7b90111::ProtoHelper418 static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
419 proto->mutable_half_val()->Reserve(n);
420 for (size_t i = 0; i < n; ++i) {
421 proto->mutable_half_val()->AddAlreadyReserved(data[i].value);
422 }
423 }
424 };
425
426 template <>
427 struct ProtoHelper<Eigen::half> {
Filltensorflow::__anon2ce1c7b90111::ProtoHelper428 static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
429 proto->mutable_half_val()->Reserve(n);
430 for (size_t i = 0; i < n; ++i) {
431 proto->mutable_half_val()->AddAlreadyReserved(data[i].x);
432 }
433 }
434 };
435
436 template <typename T>
Buffer(Allocator * a,int64 n)437 Buffer<T>::Buffer(Allocator* a, int64 n)
438 : BufferBase(a), data_(a->Allocate<T>(n)), elem_(n) {}
439
440 template <typename T>
Buffer(Allocator * a,int64 n,const AllocationAttributes & allocation_attr)441 Buffer<T>::Buffer(Allocator* a, int64 n,
442 const AllocationAttributes& allocation_attr)
443 : BufferBase(a), data_(a->Allocate<T>(n, allocation_attr)), elem_(n) {}
444
445 template <typename T>
~Buffer()446 Buffer<T>::~Buffer() {
447 if (data_) {
448 if (LogMemory::IsEnabled()) {
449 RecordDeallocation();
450 }
451 alloc_->Deallocate<T>(data_, elem_);
452 }
453 }
454
455 // Allocates a T[n] buffer. Fills in the buffer with repeated values
456 // in "in". If "in" has less values than "n", fills the rest of T[n]
457 // with the last value. If "in" has no values, fills T[n] with the
458 // default value for T.
459 //
460 // This routine is using the typed fields (float_val, etc.) in the
461 // tensor proto as opposed to the untyped binary representation
462 // (tensor_content). This is used when we expect the TensorProto is
463 // used by a client program which may not know how to encode a tensor
464 // in the compact binary representation.
465 template <typename T>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)466 TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
467 CHECK_GT(n, 0);
468 Buffer<T>* buf = new Buffer<T>(a, n);
469 T* data = buf->template base<T>();
470 if (data == nullptr) {
471 buf->Unref();
472 return nullptr;
473 }
474
475 const int64 in_n = ProtoHelper<T>::NumElements(in);
476 if (in_n <= 0) {
477 std::fill_n(data, n, T());
478 } else {
479 auto begin = ProtoHelper<T>::Begin(in);
480 if (n <= in_n) {
481 std::copy_n(begin, n, data);
482 } else {
483 std::copy_n(begin, in_n, data);
484 const T& last = *(data + in_n - 1);
485 std::fill_n(data + in_n, n - in_n, last);
486 }
487 }
488
489 return buf;
490 }
491
492 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)493 TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
494 int64 n) {
495 CHECK_GT(n, 0);
496 Buffer<Variant>* buf = new Buffer<Variant>(a, n);
497 Variant* data = buf->template base<Variant>();
498 if (data == nullptr) {
499 buf->Unref();
500 return nullptr;
501 }
502 const int64 in_n = ProtoHelper<Variant>::NumElements(in);
503 if (in_n <= 0) {
504 std::fill_n(data, n, Variant());
505 } else {
506 for (int64 i = 0; i < in_n; ++i) {
507 data[i] = in.variant_val(i);
508 if (!DecodeUnaryVariant(&data[i])) {
509 LOG(ERROR) << "Could not decode variant with type_name: \""
510 << data[i].TypeName()
511 << "\". Perhaps you forgot to register a "
512 "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
513 buf->Unref();
514 return nullptr;
515 }
516 }
517 for (int64 i = in_n; i < n; ++i) {
518 data[i] = Variant();
519 }
520 }
521 return buf;
522 }
523
524 // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
525 // identical to uint16 but with data stored in half_val instead of int_val (ie.,
526 // we don't use ProtoHelper<uint16>).
527 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)528 TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
529 int64 n) {
530 CHECK_GT(n, 0);
531 Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
532 uint16* data = buf->template base<uint16>();
533 if (data == nullptr) {
534 buf->Unref();
535 return nullptr;
536 }
537 const int64 in_n = in.half_val().size();
538 auto begin = in.half_val().begin();
539 if (n <= in_n) {
540 std::copy_n(begin, n, data);
541 } else if (in_n > 0) {
542 std::copy_n(begin, in_n, data);
543 const uint16 last = *(data + in_n - 1);
544 std::fill_n(data + in_n, n - in_n, last);
545 } else {
546 std::fill_n(data, n, 0);
547 }
548 return buf;
549 }
550
551 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)552 TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
553 int64 n) {
554 CHECK_GT(n, 0);
555 Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
556 uint16* data = buf->template base<uint16>();
557 if (data == nullptr) {
558 buf->Unref();
559 return nullptr;
560 }
561 const int64 in_n = in.half_val().size();
562 auto begin = in.half_val().begin();
563 if (n <= in_n) {
564 std::copy_n(begin, n, data);
565 } else if (in_n > 0) {
566 std::copy_n(begin, in_n, data);
567 const uint16 last = *(data + in_n - 1);
568 std::fill_n(data + in_n, n - in_n, last);
569 } else {
570 std::fill_n(data, n, 0);
571 }
572 return buf;
573 }
574
575 // Copies T[n] stored in the buffer "in" into the repeated field in
576 // "out" corresponding to type T.
577 template <typename T>
ToProtoField(const TensorBuffer & in,int64 n,TensorProto * out)578 void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) {
579 const T* data = in.base<const T>();
580 // NOTE: T may not the same as
581 // ProtoHelper<T>::FieldType::value_type. E.g., T==int16,
582 // ProtoHelper<T>::FieldType::value_type==int32. If performance is
583 // critical, we can specialize T=float and do memcpy directly.
584 ProtoHelper<T>::Fill(data, n, out);
585 }
586
RefIfNonNull(core::RefCounted * buf)587 void RefIfNonNull(core::RefCounted* buf) {
588 if (buf) buf->Ref();
589 }
590
UnrefIfNonNull(core::RefCounted * buf)591 void UnrefIfNonNull(core::RefCounted* buf) {
592 if (buf) buf->Unref();
593 }
594
595 } // end namespace
596
Tensor()597 Tensor::Tensor() : Tensor(DT_FLOAT) {}
598
Tensor(DataType type)599 Tensor::Tensor(DataType type) : shape_({0}), buf_(nullptr) { set_dtype(type); }
600
Tensor(DataType type,const TensorShape & shape,TensorBuffer * buf)601 Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
602 : shape_(shape), buf_(buf) {
603 set_dtype(type);
604 RefIfNonNull(buf);
605 }
606
IsInitialized() const607 bool Tensor::IsInitialized() const {
608 return (buf_ != nullptr && buf_->data() != nullptr) ||
609 shape_.num_elements() == 0;
610 }
611
CheckType(DataType expected_dtype) const612 void Tensor::CheckType(DataType expected_dtype) const {
613 CHECK_EQ(dtype(), expected_dtype);
614 }
615
CheckTypeAndIsAligned(DataType expected_dtype) const616 void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
617 CHECK_EQ(dtype(), expected_dtype);
618 CHECK(IsAligned()) << "CheckTypeAndIsAligned";
619 }
620
CheckIsAlignedAndSingleElement() const621 void Tensor::CheckIsAlignedAndSingleElement() const {
622 CHECK(IsAligned()) << "Aligned and single element";
623 CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
624 }
625
~Tensor()626 Tensor::~Tensor() { UnrefIfNonNull(buf_); }
627
CopyFromInternal(const Tensor & other,const TensorShape & shape)628 void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
629 CHECK_EQ(shape.num_elements(), other.NumElements());
630 // Data type will be overwritten if this == &other, since dtype is part of
631 // shape.
632 DataType other_dtype = other.dtype();
633 shape_ = shape;
634 set_dtype(other_dtype);
635 if (buf_ != other.buf_) {
636 UnrefIfNonNull(buf_);
637 buf_ = other.buf_;
638 RefIfNonNull(buf_);
639 }
640 }
641
UnsafeCopyFromInternal(const Tensor & other,DataType dtype,const TensorShape & shape)642 void Tensor::UnsafeCopyFromInternal(const Tensor& other, DataType dtype,
643 const TensorShape& shape) {
644 int in_size = DataTypeSize(other.dtype());
645 int out_size = DataTypeSize(dtype);
646 CHECK_NE(in_size, 0);
647 CHECK_NE(out_size, 0);
648 CHECK_EQ(shape.num_elements() * out_size,
649 other.shape().num_elements() * in_size);
650 shape_ = shape;
651 shape_.set_data_type(dtype);
652 if (buf_ != other.buf_) {
653 UnrefIfNonNull(buf_);
654 buf_ = other.buf_;
655 RefIfNonNull(buf_);
656 }
657 }
658
659 // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
660 // For the latter case, we have to make sure that the refcount is
661 // one both for the SubBuffer _and_ the underlying TensorBuffer.
RefCountIsOne() const662 bool Tensor::RefCountIsOne() const {
663 return buf_ != nullptr && buf_->RefCountIsOne() &&
664 buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
665 }
666
667 // The macro CASES() expands to a switch statement conditioned on
668 // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
669 #define SINGLE_ARG(...) __VA_ARGS__
670 #define CASE(TYPE, STMTS) \
671 case DataTypeToEnum<TYPE>::value: { \
672 typedef TYPE T; \
673 STMTS; \
674 break; \
675 }
676 #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
677 switch (TYPE_ENUM) { \
678 CASE(float, SINGLE_ARG(STMTS)) \
679 CASE(double, SINGLE_ARG(STMTS)) \
680 CASE(int32, SINGLE_ARG(STMTS)) \
681 CASE(uint8, SINGLE_ARG(STMTS)) \
682 CASE(uint16, SINGLE_ARG(STMTS)) \
683 CASE(uint32, SINGLE_ARG(STMTS)) \
684 CASE(uint64, SINGLE_ARG(STMTS)) \
685 CASE(int16, SINGLE_ARG(STMTS)) \
686 CASE(int8, SINGLE_ARG(STMTS)) \
687 CASE(string, SINGLE_ARG(STMTS)) \
688 CASE(complex64, SINGLE_ARG(STMTS)) \
689 CASE(complex128, SINGLE_ARG(STMTS)) \
690 CASE(int64, SINGLE_ARG(STMTS)) \
691 CASE(bool, SINGLE_ARG(STMTS)) \
692 CASE(qint32, SINGLE_ARG(STMTS)) \
693 CASE(quint8, SINGLE_ARG(STMTS)) \
694 CASE(qint8, SINGLE_ARG(STMTS)) \
695 CASE(quint16, SINGLE_ARG(STMTS)) \
696 CASE(qint16, SINGLE_ARG(STMTS)) \
697 CASE(bfloat16, SINGLE_ARG(STMTS)) \
698 CASE(Eigen::half, SINGLE_ARG(STMTS)) \
699 CASE(ResourceHandle, SINGLE_ARG(STMTS)) \
700 CASE(Variant, SINGLE_ARG(STMTS)) \
701 case DT_INVALID: \
702 INVALID; \
703 break; \
704 default: \
705 DEFAULT; \
706 break; \
707 }
708
709 #define CASES(TYPE_ENUM, STMTS) \
710 CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
711 , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
712
Tensor(Allocator * a,DataType type,const TensorShape & shape)713 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
714 : shape_(shape), buf_(nullptr) {
715 set_dtype(type);
716 CHECK_NOTNULL(a);
717 if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
718 CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
719 }
720 if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
721 LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
722 *this);
723 }
724 }
725
Tensor(Allocator * a,DataType type,const TensorShape & shape,const AllocationAttributes & allocation_attr)726 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
727 const AllocationAttributes& allocation_attr)
728 : shape_(shape), buf_(nullptr) {
729 set_dtype(type);
730 CHECK_NOTNULL(a);
731 if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
732 CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
733 }
734 if (!allocation_attr.allocation_will_be_logged && buf_ != nullptr &&
735 buf_->data() != nullptr && LogMemory::IsEnabled()) {
736 LogMemory::RecordTensorAllocation("Unknown (with attributes)",
737 LogMemory::UNKNOWN_STEP_ID, *this);
738 }
739 }
740
Tensor(DataType type,const TensorShape & shape)741 Tensor::Tensor(DataType type, const TensorShape& shape)
742 : Tensor(cpu_allocator(), type, shape) {}
743
744 template <typename T>
745 class SubBuffer : public TensorBuffer {
746 public:
747 // This buffer is an alias to buf[delta, delta + n).
SubBuffer(TensorBuffer * buf,int64 delta,int64 n)748 SubBuffer(TensorBuffer* buf, int64 delta, int64 n)
749 : root_(buf->root_buffer()), data_(buf->base<T>() + delta), elem_(n) {
750 // Sanity check. The caller should ensure the sub buffer is valid.
751 CHECK_LE(root_->base<T>(), this->base<T>());
752 T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
753 CHECK_LE(this->base<T>(), root_limit);
754 CHECK_LE(this->base<T>() + n, root_limit);
755 // Hold a ref of the underlying root buffer.
756 // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
757 root_->Ref();
758 }
759
data() const760 void* data() const override { return data_; }
size() const761 size_t size() const override { return sizeof(T) * elem_; }
root_buffer()762 TensorBuffer* root_buffer() override { return root_; }
FillAllocationDescription(AllocationDescription * proto) const763 void FillAllocationDescription(AllocationDescription* proto) const override {
764 root_->FillAllocationDescription(proto);
765 }
766
767 private:
768 TensorBuffer* root_;
769 T* data_;
770 int64 elem_;
771
~SubBuffer()772 ~SubBuffer() override { root_->Unref(); }
773
774 TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
775 };
776
Slice(int64 start,int64 limit) const777 Tensor Tensor::Slice(int64 start, int64 limit) const {
778 CHECK_GE(dims(), 1);
779 CHECK_LE(0, start);
780 CHECK_LE(start, limit);
781 int64 dim0_size = shape_.dim_size(0);
782 CHECK_LE(limit, dim0_size);
783 if ((start == 0) && (limit == dim0_size)) {
784 return *this;
785 }
786 Tensor ret;
787 ret.shape_ = shape_;
788 ret.set_dtype(dtype());
789 ret.buf_ = nullptr;
790 if (dim0_size > 0) {
791 const int64 elems_per_dim0 = NumElements() / dim0_size;
792 const int64 delta = start * elems_per_dim0;
793 dim0_size = limit - start;
794 ret.shape_.set_dim(0, dim0_size);
795 const int64 num_elems = dim0_size * elems_per_dim0;
796 if (buf_) {
797 DataType dt = dtype();
798 CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
799 }
800 }
801 return ret;
802 }
803
FromProto(const TensorProto & proto)804 bool Tensor::FromProto(const TensorProto& proto) {
805 return FromProto(cpu_allocator(), proto);
806 }
807
FromProto(Allocator * a,const TensorProto & proto)808 bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
809 CHECK_NOTNULL(a);
810 TensorBuffer* p = nullptr;
811 if (!TensorShape::IsValid(proto.tensor_shape())) return false;
812 if (proto.dtype() == DT_INVALID) return false;
813 TensorShape shape(proto.tensor_shape());
814 const int64 N = shape.num_elements();
815 if (N > 0 && proto.dtype()) {
816 bool dtype_error = false;
817 if (!proto.tensor_content().empty()) {
818 const auto& content = proto.tensor_content();
819 CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
820 dtype_error = true, dtype_error = true);
821 } else {
822 CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
823 dtype_error = true, dtype_error = true);
824 }
825 if (dtype_error || p == nullptr) return false;
826 }
827 shape_ = shape;
828 set_dtype(proto.dtype());
829 UnrefIfNonNull(buf_);
830 buf_ = p;
831 // TODO(misard) add tracking of which kernels and steps are calling
832 // FromProto.
833 if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
834 LogMemory::RecordTensorAllocation("Unknown (from Proto)",
835 LogMemory::UNKNOWN_STEP_ID, *this);
836 }
837 return true;
838 }
839
AsProtoField(TensorProto * proto) const840 void Tensor::AsProtoField(TensorProto* proto) const {
841 proto->Clear();
842 shape_.AsProto(proto->mutable_tensor_shape());
843 proto->set_dtype(dtype());
844 if (buf_) {
845 CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
846 }
847 }
848
AsProtoTensorContent(TensorProto * proto) const849 void Tensor::AsProtoTensorContent(TensorProto* proto) const {
850 proto->Clear();
851 proto->set_dtype(dtype());
852 shape_.AsProto(proto->mutable_tensor_shape());
853 if (buf_) {
854 CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
855 proto->mutable_tensor_content()));
856 }
857 }
858
TotalBytes() const859 size_t Tensor::TotalBytes() const {
860 if (shape_.num_elements() == 0) return 0;
861 CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
862 CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
863 return 0; // Makes compiler happy.
864 }
865
AllocatedBytes() const866 size_t Tensor::AllocatedBytes() const {
867 TensorDescription tensor_description;
868 FillDescription(&tensor_description);
869 if (tensor_description.has_allocation_description() &&
870 tensor_description.allocation_description().allocated_bytes() > 0) {
871 return tensor_description.allocation_description().allocated_bytes();
872 } else {
873 // Fall back to TotalBytes() if the allocator doesn't have its size.
874 return TotalBytes();
875 }
876 }
877
CanUseDMA() const878 bool Tensor::CanUseDMA() const {
879 CASES(dtype(), return is_simple_type<T>::value);
880 return false; // Makes compiler happy.
881 }
882
883 #undef CASES
884 #undef CASE
885
886 namespace {
887 // Print from left dim to right dim recursively.
888 template <typename T>
PrintOneDim(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 limit,int shape_size,const T * data,int64 * data_index,string * result)889 void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
890 int64 limit, int shape_size, const T* data, int64* data_index,
891 string* result) {
892 if (*data_index >= limit) return;
893 int64 element_count = shape[dim_index];
894 // We have reached the right-most dimension of the tensor.
895 if (dim_index == shape_size - 1) {
896 for (int64 i = 0; i < element_count; i++) {
897 if (*data_index >= limit) return;
898 if (i > 0) strings::StrAppend(result, " ");
899 strings::StrAppend(result, data[(*data_index)++]);
900 }
901 return;
902 }
903 // Loop every element of one dim.
904 for (int64 i = 0; i < element_count; i++) {
905 bool flag = false;
906 if (*data_index < limit) {
907 strings::StrAppend(result, "[");
908 flag = true;
909 }
910 // As for each element, print the sub-dim.
911 PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
912 result);
913 if (*data_index < limit || flag) {
914 strings::StrAppend(result, "]");
915 flag = false;
916 }
917 }
918 }
919
920 template <typename T>
SummarizeArray(int64 limit,int64 num_elts,const TensorShape & tensor_shape,const char * data)921 string SummarizeArray(int64 limit, int64 num_elts,
922 const TensorShape& tensor_shape, const char* data) {
923 string ret;
924 const T* array = reinterpret_cast<const T*>(data);
925
926 const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
927 if (shape.empty()) {
928 for (int64 i = 0; i < limit; ++i) {
929 if (i > 0) strings::StrAppend(&ret, " ");
930 strings::StrAppend(&ret, array[i]);
931 }
932 if (num_elts > limit) strings::StrAppend(&ret, "...");
933 return ret;
934 }
935 int64 data_index = 0;
936 const int shape_size = tensor_shape.dims();
937 PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
938
939 if (num_elts > limit) strings::StrAppend(&ret, "...");
940 return ret;
941 }
942 } // namespace
943
SummarizeValue(int64 max_entries) const944 string Tensor::SummarizeValue(int64 max_entries) const {
945 const int64 num_elts = NumElements();
946 size_t limit = std::min(max_entries, num_elts);
947 if ((limit > 0) && (buf_ == nullptr)) {
948 return strings::StrCat("uninitialized Tensor of ", num_elts,
949 " elements of type ", dtype());
950 }
951 const char* data = limit > 0 ? tensor_data().data() : nullptr;
952 switch (dtype()) {
953 case DT_HALF:
954 return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data);
955 break;
956 case DT_FLOAT:
957 return SummarizeArray<float>(limit, num_elts, shape_, data);
958 break;
959 case DT_DOUBLE:
960 return SummarizeArray<double>(limit, num_elts, shape_, data);
961 break;
962 case DT_UINT32:
963 return SummarizeArray<uint32>(limit, num_elts, shape_, data);
964 break;
965 case DT_INT32:
966 return SummarizeArray<int32>(limit, num_elts, shape_, data);
967 break;
968 case DT_UINT8:
969 case DT_QUINT8:
970 return SummarizeArray<uint8>(limit, num_elts, shape_, data);
971 break;
972 case DT_UINT16:
973 case DT_QUINT16:
974 return SummarizeArray<uint16>(limit, num_elts, shape_, data);
975 break;
976 case DT_INT16:
977 case DT_QINT16:
978 return SummarizeArray<int16>(limit, num_elts, shape_, data);
979 break;
980 case DT_INT8:
981 case DT_QINT8:
982 return SummarizeArray<int8>(limit, num_elts, shape_, data);
983 break;
984 case DT_UINT64:
985 return SummarizeArray<uint64>(limit, num_elts, shape_, data);
986 break;
987 case DT_INT64:
988 return SummarizeArray<int64>(limit, num_elts, shape_, data);
989 break;
990 case DT_BOOL:
991 // TODO(tucker): Is it better to emit "True False..."? This
992 // will emit "1 0..." which is more compact.
993 return SummarizeArray<bool>(limit, num_elts, shape_, data);
994 break;
995 default: {
996 // All irregular cases
997 string ret;
998 // TODO(irving): Don't call flat every time around this
999 // loop.
1000 for (size_t i = 0; i < limit; ++i) {
1001 if (i > 0) strings::StrAppend(&ret, " ");
1002 switch (dtype()) {
1003 case DT_STRING:
1004 strings::StrAppend(&ret, str_util::CEscape(flat<string>()(i)));
1005 break;
1006 case DT_VARIANT: {
1007 const Variant& v = flat<Variant>()(i);
1008 strings::StrAppend(&ret, v.DebugString());
1009 } break;
1010 default:
1011 // TODO(zhifengc, josh11b): Pretty-print other types (bool,
1012 // complex64, quantized).
1013 strings::StrAppend(&ret, "?");
1014 }
1015 }
1016 if (max_entries < num_elts) strings::StrAppend(&ret, "...");
1017 return ret;
1018 }
1019 }
1020 }
1021
tensor_data() const1022 StringPiece Tensor::tensor_data() const {
1023 if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors
1024 return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
1025 }
1026
SharesBufferWith(const Tensor & b) const1027 bool Tensor::SharesBufferWith(const Tensor& b) const {
1028 return buf_ != nullptr && b.buf_ != nullptr &&
1029 buf_->root_buffer() == b.buf_->root_buffer();
1030 }
1031
DebugString() const1032 string Tensor::DebugString() const {
1033 return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1034 " shape: ", shape().DebugString(),
1035 " values: ", SummarizeValue(3), ">");
1036 }
1037
FillDescription(TensorDescription * description) const1038 void Tensor::FillDescription(TensorDescription* description) const {
1039 description->set_dtype(dtype());
1040 shape().AsProto(description->mutable_shape());
1041 if (buf_ != nullptr && buf_->data() != nullptr) {
1042 buf_->FillAllocationDescription(
1043 description->mutable_allocation_description());
1044 }
1045 }
1046
ComputeFlatInnerDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1047 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
1048 gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1049 gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1050 int64 offset = orig.size() - num_out_dims;
1051 for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
1052 const int64 in_dim = out_dim + offset;
1053 out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
1054 }
1055 for (int64 in_dim = 0; in_dim < offset; ++in_dim) {
1056 out_dims[0] *= orig[in_dim];
1057 }
1058 return out_dims;
1059 }
1060
ComputeFlatOuterDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1061 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
1062 gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1063 gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1064 for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
1065 out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
1066 }
1067 for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
1068 out_dims[num_out_dims - 1] *= orig[in_dim];
1069 }
1070 return out_dims;
1071 }
1072
1073 } // namespace tensorflow
1074