1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implementation notes:
17 //
18 // Tensor.cc uses a few templated classes and structs to facilitate
19 // implementation of the Tensor class.
20 //
21 // * Buffer<T>: provides the implementation for a typed array T[n].
22 // The array is allocated by the given allocator. It runs T's
23 // default constructors and destructors when T is not a simple type
24 // (e.g., string.), and skips them otherwise.
25 //
26 // * Helper<T>: provides various routines given type T. The routines
27 // includes running the constructor and destructor of T[], encoding
28 // an decoding T[] into/from a Cord, etc.
29
30 #include "tensorflow/core/framework/tensor.h"
31
32 #include "tensorflow/core/framework/allocation_description.pb.h"
33 #include "tensorflow/core/framework/log_memory.h"
34 #include "tensorflow/core/framework/resource_handle.pb.h"
35 #include "tensorflow/core/framework/tensor.pb.h"
36 #include "tensorflow/core/framework/tensor_description.pb.h"
37 #include "tensorflow/core/framework/type_traits.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/framework/variant.h"
40 #include "tensorflow/core/framework/variant_encode_decode.h"
41 #include "tensorflow/core/framework/variant_op_registry.h"
42 #include "tensorflow/core/framework/variant_tensor_data.h"
43 #include "tensorflow/core/lib/core/coding.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
47 #include "tensorflow/core/lib/gtl/stl_util.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/core/platform/protobuf.h"
53 #include "tensorflow/core/platform/tensor_coding.h"
54 #include "tensorflow/core/platform/types.h"
55
56 namespace tensorflow {
57
58 // Allow Tensors to be stored inside Variants with automatic
59 // encoding/decoding when those Variants are themselves being decoded
60 // in a Tensor's FromProto.
61 //
62 // NOTE(mrry): The corresponding "copy function" registrations can be found in
63 // ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
64 // code).
65 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
66
67 namespace {
68
69 // An un-templated base class for Buffer.
70 class BufferBase : public TensorBuffer {
71 public:
BufferBase(Allocator * alloc,void * data_ptr)72 explicit BufferBase(Allocator* alloc, void* data_ptr)
73 : TensorBuffer(data_ptr), alloc_(alloc) {}
74
root_buffer()75 TensorBuffer* root_buffer() override { return this; }
FillAllocationDescription(AllocationDescription * proto) const76 void FillAllocationDescription(AllocationDescription* proto) const override {
77 void* data_ptr = data();
78 int64 rb = size();
79 proto->set_requested_bytes(rb);
80 proto->set_allocator_name(alloc_->Name());
81 proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
82 if (alloc_->TracksAllocationSizes()) {
83 int64 ab = alloc_->AllocatedSize(data_ptr);
84 proto->set_allocated_bytes(ab);
85 int64 id = alloc_->AllocationId(data_ptr);
86 if (id > 0) {
87 proto->set_allocation_id(id);
88 }
89 if (RefCountIsOne()) {
90 proto->set_has_single_reference(true);
91 }
92 }
93 }
94
95 protected:
RecordDeallocation()96 void RecordDeallocation() {
97 LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
98 alloc_->Name());
99 }
100
101 Allocator* const alloc_;
102 };
103
104 // Typed ref-counted buffer: T[n].
105 template <typename T>
106 class Buffer : public BufferBase {
107 public:
108 Buffer(Allocator* a, int64 n);
109 Buffer(Allocator* a, int64 n, const AllocationAttributes& allocation_attr);
110
size() const111 size_t size() const override { return sizeof(T) * elem_; }
112
113 private:
114 T* data_;
115 int64 elem_;
116
117 ~Buffer() override;
118
119 TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
120 };
121
LogUnexpectedSize(int64 actual,int64 expected)122 void LogUnexpectedSize(int64 actual, int64 expected) {
123 LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
124 }
125
126 // A set of helper functions depending on T.
127 template <typename T>
128 struct Helper {
129 // By default, we assume T is a simple type (float, int32, etc.)
130 static_assert(is_simple_type<T>::value, "T is not a simple type.");
131 typedef protobuf::RepeatedField<T> RepeatedFieldType;
132
133 // Encoder of simple type T to a string. We do a copy.
134 template <typename Destination>
Encodetensorflow::__anon2ce1c7b90111::Helper135 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
136 DCHECK_EQ(in->size(), sizeof(T) * n);
137 port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
138 out);
139 }
140
141 // Decoder of simple type T. Copy the bytes from "in" into the
142 // tensor buffer.
143 template <typename Source>
Decodetensorflow::__anon2ce1c7b90111::Helper144 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
145 if (in.size() != sizeof(T) * n) {
146 LogUnexpectedSize(in.size(), sizeof(T) * n);
147 return nullptr;
148 }
149 Buffer<T>* buf = new Buffer<T>(a, n);
150 char* data = buf->template base<char>();
151 if (data == nullptr) {
152 buf->Unref();
153 return nullptr;
154 }
155 port::CopyToArray(in, data);
156 return buf;
157 }
158
159 // Memory usage.
TotalBytestensorflow::__anon2ce1c7b90111::Helper160 static int64 TotalBytes(TensorBuffer* in, int64 n) {
161 DCHECK_EQ(in->size(), sizeof(T) * n);
162 return in->size();
163 }
164 };
165
166 // Helper specialization for string (the only non-simple type we
167 // support).
168 template <>
169 struct Helper<string> {
170 // Proto message uses RepeatedFieldType to hold repeated T.
171 typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
172
173 // Encodes "n" elements of type string stored in "in" into Cord
174 // "out", which is usually the TensorProto::tensor_content.
175 template <typename Destination>
Encodetensorflow::__anon2ce1c7b90111::Helper176 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
177 port::EncodeStringList(in->base<const string>(), n, out);
178 }
179
180 // Decodes "n" elements of type string from "in" and constructs a
181 // buffer out of it. Returns nullptr if the decoding fails. "in" is
182 // usually the TensorProto::tensor_content.
183 template <typename Source>
Decodetensorflow::__anon2ce1c7b90111::Helper184 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
185 Buffer<string>* buf = new Buffer<string>(a, n);
186 string* strings = buf->template base<string>();
187 if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
188 buf->Unref();
189 return nullptr;
190 }
191 return buf;
192 }
193
194 // Returns the estimated memory usage of "n" elements of type T
195 // stored in buffer "in".
TotalBytestensorflow::__anon2ce1c7b90111::Helper196 static int64 TotalBytes(TensorBuffer* in, int n) {
197 int64 tot = in->size();
198 DCHECK_EQ(tot, sizeof(string) * n);
199 const string* p = in->base<const string>();
200 for (int i = 0; i < n; ++i, ++p) tot += p->size();
201 return tot;
202 }
203 };
204
205 template <>
206 struct Helper<ResourceHandle> {
207 // Proto message uses RepeatedFieldType to hold repeated T.
208 typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
209
210 // Encodes "n" elements of type ResourceHandle stored in "in" into destination
211 // "out", which is usually the TensorProto::tensor_content.
212 template <typename Destination>
Encodetensorflow::__anon2ce1c7b90111::Helper213 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
214 EncodeResourceHandleList(in->base<const ResourceHandle>(), n,
215 port::NewStringListEncoder(out));
216 }
217
218 // Decodes "n" elements of type string from "in" and constructs a
219 // buffer out of it. Returns nullptr if the decoding fails. "in" is
220 // usually the TensorProto::tensor_content.
221 template <typename Source>
Decodetensorflow::__anon2ce1c7b90111::Helper222 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
223 auto* buf = new Buffer<ResourceHandle>(a, n);
224 ResourceHandle* ps = buf->template base<ResourceHandle>();
225 if (ps == nullptr ||
226 !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) {
227 buf->Unref();
228 return nullptr;
229 }
230 return buf;
231 }
232
233 // Returns the estimated memory usage of "n" elements of type T
234 // stored in buffer "in".
TotalBytestensorflow::__anon2ce1c7b90111::Helper235 static int64 TotalBytes(TensorBuffer* in, int n) {
236 return n * sizeof(ResourceHandle);
237 }
238 };
239
240 template <>
241 struct Helper<Variant> {
242 // Encodes "n" elements of type Variant stored in "in" into destination
243 // "out", which is usually the TensorProto::tensor_content.
244 template <typename Destination>
Encodetensorflow::__anon2ce1c7b90111::Helper245 static void Encode(TensorBuffer* in, int64 n, Destination* out) {
246 EncodeVariantList(in->base<const Variant>(), n,
247 port::NewStringListEncoder(out));
248 }
249
250 // Decodes "n" elements of type Variant from "in" and constructs a
251 // buffer out of it. Returns nullptr if the decoding fails. "in" is
252 // usually the TensorProto::tensor_content.
253 template <typename Source>
Decodetensorflow::__anon2ce1c7b90111::Helper254 static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
255 auto* buf = new Buffer<Variant>(a, n);
256 Variant* ps = buf->template base<Variant>();
257 if (ps == nullptr ||
258 !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) {
259 buf->Unref();
260 return nullptr;
261 }
262 return buf;
263 }
264
265 // Returns the estimated memory usage of "n" elements of type T
266 // stored in buffer "in".
TotalBytestensorflow::__anon2ce1c7b90111::Helper267 static int64 TotalBytes(TensorBuffer* in, int n) {
268 return n * sizeof(Variant);
269 }
270 };
271
272 template <typename T>
273 struct ProtoHelper {};
274
275 // For a C++ type "T" (float, double, int32, etc.), the repeated field
276 // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
277 // int32, string, etc) in the TensorProto is used for serializing the
278 // tensor of type "T".
279 #define PROTO_TRAITS(T, F, N) \
280 template <> \
281 struct ProtoHelper<T> { \
282 typedef Helper<F>::RepeatedFieldType FieldType; \
283 static FieldType::const_iterator Begin(const TensorProto& proto) { \
284 return proto.N##_val().begin(); \
285 } \
286 static size_t NumElements(const TensorProto& proto) { \
287 return proto.N##_val().size(); \
288 } \
289 static void Fill(const T* data, size_t n, TensorProto* proto) { \
290 typename ProtoHelper<T>::FieldType copy(data, data + n); \
291 proto->mutable_##N##_val()->Swap(©); \
292 } \
293 };
294 PROTO_TRAITS(float, float, float);
295 PROTO_TRAITS(double, double, double);
296 PROTO_TRAITS(int32, int32, int);
297 PROTO_TRAITS(uint8, int32, int);
298 PROTO_TRAITS(uint16, int32, int);
299 PROTO_TRAITS(uint32, uint32, uint32);
300 PROTO_TRAITS(int16, int32, int);
301 PROTO_TRAITS(int8, int32, int);
302 PROTO_TRAITS(bool, bool, bool);
303 PROTO_TRAITS(string, string, string);
304 PROTO_TRAITS(qint8, int32, int);
305 PROTO_TRAITS(quint8, int32, int);
306 PROTO_TRAITS(qint16, int32, int);
307 PROTO_TRAITS(quint16, int32, int);
308 #undef PROTO_TRAITS
309
310 template <>
311 struct ProtoHelper<int64> {
Begintensorflow::__anon2ce1c7b90111::ProtoHelper312 static const int64* Begin(const TensorProto& proto) {
313 return reinterpret_cast<const int64*>(proto.int64_val().begin());
314 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper315 static size_t NumElements(const TensorProto& proto) {
316 return proto.int64_val().size();
317 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper318 static void Fill(const int64* data, size_t n, TensorProto* proto) {
319 protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
320 proto->mutable_int64_val()->Swap(©);
321 }
322 };
323
324 template <>
325 struct ProtoHelper<uint64> {
Begintensorflow::__anon2ce1c7b90111::ProtoHelper326 static const uint64* Begin(const TensorProto& proto) {
327 return reinterpret_cast<const uint64*>(proto.uint64_val().begin());
328 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper329 static size_t NumElements(const TensorProto& proto) {
330 return proto.uint64_val().size();
331 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper332 static void Fill(const uint64* data, size_t n, TensorProto* proto) {
333 protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
334 proto->mutable_uint64_val()->Swap(©);
335 }
336 };
337
338 template <>
339 struct ProtoHelper<ResourceHandle> {
Begintensorflow::__anon2ce1c7b90111::ProtoHelper340 static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
341 const TensorProto& proto) {
342 return proto.resource_handle_val().begin();
343 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper344 static size_t NumElements(const TensorProto& proto) {
345 return proto.resource_handle_val().size();
346 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper347 static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
348 auto* handles = proto->mutable_resource_handle_val();
349 handles->Clear();
350 for (size_t i = 0; i < n; i++) {
351 data[i].AsProto(handles->Add());
352 }
353 }
354 };
355
356 template <>
357 struct ProtoHelper<Variant> {
358 static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
Begintensorflow::__anon2ce1c7b90111::ProtoHelper359 Begin(const TensorProto& proto) {
360 return proto.variant_val().begin();
361 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper362 static size_t NumElements(const TensorProto& proto) {
363 return proto.variant_val().size();
364 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper365 static void Fill(const Variant* data, size_t n, TensorProto* proto) {
366 auto* variant_values = proto->mutable_variant_val();
367 variant_values->Clear();
368 for (size_t i = 0; i < n; ++i) {
369 VariantTensorData tmp;
370 data[i].Encode(&tmp);
371 tmp.ToProto(variant_values->Add());
372 }
373 }
374 };
375
376 template <>
377 struct ProtoHelper<complex64> {
378 typedef Helper<float>::RepeatedFieldType FieldType;
Begintensorflow::__anon2ce1c7b90111::ProtoHelper379 static const complex64* Begin(const TensorProto& proto) {
380 return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
381 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper382 static size_t NumElements(const TensorProto& proto) {
383 return proto.scomplex_val().size() / 2;
384 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper385 static void Fill(const complex64* data, size_t n, TensorProto* proto) {
386 const float* p = reinterpret_cast<const float*>(data);
387 FieldType copy(p, p + n * 2);
388 proto->mutable_scomplex_val()->Swap(©);
389 }
390 };
391
392 template <>
393 struct ProtoHelper<complex128> {
394 typedef Helper<double>::RepeatedFieldType FieldType;
Begintensorflow::__anon2ce1c7b90111::ProtoHelper395 static const complex128* Begin(const TensorProto& proto) {
396 return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
397 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper398 static size_t NumElements(const TensorProto& proto) {
399 return proto.dcomplex_val().size() / 2;
400 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper401 static void Fill(const complex128* data, size_t n, TensorProto* proto) {
402 const double* p = reinterpret_cast<const double*>(data);
403 FieldType copy(p, p + n * 2);
404 proto->mutable_dcomplex_val()->Swap(©);
405 }
406 };
407
408 template <>
409 struct ProtoHelper<qint32> {
410 typedef Helper<int32>::RepeatedFieldType FieldType;
Begintensorflow::__anon2ce1c7b90111::ProtoHelper411 static const qint32* Begin(const TensorProto& proto) {
412 return reinterpret_cast<const qint32*>(proto.int_val().data());
413 }
NumElementstensorflow::__anon2ce1c7b90111::ProtoHelper414 static size_t NumElements(const TensorProto& proto) {
415 return proto.int_val().size();
416 }
Filltensorflow::__anon2ce1c7b90111::ProtoHelper417 static void Fill(const qint32* data, size_t n, TensorProto* proto) {
418 const int32* p = reinterpret_cast<const int32*>(data);
419 FieldType copy(p, p + n);
420 proto->mutable_int_val()->Swap(©);
421 }
422 };
423
424 template <>
425 struct ProtoHelper<bfloat16> {
Filltensorflow::__anon2ce1c7b90111::ProtoHelper426 static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
427 proto->mutable_half_val()->Reserve(n);
428 for (size_t i = 0; i < n; ++i) {
429 proto->mutable_half_val()->AddAlreadyReserved(data[i].value);
430 }
431 }
432 };
433
434 template <>
435 struct ProtoHelper<Eigen::half> {
Filltensorflow::__anon2ce1c7b90111::ProtoHelper436 static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
437 proto->mutable_half_val()->Reserve(n);
438 for (size_t i = 0; i < n; ++i) {
439 proto->mutable_half_val()->AddAlreadyReserved(data[i].x);
440 }
441 }
442 };
443
444 template <typename T>
Buffer(Allocator * a,int64 n)445 Buffer<T>::Buffer(Allocator* a, int64 n)
446 : BufferBase(a, a->Allocate<T>(n)), elem_(n) {}
447
448 template <typename T>
Buffer(Allocator * a,int64 n,const AllocationAttributes & allocation_attr)449 Buffer<T>::Buffer(Allocator* a, int64 n,
450 const AllocationAttributes& allocation_attr)
451 : BufferBase(a, a->Allocate<T>(n, allocation_attr)), elem_(n) {}
452
453 template <typename T>
~Buffer()454 Buffer<T>::~Buffer() {
455 if (data()) {
456 if (LogMemory::IsEnabled()) {
457 RecordDeallocation();
458 }
459 alloc_->Deallocate<T>(static_cast<T*>(data()), elem_);
460 }
461 }
462
463 // Allocates a T[n] buffer. Fills in the buffer with repeated values
464 // in "in". If "in" has less values than "n", fills the rest of T[n]
465 // with the last value. If "in" has no values, fills T[n] with the
466 // default value for T.
467 //
468 // This routine is using the typed fields (float_val, etc.) in the
469 // tensor proto as opposed to the untyped binary representation
470 // (tensor_content). This is used when we expect the TensorProto is
471 // used by a client program which may not know how to encode a tensor
472 // in the compact binary representation.
473 template <typename T>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)474 TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
475 CHECK_GT(n, 0);
476 Buffer<T>* buf = new Buffer<T>(a, n);
477 T* data = buf->template base<T>();
478 if (data == nullptr) {
479 buf->Unref();
480 return nullptr;
481 }
482
483 const int64 in_n = ProtoHelper<T>::NumElements(in);
484 if (in_n <= 0) {
485 std::fill_n(data, n, T());
486 } else {
487 auto begin = ProtoHelper<T>::Begin(in);
488 if (n <= in_n) {
489 std::copy_n(begin, n, data);
490 } else {
491 std::copy_n(begin, in_n, data);
492 const T& last = *(data + in_n - 1);
493 std::fill_n(data + in_n, n - in_n, last);
494 }
495 }
496
497 return buf;
498 }
499
500 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)501 TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
502 int64 n) {
503 CHECK_GT(n, 0);
504 Buffer<Variant>* buf = new Buffer<Variant>(a, n);
505 Variant* data = buf->template base<Variant>();
506 if (data == nullptr) {
507 buf->Unref();
508 return nullptr;
509 }
510 const int64 in_n = ProtoHelper<Variant>::NumElements(in);
511 if (in_n <= 0) {
512 std::fill_n(data, n, Variant());
513 } else {
514 for (int64 i = 0; i < in_n; ++i) {
515 data[i] = in.variant_val(i);
516 if (!DecodeUnaryVariant(&data[i])) {
517 LOG(ERROR) << "Could not decode variant with type_name: \""
518 << data[i].TypeName()
519 << "\". Perhaps you forgot to register a "
520 "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
521 buf->Unref();
522 return nullptr;
523 }
524 }
525 for (int64 i = in_n; i < n; ++i) {
526 data[i] = Variant();
527 }
528 }
529 return buf;
530 }
531
532 // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
533 // identical to uint16 but with data stored in half_val instead of int_val (ie.,
534 // we don't use ProtoHelper<uint16>).
535 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)536 TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
537 int64 n) {
538 CHECK_GT(n, 0);
539 Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
540 uint16* data = buf->template base<uint16>();
541 if (data == nullptr) {
542 buf->Unref();
543 return nullptr;
544 }
545 const int64 in_n = in.half_val().size();
546 auto begin = in.half_val().begin();
547 if (n <= in_n) {
548 std::copy_n(begin, n, data);
549 } else if (in_n > 0) {
550 std::copy_n(begin, in_n, data);
551 const uint16 last = *(data + in_n - 1);
552 std::fill_n(data + in_n, n - in_n, last);
553 } else {
554 std::fill_n(data, n, 0);
555 }
556 return buf;
557 }
558
559 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64 n)560 TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
561 int64 n) {
562 CHECK_GT(n, 0);
563 Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
564 uint16* data = buf->template base<uint16>();
565 if (data == nullptr) {
566 buf->Unref();
567 return nullptr;
568 }
569 const int64 in_n = in.half_val().size();
570 auto begin = in.half_val().begin();
571 if (n <= in_n) {
572 std::copy_n(begin, n, data);
573 } else if (in_n > 0) {
574 std::copy_n(begin, in_n, data);
575 const uint16 last = *(data + in_n - 1);
576 std::fill_n(data + in_n, n - in_n, last);
577 } else {
578 std::fill_n(data, n, 0);
579 }
580 return buf;
581 }
582
583 // Copies T[n] stored in the buffer "in" into the repeated field in
584 // "out" corresponding to type T.
585 template <typename T>
ToProtoField(const TensorBuffer & in,int64 n,TensorProto * out)586 void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) {
587 const T* data = in.base<const T>();
588 // NOTE: T may not the same as
589 // ProtoHelper<T>::FieldType::value_type. E.g., T==int16,
590 // ProtoHelper<T>::FieldType::value_type==int32. If performance is
591 // critical, we can specialize T=float and do memcpy directly.
592 ProtoHelper<T>::Fill(data, n, out);
593 }
594
RefIfNonNull(core::RefCounted * buf)595 void RefIfNonNull(core::RefCounted* buf) {
596 if (buf) buf->Ref();
597 }
598
UnrefIfNonNull(core::RefCounted * buf)599 void UnrefIfNonNull(core::RefCounted* buf) {
600 if (buf) buf->Unref();
601 }
602
603 } // end namespace
604
Tensor()605 Tensor::Tensor() : Tensor(DT_FLOAT) {}
606
Tensor(DataType type)607 Tensor::Tensor(DataType type) : shape_({0}), buf_(nullptr) { set_dtype(type); }
608
Tensor(DataType type,const TensorShape & shape,TensorBuffer * buf)609 Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
610 : shape_(shape), buf_(buf) {
611 set_dtype(type);
612 RefIfNonNull(buf);
613 }
614
IsInitialized() const615 bool Tensor::IsInitialized() const {
616 return (buf_ != nullptr && buf_->data() != nullptr) ||
617 shape_.num_elements() == 0;
618 }
619
CheckType(DataType expected_dtype) const620 void Tensor::CheckType(DataType expected_dtype) const {
621 CHECK_EQ(dtype(), expected_dtype) << " "
622 << DataTypeString(expected_dtype) << " expected, got "
623 << DataTypeString(dtype());
624 }
625
CheckTypeAndIsAligned(DataType expected_dtype) const626 void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
627 CHECK_EQ(dtype(), expected_dtype) << " "
628 << DataTypeString(expected_dtype) << " expected, got "
629 << DataTypeString(dtype());
630 CHECK(IsAligned()) << "ptr = " << base<void>();
631 }
632
CheckIsAlignedAndSingleElement() const633 void Tensor::CheckIsAlignedAndSingleElement() const {
634 CHECK(IsAligned()) << "Aligned and single element";
635 CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
636 }
637
~Tensor()638 Tensor::~Tensor() { UnrefIfNonNull(buf_); }
639
CopyFromInternal(const Tensor & other,const TensorShape & shape)640 void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
641 CHECK_EQ(shape.num_elements(), other.NumElements());
642 // Data type will be overwritten if this == &other, since dtype is part of
643 // shape.
644 DataType other_dtype = other.dtype();
645 shape_ = shape;
646 set_dtype(other_dtype);
647 if (buf_ != other.buf_) {
648 UnrefIfNonNull(buf_);
649 buf_ = other.buf_;
650 RefIfNonNull(buf_);
651 }
652 }
653
BitcastFrom(const Tensor & other,DataType dtype,const TensorShape & shape)654 Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
655 const TensorShape& shape) {
656 int in_size = DataTypeSize(other.dtype());
657 int out_size = DataTypeSize(dtype);
658 if (in_size == 0) {
659 return errors::InvalidArgument("other tensor has zero-sized data type");
660 }
661 if (out_size == 0) {
662 return errors::InvalidArgument("specified output type is zero-sized");
663 }
664 if (shape.num_elements() * out_size !=
665 other.shape().num_elements() * in_size) {
666 return errors::InvalidArgument(
667 "input and output shapes/data type sizes are not compatible");
668 }
669 shape_ = shape;
670 shape_.set_data_type(dtype);
671 if (buf_ != other.buf_) {
672 UnrefIfNonNull(buf_);
673 buf_ = other.buf_;
674 RefIfNonNull(buf_);
675 }
676 return Status::OK();
677 }
678
679 // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
680 // For the latter case, we have to make sure that the refcount is
681 // one both for the SubBuffer _and_ the underlying TensorBuffer.
RefCountIsOne() const682 bool Tensor::RefCountIsOne() const {
683 return buf_ != nullptr && buf_->RefCountIsOne() &&
684 buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
685 }
686
687 // The macro CASES() expands to a switch statement conditioned on
688 // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
689 #define SINGLE_ARG(...) __VA_ARGS__
690 #define CASE(TYPE, STMTS) \
691 case DataTypeToEnum<TYPE>::value: { \
692 typedef TYPE T; \
693 STMTS; \
694 break; \
695 }
696 #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
697 switch (TYPE_ENUM) { \
698 CASE(float, SINGLE_ARG(STMTS)) \
699 CASE(double, SINGLE_ARG(STMTS)) \
700 CASE(int32, SINGLE_ARG(STMTS)) \
701 CASE(uint8, SINGLE_ARG(STMTS)) \
702 CASE(uint16, SINGLE_ARG(STMTS)) \
703 CASE(uint32, SINGLE_ARG(STMTS)) \
704 CASE(uint64, SINGLE_ARG(STMTS)) \
705 CASE(int16, SINGLE_ARG(STMTS)) \
706 CASE(int8, SINGLE_ARG(STMTS)) \
707 CASE(string, SINGLE_ARG(STMTS)) \
708 CASE(complex64, SINGLE_ARG(STMTS)) \
709 CASE(complex128, SINGLE_ARG(STMTS)) \
710 CASE(int64, SINGLE_ARG(STMTS)) \
711 CASE(bool, SINGLE_ARG(STMTS)) \
712 CASE(qint32, SINGLE_ARG(STMTS)) \
713 CASE(quint8, SINGLE_ARG(STMTS)) \
714 CASE(qint8, SINGLE_ARG(STMTS)) \
715 CASE(quint16, SINGLE_ARG(STMTS)) \
716 CASE(qint16, SINGLE_ARG(STMTS)) \
717 CASE(bfloat16, SINGLE_ARG(STMTS)) \
718 CASE(Eigen::half, SINGLE_ARG(STMTS)) \
719 CASE(ResourceHandle, SINGLE_ARG(STMTS)) \
720 CASE(Variant, SINGLE_ARG(STMTS)) \
721 case DT_INVALID: \
722 INVALID; \
723 break; \
724 default: \
725 DEFAULT; \
726 break; \
727 }
728
729 #define CASES(TYPE_ENUM, STMTS) \
730 CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
731 , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
732
Tensor(Allocator * a,DataType type,const TensorShape & shape)733 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
734 : shape_(shape), buf_(nullptr) {
735 set_dtype(type);
736 CHECK_NOTNULL(a);
737 if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
738 CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
739 }
740 if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
741 LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
742 *this);
743 }
744 }
745
Tensor(Allocator * a,DataType type,const TensorShape & shape,const AllocationAttributes & allocation_attr)746 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
747 const AllocationAttributes& allocation_attr)
748 : shape_(shape), buf_(nullptr) {
749 set_dtype(type);
750 CHECK_NOTNULL(a);
751 if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
752 CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
753 }
754 if (!allocation_attr.allocation_will_be_logged && buf_ != nullptr &&
755 buf_->data() != nullptr && LogMemory::IsEnabled()) {
756 LogMemory::RecordTensorAllocation("Unknown (with attributes)",
757 LogMemory::UNKNOWN_STEP_ID, *this);
758 }
759 }
760
Tensor(DataType type,const TensorShape & shape)761 Tensor::Tensor(DataType type, const TensorShape& shape)
762 : Tensor(cpu_allocator(), type, shape) {}
763
FillAllocationDescription(AllocationDescription * proto) const764 void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
765 AllocationDescription* proto) const {
766 proto->set_requested_bytes(size());
767 proto->set_allocator_name("HostScalarTensorBuffer");
768 proto->set_ptr(reinterpret_cast<uintptr_t>(data()));
769 }
770
771 template <typename T>
772 class SubBuffer : public TensorBuffer {
773 public:
774 // This buffer is an alias to buf[delta, delta + n).
SubBuffer(TensorBuffer * buf,int64 delta,int64 n)775 SubBuffer(TensorBuffer* buf, int64 delta, int64 n)
776 : TensorBuffer(buf->base<T>() + delta),
777 root_(buf->root_buffer()),
778 elem_(n) {
779 // Sanity check. The caller should ensure the sub buffer is valid.
780 CHECK_LE(root_->base<T>(), this->base<T>());
781 T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
782 CHECK_LE(this->base<T>(), root_limit);
783 CHECK_LE(this->base<T>() + n, root_limit);
784 // Hold a ref of the underlying root buffer.
785 // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
786 root_->Ref();
787 }
788
size() const789 size_t size() const override { return sizeof(T) * elem_; }
root_buffer()790 TensorBuffer* root_buffer() override { return root_; }
FillAllocationDescription(AllocationDescription * proto) const791 void FillAllocationDescription(AllocationDescription* proto) const override {
792 root_->FillAllocationDescription(proto);
793 }
794
795 private:
796 TensorBuffer* root_;
797 T* data_;
798 int64 elem_;
799
~SubBuffer()800 ~SubBuffer() override { root_->Unref(); }
801
802 TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
803 };
804
Slice(int64 start,int64 limit) const805 Tensor Tensor::Slice(int64 start, int64 limit) const {
806 CHECK_GE(dims(), 1);
807 CHECK_LE(0, start);
808 CHECK_LE(start, limit);
809 int64 dim0_size = shape_.dim_size(0);
810 CHECK_LE(limit, dim0_size);
811 if ((start == 0) && (limit == dim0_size)) {
812 return *this;
813 }
814 Tensor ret;
815 ret.shape_ = shape_;
816 ret.set_dtype(dtype());
817 ret.buf_ = nullptr;
818 if (dim0_size > 0) {
819 const int64 elems_per_dim0 = NumElements() / dim0_size;
820 const int64 delta = start * elems_per_dim0;
821 dim0_size = limit - start;
822 ret.shape_.set_dim(0, dim0_size);
823 const int64 num_elems = dim0_size * elems_per_dim0;
824 if (buf_) {
825 DataType dt = dtype();
826 CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
827 }
828 }
829 return ret;
830 }
831
SubSlice(int64 index) const832 Tensor Tensor::SubSlice(int64 index) const {
833 CHECK_GE(dims(), 1); // Crash ok.
834 CHECK_LE(0, index); // Crash ok.
835 int64 dim0_size = shape_.dim_size(0);
836 CHECK_LE(index, dim0_size); // Crash ok.
837 Tensor ret;
838 ret.shape_ = shape_;
839 ret.shape_.RemoveDim(0);
840 ret.set_dtype(dtype());
841 ret.buf_ = nullptr;
842 if (dim0_size > 0) {
843 const int64 elems_per_dim0 = NumElements() / dim0_size;
844 const int64 delta = index * elems_per_dim0;
845 const int64 num_elems = elems_per_dim0;
846 if (buf_) {
847 DataType dt = dtype();
848 CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
849 }
850 }
851 return ret;
852 }
853
FromProto(const TensorProto & proto)854 bool Tensor::FromProto(const TensorProto& proto) {
855 return FromProto(cpu_allocator(), proto);
856 }
857
FromProto(Allocator * a,const TensorProto & proto)858 bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
859 CHECK_NOTNULL(a);
860 TensorBuffer* p = nullptr;
861 if (!TensorShape::IsValid(proto.tensor_shape())) return false;
862 if (proto.dtype() == DT_INVALID) return false;
863 TensorShape shape(proto.tensor_shape());
864 const int64 N = shape.num_elements();
865 if (N > 0 && proto.dtype()) {
866 bool dtype_error = false;
867 if (!proto.tensor_content().empty()) {
868 const auto& content = proto.tensor_content();
869 CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
870 dtype_error = true, dtype_error = true);
871 } else {
872 CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
873 dtype_error = true, dtype_error = true);
874 }
875 if (dtype_error || p == nullptr) return false;
876 }
877 shape_ = shape;
878 set_dtype(proto.dtype());
879 UnrefIfNonNull(buf_);
880 buf_ = p;
881 // TODO(misard) add tracking of which kernels and steps are calling
882 // FromProto.
883 if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
884 LogMemory::RecordTensorAllocation("Unknown (from Proto)",
885 LogMemory::UNKNOWN_STEP_ID, *this);
886 }
887 return true;
888 }
889
AsProtoField(TensorProto * proto) const890 void Tensor::AsProtoField(TensorProto* proto) const {
891 proto->Clear();
892 shape_.AsProto(proto->mutable_tensor_shape());
893 proto->set_dtype(dtype());
894 if (buf_) {
895 CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
896 }
897 }
898
AsProtoTensorContent(TensorProto * proto) const899 void Tensor::AsProtoTensorContent(TensorProto* proto) const {
900 proto->Clear();
901 proto->set_dtype(dtype());
902 shape_.AsProto(proto->mutable_tensor_shape());
903 if (buf_) {
904 CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
905 proto->mutable_tensor_content()));
906 }
907 }
908
TotalBytes() const909 size_t Tensor::TotalBytes() const {
910 if (shape_.num_elements() == 0) return 0;
911 CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
912 CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
913 return 0; // Makes compiler happy.
914 }
915
AllocatedBytes() const916 size_t Tensor::AllocatedBytes() const {
917 TensorDescription tensor_description;
918 FillDescription(&tensor_description);
919 if (tensor_description.has_allocation_description() &&
920 tensor_description.allocation_description().allocated_bytes() > 0) {
921 return tensor_description.allocation_description().allocated_bytes();
922 } else {
923 // Fall back to TotalBytes() if the allocator doesn't have its size.
924 return TotalBytes();
925 }
926 }
927
CanUseDMA() const928 bool Tensor::CanUseDMA() const {
929 CASES(dtype(), return is_simple_type<T>::value);
930 return false; // Makes compiler happy.
931 }
932
933 #undef CASES
934 #undef CASE
935
936 namespace {
937
938 // StrCat and StrAppend don't support Eigen::half directly at the moment, and
939 // we would like to keep them compatible with their absl counterparts, for ease
940 // of migration. We could rely on errors::internal::PrepareForStrCat() but the
941 // logic is so simple we can just replicate it here, where it is close to its
942 // usage and easy to change later. And there's the extra benefit of not
943 // accessing an 'internal' namespace.
PrintOneElement(const strings::AlphaNum & a,bool print_v2)944 inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
945 bool print_v2) {
946 return a;
947 }
PrintOneElement(const string & a,bool print_v2)948 inline string PrintOneElement(const string& a, bool print_v2) {
949 if (print_v2) {
950 return "\"" + str_util::CEscape(a) + "\"";
951 } else {
952 return str_util::CEscape(a);
953 }
954 }
PrintOneElement(const Eigen::half & h,bool print_v2)955 inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
956 return static_cast<float>(h);
957 }
958
959 // Print from left dim to right dim recursively.
960 template <typename T>
PrintOneDim(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 limit,int shape_size,const T * data,int64 * data_index,string * result)961 void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
962 int64 limit, int shape_size, const T* data, int64* data_index,
963 string* result) {
964 if (*data_index >= limit) return;
965 int64 element_count = shape[dim_index];
966 // We have reached the right-most dimension of the tensor.
967 if (dim_index == shape_size - 1) {
968 for (int64 i = 0; i < element_count; i++) {
969 if (*data_index >= limit) {
970 // If not enough elements has been printed, append "...".
971 if (dim_index != 0 && i < element_count) {
972 strings::StrAppend(result, "...");
973 }
974 return;
975 }
976 if (i > 0) strings::StrAppend(result, " ");
977 strings::StrAppend(result, PrintOneElement(data[(*data_index)++], false));
978 }
979 return;
980 }
981 // Loop every element of one dim.
982 for (int64 i = 0; i < element_count; i++) {
983 bool flag = false;
984 if (*data_index < limit) {
985 strings::StrAppend(result, "[");
986 flag = true;
987 }
988 // As for each element, print the sub-dim.
989 PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
990 result);
991 if (*data_index < limit || flag) {
992 strings::StrAppend(result, "]");
993 flag = false;
994 }
995 }
996 }
997
998 // Appends the spacing between elements for a given dim onto a result string
PrintDimSpacing(int dim_index,int num_dims,string * result)999 void PrintDimSpacing(int dim_index, int num_dims, string* result) {
1000 if (dim_index == num_dims - 1) {
1001 strings::StrAppend(result, " ");
1002 return;
1003 }
1004 for (int j = 0; j < num_dims - dim_index - 1; j++) {
1005 strings::StrAppend(result, "\n");
1006 }
1007 for (int j = 0; j <= dim_index; j++) {
1008 strings::StrAppend(result, " ");
1009 }
1010 }
1011
1012 // Print from left dim to right dim recursively.
1013 template <typename T>
PrintOneDimV2(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64 num_elts_at_ends,int num_dims,const T * data,int64 data_index,string * result)1014 void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1015 int64 num_elts_at_ends, int num_dims, const T* data,
1016 int64 data_index, string* result) {
1017 // We have recursed beyond all the dimensions into a single element
1018 // of the tensor.
1019 if (dim_index == num_dims) {
1020 strings::StrAppend(result, PrintOneElement(data[data_index], true));
1021 return;
1022 }
1023
1024 strings::StrAppend(result, "[");
1025 int64 element_count = shape[dim_index];
1026 int64 start_of_end =
1027 std::max(num_elts_at_ends, element_count - num_elts_at_ends);
1028
1029 // Loop every element of one dim.
1030 int64 elements_per_iter = 1;
1031 for (int i = dim_index + 1; i < num_dims; i++) {
1032 elements_per_iter *= shape[i];
1033 }
1034 for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
1035 if (i > 0) {
1036 PrintDimSpacing(dim_index, num_dims, result);
1037 }
1038
1039 // As for each element, print the sub-dim.
1040 PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1041 data_index + elements_per_iter * i, result);
1042 }
1043 if (element_count > 2 * num_elts_at_ends) {
1044 PrintDimSpacing(dim_index, num_dims, result);
1045 strings::StrAppend(result, "...");
1046 }
1047 for (int64 i = start_of_end; i < element_count; i++) {
1048 // As for each element, print the sub-dim.
1049 PrintDimSpacing(dim_index, num_dims, result);
1050 PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1051 data_index + elements_per_iter * i, result);
1052 }
1053
1054 strings::StrAppend(result, "]");
1055 }
1056
1057 template <typename T>
SummarizeArray(int64 limit,int64 num_elts,const TensorShape & tensor_shape,const char * data,const bool print_v2)1058 string SummarizeArray(int64 limit, int64 num_elts,
1059 const TensorShape& tensor_shape, const char* data,
1060 const bool print_v2) {
1061 string ret;
1062 const T* array = reinterpret_cast<const T*>(data);
1063
1064 const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
1065 if (shape.empty()) {
1066 for (int64 i = 0; i < limit; ++i) {
1067 if (i > 0) strings::StrAppend(&ret, " ");
1068 strings::StrAppend(&ret, PrintOneElement(array[i], print_v2));
1069 }
1070 if (num_elts > limit) strings::StrAppend(&ret, "...");
1071 return ret;
1072 }
1073 if (print_v2) {
1074 const int num_dims = tensor_shape.dims();
1075 PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
1076 } else {
1077 int64 data_index = 0;
1078 const int shape_size = tensor_shape.dims();
1079 PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
1080
1081 if (num_elts > limit) strings::StrAppend(&ret, "...");
1082 }
1083
1084 return ret;
1085 }
1086 } // namespace
1087
SummarizeValue(int64 max_entries,bool print_v2) const1088 string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
1089 const int64 num_elts = NumElements();
1090 if (max_entries < 0) {
1091 max_entries = num_elts;
1092 }
1093 size_t limit = std::min(max_entries, num_elts);
1094 if ((limit > 0) && (buf_ == nullptr)) {
1095 return strings::StrCat("uninitialized Tensor of ", num_elts,
1096 " elements of type ", dtype());
1097 }
1098 const char* data = limit > 0 ? tensor_data().data() : nullptr;
1099 switch (dtype()) {
1100 case DT_HALF:
1101 return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
1102 print_v2);
1103 break;
1104 case DT_FLOAT:
1105 return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
1106 break;
1107 case DT_DOUBLE:
1108 return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
1109 break;
1110 case DT_UINT32:
1111 return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
1112 break;
1113 case DT_INT32:
1114 return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
1115 break;
1116 case DT_UINT8:
1117 case DT_QUINT8:
1118 return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
1119 break;
1120 case DT_UINT16:
1121 case DT_QUINT16:
1122 return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
1123 break;
1124 case DT_INT16:
1125 case DT_QINT16:
1126 return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
1127 break;
1128 case DT_INT8:
1129 case DT_QINT8:
1130 return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
1131 break;
1132 case DT_UINT64:
1133 return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
1134 break;
1135 case DT_INT64:
1136 return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
1137 break;
1138 case DT_BOOL:
1139 // TODO(tucker): Is it better to emit "True False..."? This
1140 // will emit "1 0..." which is more compact.
1141 return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
1142 break;
1143 case DT_STRING:
1144 return SummarizeArray<string>(limit, num_elts, shape_, data, print_v2);
1145 break;
1146 default: {
1147 // All irregular cases
1148 string ret;
1149 if (print_v2) {
1150 strings::StrAppend(&ret, "[");
1151 }
1152 // TODO(irving): Don't call flat every time around this
1153 // loop.
1154 for (size_t i = 0; i < limit; ++i) {
1155 if (i > 0) strings::StrAppend(&ret, " ");
1156 switch (dtype()) {
1157 case DT_VARIANT: {
1158 const Variant& v = flat<Variant>()(i);
1159 strings::StrAppend(&ret, v.DebugString());
1160 } break;
1161 default:
1162 // TODO(zhifengc, josh11b): Pretty-print other types (bool,
1163 // complex64, quantized).
1164 strings::StrAppend(&ret, "?");
1165 }
1166 }
1167 if (max_entries < num_elts) strings::StrAppend(&ret, "...");
1168 if (print_v2) {
1169 strings::StrAppend(&ret, "]");
1170 }
1171 return ret;
1172 }
1173 }
1174 }
1175
tensor_data() const1176 StringPiece Tensor::tensor_data() const {
1177 if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors
1178 return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
1179 }
1180
SharesBufferWith(const Tensor & b) const1181 bool Tensor::SharesBufferWith(const Tensor& b) const {
1182 return buf_ != nullptr && b.buf_ != nullptr &&
1183 buf_->root_buffer() == b.buf_->root_buffer();
1184 }
1185
DebugString(int num_values) const1186 string Tensor::DebugString(int num_values) const {
1187 return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1188 " shape: ", shape().DebugString(),
1189 " values: ", SummarizeValue(num_values), ">");
1190 }
1191
DeviceSafeDebugString() const1192 string Tensor::DeviceSafeDebugString() const {
1193 return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1194 " shape: ", shape().DebugString(), ">");
1195 }
1196
FillDescription(TensorDescription * description) const1197 void Tensor::FillDescription(TensorDescription* description) const {
1198 description->set_dtype(dtype());
1199 shape().AsProto(description->mutable_shape());
1200 if (buf_ != nullptr && buf_->data() != nullptr) {
1201 buf_->FillAllocationDescription(
1202 description->mutable_allocation_description());
1203 }
1204 }
1205
ComputeFlatInnerDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1206 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
1207 gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1208 gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1209 int64 offset = orig.size() - num_out_dims;
1210 for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
1211 const int64 in_dim = out_dim + offset;
1212 out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
1213 }
1214 for (int64 in_dim = 0; in_dim < offset; ++in_dim) {
1215 out_dims[0] *= orig[in_dim];
1216 }
1217 return out_dims;
1218 }
1219
ComputeFlatOuterDims(gtl::ArraySlice<int64> orig,int64 num_out_dims)1220 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
1221 gtl::ArraySlice<int64> orig, int64 num_out_dims) {
1222 gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
1223 for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
1224 out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
1225 }
1226 for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
1227 out_dims[num_out_dims - 1] *= orig[in_dim];
1228 }
1229 return out_dims;
1230 }
1231
1232 } // namespace tensorflow
1233