1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_util.h"
17 
18 #include <vector>
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/framework/variant_encode_decode.h"
24 #include "tensorflow/core/framework/variant_tensor_data.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/test.h"
27 
28 namespace tensorflow {
29 namespace {
30 
TEST(TensorUtil,DeepCopy0d)31 TEST(TensorUtil, DeepCopy0d) {
32   Tensor x(DT_FLOAT, TensorShape({}));
33   x.scalar<float>()() = 10.0;
34 
35   // Make y a deep copy of x and then change it.
36   Tensor y = tensor::DeepCopy(x);
37   y.scalar<float>()() = 20.0;
38 
39   // x doesn't change
40   EXPECT_EQ(10.0, x.scalar<float>()());
41 
42   // Change x.
43   x.scalar<float>()() = 30.0;
44 
45   // Y doesn't change.
46   EXPECT_EQ(20.0, y.scalar<float>()());
47 
48   Tensor z = tensor::DeepCopy(y);
49 
50   // Change y.
51   y.scalar<float>()() = 40.0;
52 
53   // The final states should all be different.
54   EXPECT_EQ(20.0, z.scalar<float>()());
55   EXPECT_EQ(30.0, x.scalar<float>()());
56   EXPECT_EQ(40.0, y.scalar<float>()());
57 
58   // Should have the same shape and type.
59   EXPECT_EQ(TensorShape({}), x.shape());
60   EXPECT_EQ(TensorShape({}), y.shape());
61   EXPECT_EQ(TensorShape({}), z.shape());
62 
63   EXPECT_EQ(DT_FLOAT, x.dtype());
64   EXPECT_EQ(DT_FLOAT, y.dtype());
65   EXPECT_EQ(DT_FLOAT, z.dtype());
66 }
67 
TEST(TensorUtil,DeepCopyZeroElements)68 TEST(TensorUtil, DeepCopyZeroElements) {
69   Tensor x;
70   Tensor y = tensor::DeepCopy(x);
71   EXPECT_EQ(TensorShape({0}), y.shape());
72   EXPECT_EQ(DT_FLOAT, y.dtype());
73   EXPECT_EQ(0, y.NumElements());
74 }
75 
TEST(TensorUtil,DeepCopy)76 TEST(TensorUtil, DeepCopy) {
77   Tensor x(DT_FLOAT, TensorShape({1}));
78   x.flat<float>()(0) = 10.0;
79 
80   // Make y a deep copy of x and then change it.
81   Tensor y = tensor::DeepCopy(x);
82   y.flat<float>()(0) = 20.0;
83 
84   // x doesn't change
85   EXPECT_EQ(10.0, x.flat<float>()(0));
86 
87   // Change x.
88   x.flat<float>()(0) = 30.0;
89 
90   // Y doesn't change.
91   EXPECT_EQ(20.0, y.flat<float>()(0));
92 
93   Tensor z = tensor::DeepCopy(y);
94 
95   // Change y.
96   y.flat<float>()(0) = 40.0;
97 
98   // The final states should all be different.
99   EXPECT_EQ(20.0, z.flat<float>()(0));
100   EXPECT_EQ(30.0, x.flat<float>()(0));
101   EXPECT_EQ(40.0, y.flat<float>()(0));
102 
103   // Should have the same shape and type.
104   EXPECT_EQ(TensorShape({1}), x.shape());
105   EXPECT_EQ(TensorShape({1}), y.shape());
106   EXPECT_EQ(TensorShape({1}), z.shape());
107 
108   EXPECT_EQ(DT_FLOAT, x.dtype());
109   EXPECT_EQ(DT_FLOAT, y.dtype());
110   EXPECT_EQ(DT_FLOAT, z.dtype());
111 
112   // Test string deep copy
113   Tensor str1(DT_STRING, TensorShape({2}));
114   str1.flat<string>()(0) = "foo1";
115   str1.flat<string>()(1) = "foo2";
116   Tensor str2 = tensor::DeepCopy(str1);
117   str2.flat<string>()(0) = "bar1";
118   str2.flat<string>()(1) = "bar2";
119   EXPECT_NE(str2.flat<string>()(0), str1.flat<string>()(0));
120 }
121 
TEST(TensorUtil,DeepCopySlice)122 TEST(TensorUtil, DeepCopySlice) {
123   Tensor x(DT_INT32, TensorShape({10}));
124   x.flat<int32>().setConstant(1);
125 
126   // Slice 'x' -- y still refers to the same buffer.
127   Tensor y = x.Slice(2, 6);
128 
129   // Do a deep copy of y, which is a slice.
130   Tensor z = tensor::DeepCopy(y);
131 
132   // Set x to be different.
133   x.flat<int32>().setConstant(2);
134 
135   EXPECT_EQ(TensorShape({10}), x.shape());
136   EXPECT_EQ(TensorShape({4}), y.shape());
137   EXPECT_EQ(TensorShape({4}), z.shape());
138   EXPECT_EQ(DT_INT32, x.dtype());
139   EXPECT_EQ(DT_INT32, y.dtype());
140   EXPECT_EQ(DT_INT32, z.dtype());
141 
142   // x and y should now all be '2', but z should be '1'.
143   for (int i = 0; i < 10; ++i) {
144     EXPECT_EQ(2, x.flat<int32>()(i));
145   }
146   for (int i = 0; i < 4; ++i) {
147     EXPECT_EQ(2, y.unaligned_flat<int32>()(i));
148     EXPECT_EQ(1, z.flat<int32>()(i));
149   }
150 }
151 
TEST(TensorUtil,DeepCopySliceString)152 TEST(TensorUtil, DeepCopySliceString) {
153   Tensor x(DT_STRING, TensorShape({10}));
154   x.flat<string>().setConstant("hello");
155 
156   // Slice 'x' -- y still refers to the same buffer.
157   Tensor y = x.Slice(3, 7);
158 
159   // Do a deep copy of y, which is a slice.
160   Tensor z = tensor::DeepCopy(y);
161 
162   // Set x to be different.
163   x.flat<string>().setConstant("goodbye");
164 
165   EXPECT_EQ(TensorShape({10}), x.shape());
166   EXPECT_EQ(TensorShape({4}), y.shape());
167   EXPECT_EQ(TensorShape({4}), z.shape());
168   EXPECT_EQ(DT_STRING, x.dtype());
169   EXPECT_EQ(DT_STRING, y.dtype());
170   EXPECT_EQ(DT_STRING, z.dtype());
171 
172   // x and y should now all be 'goodbye', but z should be 'hello'.
173   for (int i = 0; i < 10; ++i) {
174     EXPECT_EQ("goodbye", x.flat<string>()(i));
175   }
176   for (int i = 0; i < 4; ++i) {
177     EXPECT_EQ("goodbye", y.unaligned_flat<string>()(i));
178     EXPECT_EQ("hello", z.flat<string>()(i));
179   }
180 }
181 
TEST(TensorUtil,DeepCopySliceVariant)182 TEST(TensorUtil, DeepCopySliceVariant) {
183   Tensor x(DT_VARIANT, TensorShape({10}));
184   x.flat<Variant>().setConstant(Tensor(42.0f));
185 
186   // Slice 'x' -- y still refers to the same buffer.
187   Tensor y = x.Slice(3, 7);
188 
189   // Do a deep copy of y, which is a slice.
190   Tensor z = tensor::DeepCopy(y);
191 
192   // Set x to be different.
193   x.flat<Variant>().setConstant(Tensor("foo"));
194 
195   EXPECT_EQ(TensorShape({10}), x.shape());
196   EXPECT_EQ(TensorShape({4}), y.shape());
197   EXPECT_EQ(TensorShape({4}), z.shape());
198   EXPECT_EQ(DT_VARIANT, x.dtype());
199   EXPECT_EQ(DT_VARIANT, y.dtype());
200   EXPECT_EQ(DT_VARIANT, z.dtype());
201 
202   // Each element of x and y should now be a DT_STRING Tensor containing "foo",
203   // but each element of z should be a DT_FLOAT tensor containing 42.0.
204   for (int i = 0; i < 10; ++i) {
205     EXPECT_EQ("foo", x.flat<Variant>()(i).get<Tensor>()->scalar<string>()());
206   }
207   for (int i = 0; i < 4; ++i) {
208     EXPECT_EQ("foo",
209               y.unaligned_flat<Variant>()(i).get<Tensor>()->scalar<string>()());
210     EXPECT_EQ(42.0, z.flat<Variant>()(i).get<Tensor>()->scalar<float>()());
211   }
212 }
213 
TEST(TensorUtil,Concat)214 TEST(TensorUtil, Concat) {
215   std::vector<int64> sizes = {1, 4, 5};
216   std::vector<Tensor> to_concat;
217   int64 total_size = 0;
218   int offset = 0;
219   for (size_t entry = 0; entry < sizes.size(); ++entry) {
220     const int64 size = sizes[entry];
221     Tensor tensor(DT_INT32, TensorShape({size, 2}));
222     for (int i = offset; i < offset + size; ++i) {
223       for (int j = 0; j < 2; ++j) {
224         tensor.matrix<int32>()(i - offset, j) = 2 * i + j;
225       }
226     }
227     to_concat.push_back(tensor);
228     total_size += size;
229     offset += size;
230   }
231 
232   Tensor concated;
233   TF_ASSERT_OK(tensor::Concat(to_concat, &concated));
234   ASSERT_EQ(TensorShape({total_size, 2}), concated.shape());
235   for (int i = 0; i < total_size; ++i) {
236     for (int j = 0; j < 2; ++j) {
237       EXPECT_EQ(2 * i + j, concated.matrix<int32>()(i, j));
238     }
239   }
240 }
241 
TEST(TensorUtil,Split)242 TEST(TensorUtil, Split) {
243   Tensor to_split(DT_INT64, TensorShape({10, 2}));
244   for (int i = 0; i < 10; ++i) {
245     for (int j = 0; j < 2; ++j) {
246       to_split.matrix<int64>()(i, j) = 2 * i + j;
247     }
248   }
249 
250   std::vector<int64> sizes = {1, 4, 5};
251   std::vector<Tensor> splits;
252   TF_ASSERT_OK(tensor::Split(to_split, sizes, &splits));
253   ASSERT_EQ(sizes.size(), splits.size());
254 
255   int offset = 0;
256   for (size_t entry = 0; entry < splits.size(); ++entry) {
257     const int64 size = sizes[entry];
258     const Tensor& split = splits[entry];
259 
260     ASSERT_EQ(TensorShape({size, 2}), split.shape());
261     for (int i = offset; i < offset + size; ++i) {
262       for (int j = 0; j < 2; ++j) {
263         EXPECT_EQ(2 * i + j, split.matrix<int64>()(i - offset, j));
264       }
265     }
266 
267     offset += size;
268   }
269 }
270 
TEST(TensorUtil,ConcatSplitStrings)271 TEST(TensorUtil, ConcatSplitStrings) {
272   Tensor x(DT_STRING, TensorShape({4, 3}));
273   for (int i = 0; i < 4 * 3; ++i) {
274     x.flat<string>()(i) = strings::StrCat("foo_", i);
275   }
276 
277   std::vector<Tensor> split;
278   TF_ASSERT_OK(tensor::Split(x, {2, 1, 1}, &split));
279   Tensor x_round_tripped;
280   TF_ASSERT_OK(tensor::Concat(split, &x_round_tripped));
281   ASSERT_EQ(x.shape(), x_round_tripped.shape());
282   for (int i = 0; i < 4 * 3; ++i) {
283     EXPECT_EQ(x.flat<string>()(i), x_round_tripped.flat<string>()(i));
284   }
285 
286   // Ensure that no memory is being shared between 'x' and 'x_round_tripped'.
287   for (int i = 0; i < 4 * 3; ++i) {
288     x_round_tripped.flat<string>()(i) = strings::StrCat("bar_", i);
289   }
290   for (int i = 0; i < 4 * 3; ++i) {
291     EXPECT_NE(x.flat<string>()(i), x_round_tripped.flat<string>()(i));
292   }
293 }
294 
TEST(TensorProtoUtil,CreatesStringTensorProto)295 TEST(TensorProtoUtil, CreatesStringTensorProto) {
296   std::vector<string> values{"a", "b", "c"};
297   std::vector<size_t> shape{1, 3};
298 
299   auto proto = tensor::CreateTensorProto(values, shape);
300 
301   EXPECT_EQ(proto.DebugString(),
302             "dtype: DT_STRING\n"
303             "tensor_shape {\n"
304             "  dim {\n"
305             "    size: 1\n"
306             "  }\n"
307             "  dim {\n"
308             "    size: 3\n"
309             "  }\n"
310             "}\n"
311             "string_val: \"a\"\n"
312             "string_val: \"b\"\n"
313             "string_val: \"c\"\n");
314 }
315 
TEST(TensorProtoUtil,CreatesInt32TensorProto)316 TEST(TensorProtoUtil, CreatesInt32TensorProto) {
317   std::vector<int32> values{1, 2};
318   std::vector<size_t> shape{2};
319 
320   auto proto = tensor::CreateTensorProto(values, shape);
321 
322   EXPECT_EQ(proto.DebugString(),
323             "dtype: DT_INT32\n"
324             "tensor_shape {\n"
325             "  dim {\n"
326             "    size: 2\n"
327             "  }\n"
328             "}\n"
329             "int_val: 1\n"
330             "int_val: 2\n");
331 }
332 
TEST(TensorProtoUtil,CreatesInt64TensorProto)333 TEST(TensorProtoUtil, CreatesInt64TensorProto) {
334   std::vector<int64> values{1, 2};
335   std::vector<size_t> shape{2};
336 
337   auto proto = tensor::CreateTensorProto(values, shape);
338 
339   EXPECT_EQ(proto.DebugString(),
340             "dtype: DT_INT64\n"
341             "tensor_shape {\n"
342             "  dim {\n"
343             "    size: 2\n"
344             "  }\n"
345             "}\n"
346             "int64_val: 1\n"
347             "int64_val: 2\n");
348 }
349 
TEST(TensorProtoUtil,CreatesUInt32TensorProto)350 TEST(TensorProtoUtil, CreatesUInt32TensorProto) {
351   std::vector<uint32> values{1, 2};
352   std::vector<size_t> shape{2};
353 
354   auto proto = tensor::CreateTensorProto(values, shape);
355 
356   EXPECT_EQ(proto.DebugString(),
357             "dtype: DT_UINT32\n"
358             "tensor_shape {\n"
359             "  dim {\n"
360             "    size: 2\n"
361             "  }\n"
362             "}\n"
363             "uint32_val: 1\n"
364             "uint32_val: 2\n");
365 }
366 
TEST(TensorProtoUtil,CreatesUInt64TensorProto)367 TEST(TensorProtoUtil, CreatesUInt64TensorProto) {
368   std::vector<uint64> values{1, 2};
369   std::vector<size_t> shape{2};
370 
371   auto proto = tensor::CreateTensorProto(values, shape);
372 
373   EXPECT_EQ(proto.DebugString(),
374             "dtype: DT_UINT64\n"
375             "tensor_shape {\n"
376             "  dim {\n"
377             "    size: 2\n"
378             "  }\n"
379             "}\n"
380             "uint64_val: 1\n"
381             "uint64_val: 2\n");
382 }
383 
TEST(TensorProtoUtil,CreatesFloatTensorProto)384 TEST(TensorProtoUtil, CreatesFloatTensorProto) {
385   std::vector<float> values{1.1, 2.2};
386   std::vector<size_t> shape{2};
387 
388   auto proto = tensor::CreateTensorProto(values, shape);
389 
390   EXPECT_EQ(proto.DebugString(),
391             "dtype: DT_FLOAT\n"
392             "tensor_shape {\n"
393             "  dim {\n"
394             "    size: 2\n"
395             "  }\n"
396             "}\n"
397             "float_val: 1.1\n"
398             "float_val: 2.2\n");
399 }
400 
TEST(TensorProtoUtil,CreatesDoubleTensorProto)401 TEST(TensorProtoUtil, CreatesDoubleTensorProto) {
402   std::vector<double> values{1.1, 2.2};
403   std::vector<size_t> shape{2};
404 
405   auto proto = tensor::CreateTensorProto(values, shape);
406 
407   EXPECT_EQ(proto.DebugString(),
408             "dtype: DT_DOUBLE\n"
409             "tensor_shape {\n"
410             "  dim {\n"
411             "    size: 2\n"
412             "  }\n"
413             "}\n"
414             "double_val: 1.1\n"
415             "double_val: 2.2\n");
416 }
417 
TEST(TensorProtoUtil,CreatesBoolTensorProto)418 TEST(TensorProtoUtil, CreatesBoolTensorProto) {
419   std::vector<bool> values{true, false};
420   std::vector<size_t> shape{2};
421 
422   auto proto = tensor::CreateTensorProto(values, shape);
423 
424   EXPECT_EQ(proto.DebugString(),
425             "dtype: DT_BOOL\n"
426             "tensor_shape {\n"
427             "  dim {\n"
428             "    size: 2\n"
429             "  }\n"
430             "}\n"
431             "bool_val: true\n"
432             "bool_val: false\n");
433 }
434 
TEST(TensorProtoUtil,CompressTensorProtoInPlaceTooSmall)435 TEST(TensorProtoUtil, CompressTensorProtoInPlaceTooSmall) {
436   const int kLength = 63;
437   TensorProto tensor_proto =
438       tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
439   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
440   tensor_proto =
441       tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
442   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
443   tensor_proto =
444       tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
445   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
446   tensor_proto =
447       tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
448   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
449   tensor_proto =
450       tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
451   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
452   tensor_proto = tensor::CreateTensorProto(
453       std::vector<std::complex<float>>(kLength), {kLength});
454   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
455 }
456 
TEST(TensorProtoUtil,CompressTensorProtoInPlaceAllEqual)457 TEST(TensorProtoUtil, CompressTensorProtoInPlaceAllEqual) {
458   const int kLength = 64;
459   TensorProto tensor_proto =
460       tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
461   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
462   EXPECT_EQ(tensor::internal::TensorProtoHelper<float>::NumValues(tensor_proto),
463             1);
464 
465   tensor_proto =
466       tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
467   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
468   EXPECT_EQ(tensor::internal::TensorProtoHelper<int>::NumValues(tensor_proto),
469             1);
470 
471   tensor_proto =
472       tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
473   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
474   EXPECT_EQ(tensor::internal::TensorProtoHelper<uint8>::NumValues(tensor_proto),
475             1);
476   tensor_proto =
477       tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
478   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
479   EXPECT_EQ(tensor::internal::TensorProtoHelper<bool>::NumValues(tensor_proto),
480             1);
481 
482   tensor_proto =
483       tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
484   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
485   EXPECT_EQ(
486       tensor::internal::TensorProtoHelper<Eigen::half>::NumValues(tensor_proto),
487       1);
488 
489   tensor_proto = tensor::CreateTensorProto(
490       std::vector<std::complex<float>>(kLength), {kLength});
491   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
492   EXPECT_EQ(tensor::internal::TensorProtoHelper<std::complex<float>>::NumValues(
493                 tensor_proto),
494             1);
495 }
496 
497 template <typename T>
VectorWithConstantTail(int size,int tail_length)498 std::vector<T> VectorWithConstantTail(int size, int tail_length) {
499   CHECK_LE(tail_length, size);
500   std::vector<T> v(size, T(0));
501   for (int i = 0; i < size - tail_length; ++i) {
502     v[i] = T(i + 1);
503   }
504   return v;
505 }
506 
507 template <typename T>
CreateAsProtoTensorContent(int size,int tail_length)508 TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
509   auto values = VectorWithConstantTail<T>(size, tail_length);
510   Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
511   std::copy(values.begin(), values.end(), tensor.flat<T>().data());
512   TensorProto tensor_proto;
513   tensor.AsProtoTensorContent(&tensor_proto);
514   return tensor_proto;
515 }
516 
517 template <typename T>
CreateAsProtoField(int size,int tail_length)518 TensorProto CreateAsProtoField(int size, int tail_length) {
519   auto values = VectorWithConstantTail<T>(size, tail_length);
520   Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
521   std::copy(values.begin(), values.end(), tensor.flat<T>().data());
522   TensorProto tensor_proto;
523   tensor.AsProtoField(&tensor_proto);
524   return tensor_proto;
525 }
526 
527 template <typename T>
CompareTensorValues(const TensorProto & x,const TensorProto & y)528 void CompareTensorValues(const TensorProto& x, const TensorProto& y) {
529   Tensor x_t;
530   EXPECT_TRUE(x_t.FromProto(x));
531   Tensor y_t;
532   EXPECT_TRUE(y_t.FromProto(y));
533   test::ExpectTensorEqual<T>(x_t, y_t);
534 }
535 
536 template <typename T>
ConstantTailTest(int64 length,int64 tail_length,bool as_field)537 void ConstantTailTest(int64 length, int64 tail_length, bool as_field) {
538   using TensorProtoHelper = tensor::internal::TensorProtoHelper<T>;
539   using FieldType = typename TensorProtoHelper::FieldType;
540   const float kMinCompressionRatio = 2.0;
541   const int64 kMinSize = 64;
542   TensorProto tensor_proto =
543       as_field ? CreateAsProtoField<T>(length, tail_length)
544                : CreateAsProtoTensorContent<T>(length, tail_length);
545   TensorProto original_tensor_proto = tensor_proto;
546   int64 original_size =
547       length * (as_field ? (is_complex<T>::value ? 2 : 1) * sizeof(FieldType)
548                          : sizeof(T));
549   int64 size_as_tensor_content = length * sizeof(T);
550   int64 size_as_field = std::min(length, (length - tail_length + 1)) *
551                         (is_complex<T>::value ? 2 : 1) * sizeof(FieldType);
552   bool will_compress = std::min(size_as_tensor_content, size_as_field) <=
553                        static_cast<int64>(original_size / kMinCompressionRatio);
554 
555   EXPECT_EQ(tensor::CompressTensorProtoInPlace(kMinSize, kMinCompressionRatio,
556                                                &tensor_proto),
557             will_compress);
558   if (will_compress) {
559     if (size_as_tensor_content < size_as_field) {
560       EXPECT_EQ(TensorProtoHelper::NumValues(tensor_proto), 0);
561       EXPECT_FALSE(tensor_proto.tensor_content().empty());
562     } else {
563       EXPECT_LE(TensorProtoHelper::NumValues(tensor_proto),
564                 (length - tail_length + 1));
565       EXPECT_TRUE(tensor_proto.tensor_content().empty());
566     }
567   }
568   CompareTensorValues<T>(tensor_proto, original_tensor_proto);
569 }
570 
TEST(TensorProtoUtil,CompressTensorProtoConstantTail)571 TEST(TensorProtoUtil, CompressTensorProtoConstantTail) {
572   const int kLength = 64;
573   for (bool as_field : {true, false}) {
574     for (int tail_length : {0, 1, 2, 32, 33, 63, 64}) {
575       ConstantTailTest<float>(kLength, tail_length, as_field);
576       ConstantTailTest<double>(kLength, tail_length, as_field);
577       ConstantTailTest<complex64>(kLength, tail_length, as_field);
578       ConstantTailTest<complex128>(kLength, tail_length, as_field);
579       ConstantTailTest<int32>(kLength, tail_length, as_field);
580       ConstantTailTest<uint32>(kLength, tail_length, as_field);
581       ConstantTailTest<int64>(kLength, tail_length, as_field);
582       ConstantTailTest<uint64>(kLength, tail_length, as_field);
583       ConstantTailTest<int8>(kLength, tail_length, as_field);
584       ConstantTailTest<uint8>(kLength, tail_length, as_field);
585       ConstantTailTest<int16>(kLength, tail_length, as_field);
586       ConstantTailTest<uint16>(kLength, tail_length, as_field);
587       ConstantTailTest<Eigen::half>(kLength, tail_length, as_field);
588       ConstantTailTest<bfloat16>(kLength, tail_length, as_field);
589     }
590   }
591 }
592 
593 }  // namespace
594 }  // namespace tensorflow
595