1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_util.h"
17
18 #include <cmath>
19 #include <vector>
20
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/type_traits.h"
24 #include "tensorflow/core/framework/variant.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/platform/protobuf.h"
27 #include "tensorflow/core/platform/tensor_coding.h"
28 #include "tensorflow/core/platform/types.h"
29
30 namespace tensorflow {
31 namespace tensor {
32
DeepCopy(const Tensor & other)33 Tensor DeepCopy(const Tensor& other) {
34 Tensor tmp = Tensor(other.dtype(), other.shape());
35 DeepCopy(other, &tmp);
36 return tmp;
37 }
38
DeepCopy(const Tensor & input,Tensor * output)39 void DeepCopy(const Tensor& input, Tensor* output) {
40 if (DataTypeCanUseMemcpy(input.dtype())) {
41 if (input.NumElements() > 0) {
42 StringPiece input_data = input.tensor_data();
43
44 // We use StringPiece as a convenient map over the tensor buffer,
45 // but we cast the type to get to the underlying buffer to do the
46 // copy.
47 StringPiece output_data = output->tensor_data();
48 memcpy(const_cast<char*>(output_data.data()), input_data.data(),
49 input_data.size());
50 }
51 } else if (input.dtype() == DT_STRING) {
52 output->unaligned_flat<tstring>() = input.unaligned_flat<tstring>();
53 } else {
54 CHECK_EQ(DT_VARIANT, input.dtype());
55 output->unaligned_flat<Variant>() = input.unaligned_flat<Variant>();
56 }
57 }
58
Concat(const gtl::ArraySlice<Tensor> & tensors,Tensor * result)59 Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
60 if (tensors.empty()) {
61 return errors::InvalidArgument("Cannot concatenate zero tensors");
62 }
63 int64 total_dim0_size = 0;
64 for (const Tensor& tensor : tensors) {
65 if (tensor.dims() == 0) {
66 return errors::InvalidArgument(
67 "Cannot concatenate a zero-dimensional tensor");
68 }
69 total_dim0_size += tensor.dim_size(0);
70 }
71 TensorShape shape = tensors[0].shape();
72 shape.set_dim(0, total_dim0_size);
73
74 const DataType dtype = tensors[0].dtype();
75 for (int i = 1; i < tensors.size(); ++i) {
76 if (tensors[i].dtype() != dtype) {
77 return errors::InvalidArgument(
78 "Cannot concatenate tensors that have different data types.", " Got ",
79 DataTypeString(dtype), " and ", DataTypeString(tensors[i].dtype()),
80 ".");
81 }
82 }
83 *result = Tensor(dtype, shape);
84
85 // We use StringPiece as a convenient map over the tensor buffer,
86 // but we cast the type to get to the underlying buffer to do the
87 // copy.
88 StringPiece to_data = result->tensor_data();
89
90 if (DataTypeCanUseMemcpy(dtype)) {
91 int64 offset = 0;
92 for (const Tensor& tensor : tensors) {
93 StringPiece from_data = tensor.tensor_data();
94 CHECK_LE(offset + from_data.size(), to_data.size());
95 memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(),
96 from_data.size());
97
98 offset += from_data.size();
99 }
100 } else {
101 if (dtype != DT_STRING) {
102 return errors::Internal("Unexpected data type");
103 }
104 tstring* to_strings =
105 reinterpret_cast<tstring*>(const_cast<char*>(to_data.data()));
106
107 int64 offset = 0;
108 for (const Tensor& tensor : tensors) {
109 auto from_strings = tensor.flat<tstring>();
110 CHECK_LE(offset + tensor.NumElements(), result->NumElements());
111 for (int i = 0; i < tensor.NumElements(); ++i) {
112 to_strings[offset + i] = from_strings(i);
113 }
114
115 offset += tensor.NumElements();
116 }
117 }
118
119 return Status::OK();
120 }
121
Split(const Tensor & tensor,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * result)122 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
123 std::vector<Tensor>* result) {
124 if (tensor.dims() == 0) {
125 return errors::InvalidArgument("Cannot split a zero-dimensional tensor");
126 }
127 int64 total_size = 0;
128 for (int64 size : sizes) {
129 total_size += size;
130 }
131 if (total_size != tensor.dim_size(0)) {
132 return errors::InvalidArgument(
133 "The values in 'sizes' do not sum to the zeroth-dimension size of "
134 "'tensor'");
135 }
136
137 StringPiece from_data = tensor.tensor_data();
138
139 if (DataTypeCanUseMemcpy(tensor.dtype())) {
140 int64 offset = 0;
141 for (int64 size : sizes) {
142 TensorShape shape = tensor.shape();
143 shape.set_dim(0, size);
144 result->emplace_back(tensor.dtype(), shape);
145 Tensor* split = &(*result)[result->size() - 1];
146
147 // We use StringPiece as a convenient map over the tensor buffer,
148 // but we cast the type to get to the underlying buffer to do the
149 // copy.
150 StringPiece to_data = split->tensor_data();
151 CHECK_LE(offset + to_data.size(), from_data.size());
152 memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset,
153 to_data.size());
154
155 offset += to_data.size();
156 }
157 } else {
158 if (tensor.dtype() != DT_STRING) {
159 return errors::Internal("Unexpected data type");
160 }
161 auto from_strings = tensor.flat<tstring>();
162
163 int64 offset = 0;
164 for (int64 size : sizes) {
165 TensorShape shape = tensor.shape();
166 shape.set_dim(0, size);
167 result->emplace_back(tensor.dtype(), shape);
168 Tensor& split = (*result)[result->size() - 1];
169 tstring* to_strings = reinterpret_cast<tstring*>(
170 const_cast<char*>(split.tensor_data().data()));
171
172 CHECK_LE(offset + split.NumElements(), tensor.NumElements());
173 for (int i = 0; i < split.NumElements(); ++i) {
174 to_strings[i] = from_strings(offset + i);
175 }
176
177 offset += split.NumElements();
178 }
179 }
180
181 return Status::OK();
182 }
183
184 namespace internal {
SetTensorProtoShape(std::vector<size_t> shape,TensorShapeProto * shape_proto)185 void SetTensorProtoShape(std::vector<size_t> shape,
186 TensorShapeProto* shape_proto) {
187 for (auto dim : shape) {
188 shape_proto->mutable_dim()->Add()->set_size(dim);
189 }
190 }
191
192 template <typename T>
CompressTensorContent(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)193 bool CompressTensorContent(float min_compression_ratio,
194 const TensorShape& shape, TensorProto* tensor) {
195 using TypeHelper = internal::TensorProtoHelper<T>;
196 using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
197 const int64 num_tensor_values = shape.num_elements();
198 const int64 num_bytes = tensor->tensor_content().size();
199 const int64 num_raw_values = num_bytes / sizeof(T);
200 if (num_raw_values != num_tensor_values) {
201 // Invalid or too small.
202 return false;
203 }
204 int64 last_offset = num_bytes - 1;
205 int64 prev_offset = last_offset - sizeof(T);
206 // Inspect individual raw bytes sizeof(T) bytes apart in adjacent elements,
207 // starting from the end, to find the last pair of elements that are not
208 // identical.
209 while (prev_offset >= 0) {
210 if (tensor->tensor_content()[prev_offset] !=
211 tensor->tensor_content()[last_offset]) {
212 break;
213 }
214 --last_offset;
215 --prev_offset;
216 }
217 // Round up to the next whole number of element of type T.
218 const int64 new_num_values = last_offset / sizeof(T) + 1;
219 if (new_num_values * (is_complex<T>::value ? 2 : 1) * sizeof(FieldType) >
220 static_cast<int64>(num_bytes / min_compression_ratio)) {
221 return false;
222 }
223 // Copy values to truncated repeated field.
224 if (sizeof(FieldType) == sizeof(T)) {
225 FieldType* dst_ptr =
226 TypeHelper::AppendUninitialized(new_num_values, tensor);
227 port::CopySubrangeToArray(tensor->tensor_content(), 0,
228 new_num_values * sizeof(T),
229 reinterpret_cast<char*>(dst_ptr));
230 tensor->clear_tensor_content();
231 } else if (sizeof(T) > 1) {
232 // Copy raw bytes to temp array first, then cast.
233 gtl::InlinedVector<T, 64> tmp(new_num_values);
234 port::CopySubrangeToArray(tensor->tensor_content(), 0,
235 new_num_values * sizeof(T),
236 reinterpret_cast<char*>(tmp.data()));
237 tensor->clear_tensor_content();
238 const T* begin = tmp.begin();
239 const T* end = tmp.end();
240 TypeHelper::AddValues(begin, end, tensor);
241 } else {
242 // Copy and cast, one byte at a time.
243 for (int64 i = 0; i < new_num_values; ++i) {
244 char c = tensor->tensor_content()[i];
245 TypeHelper::AddValue(static_cast<T>(c), tensor);
246 }
247 tensor->clear_tensor_content();
248 }
249 return true;
250 }
251
252 template <typename T>
PackedValuesNotEqual(T a,T b)253 inline bool PackedValuesNotEqual(T a, T b) {
254 return a != b;
255 }
256 template <>
PackedValuesNotEqual(float a,float b)257 inline bool PackedValuesNotEqual(float a, float b) {
258 return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
259 }
260 template <>
PackedValuesNotEqual(double a,double b)261 inline bool PackedValuesNotEqual(double a, double b) {
262 return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
263 }
264 template <typename RealType>
PackedValuesNotEqual(const std::complex<RealType> & a,const std::complex<RealType> & b)265 inline bool PackedValuesNotEqual(const std::complex<RealType>& a,
266 const std::complex<RealType>& b) {
267 return PackedValuesNotEqual(a.real(), b.real()) ||
268 PackedValuesNotEqual(a.imag(), b.imag());
269 }
270
271 template <typename T>
CompressRepeatedField(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)272 bool CompressRepeatedField(float min_compression_ratio,
273 const TensorShape& shape, TensorProto* tensor) {
274 using TypeHelper = internal::TensorProtoHelper<T>;
275 using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
276 const int64 num_tensor_values = shape.num_elements();
277 // Notice that for complex types the tensor is stored as an array of up to
278 // 2 * num_tensor_values real values (real and imaginary parts), possibly
279 // truncated.
280 const int64 num_proto_values = TypeHelper::NumValues(*tensor);
281 if (num_proto_values != num_tensor_values) {
282 // Already compressed or invalid.
283 return false;
284 }
285 const T last_value = TypeHelper::GetValue(num_proto_values - 1, *tensor);
286 int64 last_index = 0;
287 for (int64 i = num_proto_values - 2; i >= 0 && last_index == 0; --i) {
288 const T cur_value = TypeHelper::GetValue(i, *tensor);
289 if (PackedValuesNotEqual(cur_value, last_value)) {
290 last_index = i + 1;
291 }
292 }
293 const int64 num_truncated_proto_values = last_index + 1;
294 const int64 num_bytes_as_field =
295 num_truncated_proto_values * sizeof(FieldType);
296 const int64 num_bytes_as_tensor_content = num_tensor_values * sizeof(T);
297 const int64 num_bytes_before = num_proto_values * sizeof(FieldType);
298 if (std::min(num_bytes_as_field, num_bytes_as_tensor_content) >
299 static_cast<int64>(num_bytes_before / min_compression_ratio)) {
300 return false;
301 }
302 if (num_bytes_as_field <= num_bytes_as_tensor_content) {
303 TypeHelper::Truncate(num_truncated_proto_values, tensor);
304 } else {
305 gtl::InlinedVector<T, 64> tmp(num_tensor_values);
306 TypeHelper::CopyValues(tmp.begin(), *tensor);
307 TypeHelper::Truncate(0, tensor);
308 port::CopyFromArray(tensor->mutable_tensor_content(),
309 reinterpret_cast<const char*>(tmp.data()),
310 num_bytes_as_tensor_content);
311 }
312 return true;
313 }
314
315 template <typename T>
CompressTensorProtoInPlaceImpl(int64 min_num_elements,float min_compression_ratio,TensorProto * tensor)316 bool CompressTensorProtoInPlaceImpl(int64 min_num_elements,
317 float min_compression_ratio,
318 TensorProto* tensor) {
319 const TensorShape shape(tensor->tensor_shape());
320 const int64 num_tensor_values = shape.num_elements();
321 if (num_tensor_values < min_num_elements) {
322 return false;
323 }
324 if (tensor->tensor_content().empty()) {
325 return CompressRepeatedField<T>(min_compression_ratio, shape, tensor);
326 } else {
327 return CompressTensorContent<T>(min_compression_ratio, shape, tensor);
328 }
329 return true;
330 }
331
332 } // namespace internal
333
334 #define HANDLE_COMPRESS_CASE(TF_TYPE) \
335 case TF_TYPE: \
336 return internal::CompressTensorProtoInPlaceImpl< \
337 EnumToDataType<TF_TYPE>::Type>(min_num_elements, \
338 min_compression_ratio, tensor); \
339 break
340
CompressTensorProtoInPlace(int64 min_num_elements,float min_compression_ratio,TensorProto * tensor)341 bool CompressTensorProtoInPlace(int64 min_num_elements,
342 float min_compression_ratio,
343 TensorProto* tensor) {
344 switch (tensor->dtype()) {
345 HANDLE_COMPRESS_CASE(DT_FLOAT);
346 HANDLE_COMPRESS_CASE(DT_DOUBLE);
347 HANDLE_COMPRESS_CASE(DT_COMPLEX64);
348 HANDLE_COMPRESS_CASE(DT_COMPLEX128);
349 HANDLE_COMPRESS_CASE(DT_UINT8);
350 HANDLE_COMPRESS_CASE(DT_INT8);
351 HANDLE_COMPRESS_CASE(DT_UINT16);
352 HANDLE_COMPRESS_CASE(DT_INT16);
353 HANDLE_COMPRESS_CASE(DT_UINT32);
354 HANDLE_COMPRESS_CASE(DT_INT32);
355 HANDLE_COMPRESS_CASE(DT_UINT64);
356 HANDLE_COMPRESS_CASE(DT_INT64);
357 HANDLE_COMPRESS_CASE(DT_BOOL);
358 HANDLE_COMPRESS_CASE(DT_QUINT8);
359 HANDLE_COMPRESS_CASE(DT_QINT8);
360 HANDLE_COMPRESS_CASE(DT_QUINT16);
361 HANDLE_COMPRESS_CASE(DT_QINT16);
362 HANDLE_COMPRESS_CASE(DT_QINT32);
363 HANDLE_COMPRESS_CASE(DT_HALF);
364 HANDLE_COMPRESS_CASE(DT_BFLOAT16);
365 default:
366 return false;
367 }
368 }
369
370 #undef HANDLE_COMPRESS_CASE
371
MakeShape(const Tensor & shape,TensorShape * out)372 Status MakeShape(const Tensor& shape, TensorShape* out) {
373 if (!TensorShapeUtils::IsVector(shape.shape())) {
374 return errors::InvalidArgument(
375 "shape must be a vector of {int32,int64}, got shape ",
376 shape.shape().DebugString());
377 }
378 if (shape.dtype() == DataType::DT_INT32) {
379 auto vec = shape.flat<int32>();
380 return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
381 } else if (shape.dtype() == DataType::DT_INT64) {
382 auto vec = shape.flat<int64>();
383 return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
384 } else {
385 return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
386 }
387 }
388
389 } // namespace tensor
390 } // namespace tensorflow
391