1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17 
18 #include <random>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/framework/tensor_util.h"
23 #include "tensorflow/core/framework/types.pb.h"
24 #include "tensorflow/core/framework/variant.h"
25 #include "tensorflow/core/framework/variant_op_registry.h"
26 #include "tensorflow/core/framework/versions.pb.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/lib/io/table_builder.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
35 
36 namespace tensorflow {
37 
38 namespace {
39 
40 // Prepend the current test case's working temporary directory to <prefix>
Prefix(const string & prefix)41 string Prefix(const string& prefix) {
42   return strings::StrCat(testing::TmpDir(), "/", prefix);
43 }
44 
45 // Construct a data input directory by prepending the test data root
46 // directory to <prefix>
TestdataPrefix(const string & prefix)47 string TestdataPrefix(const string& prefix) {
48   return strings::StrCat(testing::TensorFlowSrcRoot(),
49                          "/core/util/tensor_bundle/testdata/", prefix);
50 }
51 
52 template <typename T>
Constant(T v,TensorShape shape)53 Tensor Constant(T v, TensorShape shape) {
54   Tensor ret(DataTypeToEnum<T>::value, shape);
55   ret.flat<T>().setConstant(v);
56   return ret;
57 }
58 
59 template <typename T>
Constant_2x3(T v)60 Tensor Constant_2x3(T v) {
61   return Constant(v, TensorShape({2, 3}));
62 }
63 
ByteSwap(Tensor t)64 Tensor ByteSwap(Tensor t) {
65   Tensor ret = tensor::DeepCopy(t);
66   TF_EXPECT_OK(ByteSwapTensor(&ret));
67   return ret;
68 }
69 
70 // Assert that <reader> has a tensor under <key> matching <expected_val> in
71 // terms of both shape, dtype, and value
72 template <typename T>
Expect(BundleReader * reader,const string & key,const Tensor & expected_val)73 void Expect(BundleReader* reader, const string& key,
74             const Tensor& expected_val) {
75   // Tests for Contains().
76   EXPECT_TRUE(reader->Contains(key));
77   // Tests for LookupDtypeAndShape().
78   DataType dtype;
79   TensorShape shape;
80   TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
81   EXPECT_EQ(expected_val.dtype(), dtype);
82   EXPECT_EQ(expected_val.shape(), shape);
83   // Tests for Lookup(), checking tensor contents.
84   Tensor val(expected_val.dtype(), shape);
85   TF_ASSERT_OK(reader->Lookup(key, &val));
86   test::ExpectTensorEqual<T>(val, expected_val);
87 }
88 
89 template <class T>
ExpectVariant(BundleReader * reader,const string & key,const Tensor & expected_t)90 void ExpectVariant(BundleReader* reader, const string& key,
91                    const Tensor& expected_t) {
92   // Tests for Contains().
93   EXPECT_TRUE(reader->Contains(key));
94   // Tests for LookupDtypeAndShape().
95   DataType dtype;
96   TensorShape shape;
97   TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
98   // Tests for Lookup(), checking tensor contents.
99   EXPECT_EQ(expected_t.dtype(), dtype);
100   EXPECT_EQ(expected_t.shape(), shape);
101   Tensor actual_t(dtype, shape);
102   TF_ASSERT_OK(reader->Lookup(key, &actual_t));
103   for (int i = 0; i < expected_t.NumElements(); i++) {
104     Variant actual_var = actual_t.flat<Variant>()(i);
105     Variant expected_var = expected_t.flat<Variant>()(i);
106     EXPECT_EQ(actual_var.TypeName(), expected_var.TypeName());
107     auto* actual_val = actual_var.get<T>();
108     auto* expected_val = expected_var.get<T>();
109     EXPECT_EQ(*expected_val, *actual_val);
110   }
111 }
112 
113 template <typename T>
ExpectNext(BundleReader * reader,const Tensor & expected_val)114 void ExpectNext(BundleReader* reader, const Tensor& expected_val) {
115   EXPECT_TRUE(reader->Valid());
116   reader->Next();
117   TF_ASSERT_OK(reader->status());
118   Tensor val;
119   TF_ASSERT_OK(reader->ReadCurrent(&val));
120   test::ExpectTensorEqual<T>(val, expected_val);
121 }
122 
AllTensorKeys(BundleReader * reader)123 std::vector<string> AllTensorKeys(BundleReader* reader) {
124   std::vector<string> ret;
125   reader->Seek(kHeaderEntryKey);
126   reader->Next();
127   for (; reader->Valid(); reader->Next()) {
128     ret.emplace_back(reader->key());
129   }
130   return ret;
131 }
132 
133 // Writes out the metadata file of a bundle again, with the endianness marker
134 // bit flipped.
FlipEndiannessBit(const string & prefix)135 Status FlipEndiannessBit(const string& prefix) {
136   Env* env = Env::Default();
137   const string metadata_tmp_path = Prefix("some_tmp_path");
138   std::unique_ptr<WritableFile> metadata_file;
139   TF_RETURN_IF_ERROR(env->NewWritableFile(metadata_tmp_path, &metadata_file));
140   // We create the builder lazily in case we run into an exception earlier, in
141   // which case we'd forget to call Finish() and TableBuilder's destructor
142   // would complain.
143   std::unique_ptr<table::TableBuilder> builder;
144 
145   // Reads the existing metadata file, and fills the builder.
146   {
147     const string filename = MetaFilename(prefix);
148     uint64 file_size;
149     TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
150     std::unique_ptr<RandomAccessFile> file;
151     TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
152 
153     table::Table* table = nullptr;
154     TF_RETURN_IF_ERROR(
155         table::Table::Open(table::Options(), file.get(), file_size, &table));
156     std::unique_ptr<table::Table> table_deleter(table);
157     std::unique_ptr<table::Iterator> iter(table->NewIterator());
158 
159     // Reads the header entry.
160     iter->Seek(kHeaderEntryKey);
161     CHECK(iter->Valid());
162     BundleHeaderProto header;
163     CHECK(header.ParseFromArray(iter->value().data(), iter->value().size()));
164     // Flips the endianness.
165     if (header.endianness() == BundleHeaderProto::LITTLE) {
166       header.set_endianness(BundleHeaderProto::BIG);
167     } else {
168       header.set_endianness(BundleHeaderProto::LITTLE);
169     }
170     builder.reset(
171         new table::TableBuilder(table::Options(), metadata_file.get()));
172     builder->Add(iter->key(), header.SerializeAsString());
173     iter->Next();
174 
175     // Adds the non-header entries unmodified.
176     for (; iter->Valid(); iter->Next())
177       builder->Add(iter->key(), iter->value());
178   }
179   TF_RETURN_IF_ERROR(builder->Finish());
180   TF_RETURN_IF_ERROR(env->RenameFile(metadata_tmp_path, MetaFilename(prefix)));
181   return metadata_file->Close();
182 }
183 
184 template <typename T>
TestBasic()185 void TestBasic() {
186   {
187     BundleWriter writer(Env::Default(), Prefix("foo"));
188     TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3(T(3))));
189     TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3(T(0))));
190     TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3(T(2))));
191     TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3(T(1))));
192     TF_ASSERT_OK(writer.Finish());
193   }
194   {
195     BundleReader reader(Env::Default(), Prefix("foo"));
196     TF_ASSERT_OK(reader.status());
197     EXPECT_EQ(
198         AllTensorKeys(&reader),
199         std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
200     Expect<T>(&reader, "foo_000", Constant_2x3(T(0)));
201     Expect<T>(&reader, "foo_001", Constant_2x3(T(1)));
202     Expect<T>(&reader, "foo_002", Constant_2x3(T(2)));
203     Expect<T>(&reader, "foo_003", Constant_2x3(T(3)));
204   }
205   {
206     BundleReader reader(Env::Default(), Prefix("foo"));
207     TF_ASSERT_OK(reader.status());
208     ExpectNext<T>(&reader, Constant_2x3(T(0)));
209     ExpectNext<T>(&reader, Constant_2x3(T(1)));
210     ExpectNext<T>(&reader, Constant_2x3(T(2)));
211     ExpectNext<T>(&reader, Constant_2x3(T(3)));
212     EXPECT_TRUE(reader.Valid());
213     reader.Next();
214     EXPECT_FALSE(reader.Valid());
215   }
216   {
217     BundleWriter writer(Env::Default(), Prefix("bar"));
218     TF_EXPECT_OK(writer.Add("bar_003", Constant_2x3(T(3))));
219     TF_EXPECT_OK(writer.Add("bar_000", Constant_2x3(T(0))));
220     TF_EXPECT_OK(writer.Add("bar_002", Constant_2x3(T(2))));
221     TF_EXPECT_OK(writer.Add("bar_001", Constant_2x3(T(1))));
222     TF_ASSERT_OK(writer.Finish());
223   }
224   {
225     BundleReader reader(Env::Default(), Prefix("bar"));
226     TF_ASSERT_OK(reader.status());
227     EXPECT_EQ(
228         AllTensorKeys(&reader),
229         std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003"}));
230     Expect<T>(&reader, "bar_003", Constant_2x3(T(3)));
231     Expect<T>(&reader, "bar_002", Constant_2x3(T(2)));
232     Expect<T>(&reader, "bar_001", Constant_2x3(T(1)));
233     Expect<T>(&reader, "bar_000", Constant_2x3(T(0)));
234   }
235   {
236     BundleReader reader(Env::Default(), Prefix("bar"));
237     TF_ASSERT_OK(reader.status());
238     ExpectNext<T>(&reader, Constant_2x3(T(0)));
239     ExpectNext<T>(&reader, Constant_2x3(T(1)));
240     ExpectNext<T>(&reader, Constant_2x3(T(2)));
241     ExpectNext<T>(&reader, Constant_2x3(T(3)));
242     EXPECT_TRUE(reader.Valid());
243     reader.Next();
244     EXPECT_FALSE(reader.Valid());
245   }
246   TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")},
247                             Prefix("merged")));
248   {
249     BundleReader reader(Env::Default(), Prefix("merged"));
250     TF_ASSERT_OK(reader.status());
251     EXPECT_EQ(
252         AllTensorKeys(&reader),
253         std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003",
254                              "foo_000", "foo_001", "foo_002", "foo_003"}));
255     Expect<T>(&reader, "bar_000", Constant_2x3(T(0)));
256     Expect<T>(&reader, "bar_001", Constant_2x3(T(1)));
257     Expect<T>(&reader, "bar_002", Constant_2x3(T(2)));
258     Expect<T>(&reader, "bar_003", Constant_2x3(T(3)));
259     Expect<T>(&reader, "foo_000", Constant_2x3(T(0)));
260     Expect<T>(&reader, "foo_001", Constant_2x3(T(1)));
261     Expect<T>(&reader, "foo_002", Constant_2x3(T(2)));
262     Expect<T>(&reader, "foo_003", Constant_2x3(T(3)));
263   }
264   {
265     BundleReader reader(Env::Default(), Prefix("merged"));
266     TF_ASSERT_OK(reader.status());
267     ExpectNext<T>(&reader, Constant_2x3(T(0)));
268     ExpectNext<T>(&reader, Constant_2x3(T(1)));
269     ExpectNext<T>(&reader, Constant_2x3(T(2)));
270     ExpectNext<T>(&reader, Constant_2x3(T(3)));
271     ExpectNext<T>(&reader, Constant_2x3(T(0)));
272     ExpectNext<T>(&reader, Constant_2x3(T(1)));
273     ExpectNext<T>(&reader, Constant_2x3(T(2)));
274     ExpectNext<T>(&reader, Constant_2x3(T(3)));
275     EXPECT_TRUE(reader.Valid());
276     reader.Next();
277     EXPECT_FALSE(reader.Valid());
278   }
279 }
280 
281 // Type-specific subroutine of SwapBytes test below
282 template <typename T>
TestByteSwap(const T * forward,const T * swapped,int array_len)283 void TestByteSwap(const T* forward, const T* swapped, int array_len) {
284   auto bytes_per_elem = sizeof(T);
285 
286   // Convert the entire array at once
287   std::unique_ptr<T[]> forward_copy(new T[array_len]);
288   std::memcpy(forward_copy.get(), forward, array_len * bytes_per_elem);
289   TF_EXPECT_OK(ByteSwapArray(reinterpret_cast<char*>(forward_copy.get()),
290                              bytes_per_elem, array_len));
291   for (int i = 0; i < array_len; i++) {
292     EXPECT_EQ(forward_copy.get()[i], swapped[i]);
293   }
294 
295   // Then the array wrapped in a tensor
296   auto shape = TensorShape({array_len});
297   auto dtype = DataTypeToEnum<T>::value;
298   Tensor forward_tensor(dtype, shape);
299   Tensor swapped_tensor(dtype, shape);
300   std::memcpy(const_cast<char*>(forward_tensor.tensor_data().data()), forward,
301               array_len * bytes_per_elem);
302   std::memcpy(const_cast<char*>(swapped_tensor.tensor_data().data()), swapped,
303               array_len * bytes_per_elem);
304   TF_EXPECT_OK(ByteSwapTensor(&forward_tensor));
305   test::ExpectTensorEqual<T>(forward_tensor, swapped_tensor);
306 }
307 
308 // Unit test of the byte-swapping operations that TensorBundle uses.
TEST(TensorBundleTest,SwapBytes)309 TEST(TensorBundleTest, SwapBytes) {
310   // A bug in the compiler on MacOS causes ByteSwap() and FlipEndiannessBit()
311   // to be removed from the executable if they are only called from templated
312   // functions. As a workaround, we make some dummy calls here.
313   // TODO(frreiss): Remove this workaround when the compiler bug is fixed.
314   ByteSwap(Constant_2x3<int>(42));
315   EXPECT_NE(Status::OK(), FlipEndiannessBit(Prefix("not_a_valid_prefix")));
316 
317   // Test patterns, manually swapped so that we aren't relying on the
318   // correctness of our own byte-swapping macros when testing those macros.
319   // At least one of the entries in each list has the sign bit set when
320   // interpreted as a signed int.
321   const int arr_len_16 = 4;
322   const uint16_t forward_16[] = {0x1de5, 0xd017, 0xf1ea, 0xc0a1};
323   const uint16_t swapped_16[] = {0xe51d, 0x17d0, 0xeaf1, 0xa1c0};
324   const int arr_len_32 = 2;
325   const uint32_t forward_32[] = {0x0ddba115, 0xf01dab1e};
326   const uint32_t swapped_32[] = {0x15a1db0d, 0x1eab1df0};
327   const int arr_len_64 = 2;
328   const uint64_t forward_64[] = {0xf005ba11caba1000, 0x5ca1ab1ecab005e5};
329   const uint64_t swapped_64[] = {0x0010baca11ba05f0, 0xe505b0ca1eaba15c};
330 
331   // 16-bit types
332   TestByteSwap(forward_16, swapped_16, arr_len_16);
333   TestByteSwap(reinterpret_cast<const int16_t*>(forward_16),
334                reinterpret_cast<const int16_t*>(swapped_16), arr_len_16);
335   TestByteSwap(reinterpret_cast<const bfloat16*>(forward_16),
336                reinterpret_cast<const bfloat16*>(swapped_16), arr_len_16);
337 
338   // 32-bit types
339   TestByteSwap(forward_32, swapped_32, arr_len_32);
340   TestByteSwap(reinterpret_cast<const int32_t*>(forward_32),
341                reinterpret_cast<const int32_t*>(swapped_32), arr_len_32);
342   TestByteSwap(reinterpret_cast<const float*>(forward_32),
343                reinterpret_cast<const float*>(swapped_32), arr_len_32);
344 
345   // 64-bit types
346   // Cast to uint64*/int64* to make DataTypeToEnum<T> happy
347   TestByteSwap(reinterpret_cast<const uint64*>(forward_64),
348                reinterpret_cast<const uint64*>(swapped_64), arr_len_64);
349   TestByteSwap(reinterpret_cast<const int64*>(forward_64),
350                reinterpret_cast<const int64*>(swapped_64), arr_len_64);
351   TestByteSwap(reinterpret_cast<const double*>(forward_64),
352                reinterpret_cast<const double*>(swapped_64), arr_len_64);
353 
354   // Complex types.
355   // Logic for complex number handling is only in ByteSwapTensor, so don't test
356   // ByteSwapArray
357   const float* forward_float = reinterpret_cast<const float*>(forward_32);
358   const float* swapped_float = reinterpret_cast<const float*>(swapped_32);
359   const double* forward_double = reinterpret_cast<const double*>(forward_64);
360   const double* swapped_double = reinterpret_cast<const double*>(swapped_64);
361   Tensor forward_complex64 = Constant_2x3<complex64>(
362       std::complex<float>(forward_float[0], forward_float[1]));
363   Tensor swapped_complex64 = Constant_2x3<complex64>(
364       std::complex<float>(swapped_float[0], swapped_float[1]));
365   Tensor forward_complex128 = Constant_2x3<complex128>(
366       std::complex<double>(forward_double[0], forward_double[1]));
367   Tensor swapped_complex128 = Constant_2x3<complex128>(
368       std::complex<double>(swapped_double[0], swapped_double[1]));
369 
370   TF_EXPECT_OK(ByteSwapTensor(&forward_complex64));
371   test::ExpectTensorEqual<complex64>(forward_complex64, swapped_complex64);
372 
373   TF_EXPECT_OK(ByteSwapTensor(&forward_complex128));
374   test::ExpectTensorEqual<complex128>(forward_complex128, swapped_complex128);
375 }
376 
377 // Basic test of alternate-endianness support. Generates a bundle in
378 // the opposite of the current system's endianness and attempts to
379 // read the bundle back in. Does not exercise sharding or access to
380 // nonaligned tensors. Does cover the major access types exercised
381 // in TestBasic.
382 template <typename T>
TestEndianness()383 void TestEndianness() {
384   {
385     // Write out a TensorBundle in the opposite of this host's endianness.
386     BundleWriter writer(Env::Default(), Prefix("foo"));
387     TF_EXPECT_OK(writer.Add("foo_003", ByteSwap(Constant_2x3<T>(T(3)))));
388     TF_EXPECT_OK(writer.Add("foo_000", ByteSwap(Constant_2x3<T>(T(0)))));
389     TF_EXPECT_OK(writer.Add("foo_002", ByteSwap(Constant_2x3<T>(T(2)))));
390     TF_EXPECT_OK(writer.Add("foo_001", ByteSwap(Constant_2x3<T>(T(1)))));
391     TF_ASSERT_OK(writer.Finish());
392     TF_ASSERT_OK(FlipEndiannessBit(Prefix("foo")));
393   }
394   {
395     BundleReader reader(Env::Default(), Prefix("foo"));
396     TF_ASSERT_OK(reader.status());
397     EXPECT_EQ(
398         AllTensorKeys(&reader),
399         std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
400     Expect<T>(&reader, "foo_000", Constant_2x3<T>(T(0)));
401     Expect<T>(&reader, "foo_001", Constant_2x3<T>(T(1)));
402     Expect<T>(&reader, "foo_002", Constant_2x3<T>(T(2)));
403     Expect<T>(&reader, "foo_003", Constant_2x3<T>(T(3)));
404   }
405   {
406     BundleReader reader(Env::Default(), Prefix("foo"));
407     TF_ASSERT_OK(reader.status());
408     ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
409     ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
410     ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
411     ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
412     EXPECT_TRUE(reader.Valid());
413     reader.Next();
414     EXPECT_FALSE(reader.Valid());
415   }
416   {
417     BundleWriter writer(Env::Default(), Prefix("bar"));
418     TF_EXPECT_OK(writer.Add("bar_003", ByteSwap(Constant_2x3<T>(T(3)))));
419     TF_EXPECT_OK(writer.Add("bar_000", ByteSwap(Constant_2x3<T>(T(0)))));
420     TF_EXPECT_OK(writer.Add("bar_002", ByteSwap(Constant_2x3<T>(T(2)))));
421     TF_EXPECT_OK(writer.Add("bar_001", ByteSwap(Constant_2x3<T>(T(1)))));
422     TF_ASSERT_OK(writer.Finish());
423     TF_ASSERT_OK(FlipEndiannessBit(Prefix("bar")));
424   }
425   {
426     BundleReader reader(Env::Default(), Prefix("bar"));
427     TF_ASSERT_OK(reader.status());
428     EXPECT_EQ(
429         AllTensorKeys(&reader),
430         std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003"}));
431     Expect<T>(&reader, "bar_003", Constant_2x3<T>(T(3)));
432     Expect<T>(&reader, "bar_002", Constant_2x3<T>(T(2)));
433     Expect<T>(&reader, "bar_001", Constant_2x3<T>(T(1)));
434     Expect<T>(&reader, "bar_000", Constant_2x3<T>(T(0)));
435   }
436   {
437     BundleReader reader(Env::Default(), Prefix("bar"));
438     TF_ASSERT_OK(reader.status());
439     ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
440     ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
441     ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
442     ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
443     EXPECT_TRUE(reader.Valid());
444     reader.Next();
445     EXPECT_FALSE(reader.Valid());
446   }
447   TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")},
448                             Prefix("merged")));
449   {
450     BundleReader reader(Env::Default(), Prefix("merged"));
451     TF_ASSERT_OK(reader.status());
452     EXPECT_EQ(
453         AllTensorKeys(&reader),
454         std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003",
455                              "foo_000", "foo_001", "foo_002", "foo_003"}));
456     Expect<T>(&reader, "bar_000", Constant_2x3<T>(T(0)));
457     Expect<T>(&reader, "bar_001", Constant_2x3<T>(T(1)));
458     Expect<T>(&reader, "bar_002", Constant_2x3<T>(T(2)));
459     Expect<T>(&reader, "bar_003", Constant_2x3<T>(T(3)));
460     Expect<T>(&reader, "foo_000", Constant_2x3<T>(T(0)));
461     Expect<T>(&reader, "foo_001", Constant_2x3<T>(T(1)));
462     Expect<T>(&reader, "foo_002", Constant_2x3<T>(T(2)));
463     Expect<T>(&reader, "foo_003", Constant_2x3<T>(T(3)));
464   }
465   {
466     BundleReader reader(Env::Default(), Prefix("merged"));
467     TF_ASSERT_OK(reader.status());
468     ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
469     ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
470     ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
471     ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
472     ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
473     ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
474     ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
475     ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
476     EXPECT_TRUE(reader.Valid());
477     reader.Next();
478     EXPECT_FALSE(reader.Valid());
479   }
480 }
481 
482 template <typename T>
TestNonStandardShapes()483 void TestNonStandardShapes() {
484   {
485     BundleWriter writer(Env::Default(), Prefix("nonstandard"));
486     TF_EXPECT_OK(writer.Add("scalar", Constant(T(0), TensorShape())));
487     TF_EXPECT_OK(
488         writer.Add("non_standard0", Constant(T(0), TensorShape({0, 1618}))));
489     TF_EXPECT_OK(
490         writer.Add("non_standard1", Constant(T(0), TensorShape({16, 0, 18}))));
491     TF_ASSERT_OK(writer.Finish());
492   }
493   {
494     BundleReader reader(Env::Default(), Prefix("nonstandard"));
495     TF_ASSERT_OK(reader.status());
496     Expect<T>(&reader, "scalar", Constant(T(0), TensorShape()));
497     Expect<T>(&reader, "non_standard0", Constant(T(0), TensorShape({0, 1618})));
498     Expect<T>(&reader, "non_standard1",
499               Constant(T(0), TensorShape({16, 0, 18})));
500   }
501 }
502 
503 // Writes a bundle to disk with a bad "version"; checks for "expected_error".
VersionTest(const VersionDef & version,StringPiece expected_error)504 void VersionTest(const VersionDef& version, StringPiece expected_error) {
505   const string path = Prefix("version_test");
506   {
507     // Prepare an empty bundle with the given version information.
508     BundleHeaderProto header;
509     *header.mutable_version() = version;
510 
511     // Write the metadata file to disk.
512     std::unique_ptr<WritableFile> file;
513     TF_ASSERT_OK(Env::Default()->NewWritableFile(MetaFilename(path), &file));
514     table::TableBuilder builder(table::Options(), file.get());
515     builder.Add(kHeaderEntryKey, header.SerializeAsString());
516     TF_ASSERT_OK(builder.Finish());
517   }
518   // Read it back in and verify that we get the expected error.
519   BundleReader reader(Env::Default(), path);
520   EXPECT_TRUE(errors::IsInvalidArgument(reader.status()));
521   EXPECT_TRUE(
522       absl::StartsWith(reader.status().error_message(), expected_error));
523 }
524 
525 }  // namespace
526 
TEST(TensorBundleTest,Basic)527 TEST(TensorBundleTest, Basic) {
528   TestBasic<float>();
529   TestBasic<double>();
530   TestBasic<int32>();
531   TestBasic<uint8>();
532   TestBasic<int16>();
533   TestBasic<int8>();
534   TestBasic<complex64>();
535   TestBasic<complex128>();
536   TestBasic<int64>();
537   TestBasic<bool>();
538   TestBasic<qint32>();
539   TestBasic<quint8>();
540   TestBasic<qint8>();
541   TestBasic<bfloat16>();
542 }
543 
TEST(TensorBundleTest,Endianness)544 TEST(TensorBundleTest, Endianness) {
545   TestEndianness<float>();
546   TestEndianness<double>();
547   TestEndianness<int32>();
548   TestEndianness<uint8>();
549   TestEndianness<int16>();
550   TestEndianness<int8>();
551   TestEndianness<complex64>();
552   TestEndianness<complex128>();
553   TestEndianness<int64>();
554   TestEndianness<bool>();
555   TestEndianness<qint32>();
556   TestEndianness<quint8>();
557   TestEndianness<qint8>();
558   TestEndianness<bfloat16>();
559 }
560 
TEST(TensorBundleTest,PartitionedVariables)561 TEST(TensorBundleTest, PartitionedVariables) {
562   const TensorShape kFullShape({5, 10});
563   // Adds two slices.
564   // First slice: column 0, all zeros.
565   // Second slice: column 1 to rest, all ones.
566   TensorSlice slice1 = TensorSlice::ParseOrDie("-:0,1");
567   TensorSlice slice2 = TensorSlice::ParseOrDie("-:1,9");
568   {
569     BundleWriter writer(Env::Default(), Prefix("foo"));
570 
571     TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice1,
572                                  Constant<float>(0., TensorShape({5, 1}))));
573     TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice2,
574                                  Constant<float>(1., TensorShape({5, 9}))));
575     TF_ASSERT_OK(writer.Finish());
576   }
577   // Reads in full.
578   {
579     BundleReader reader(Env::Default(), Prefix("foo"));
580     TF_ASSERT_OK(reader.status());
581 
582     Tensor expected_val(DT_FLOAT, kFullShape);
583     test::FillFn<float>(&expected_val, [](int offset) -> float {
584       if (offset % 10 == 0) {
585         return 0;  // First column zeros.
586       }
587       return 1;  // Other columns ones.
588     });
589 
590     Tensor val(DT_FLOAT, kFullShape);
591     TF_ASSERT_OK(reader.Lookup("foo", &val));
592     test::ExpectTensorEqual<float>(val, expected_val);
593   }
594   // Reads all slices.
595   {
596     BundleReader reader(Env::Default(), Prefix("foo"));
597     TF_ASSERT_OK(reader.status());
598 
599     std::vector<TensorSlice> slices;
600     TF_ASSERT_OK(reader.LookupTensorSlices("foo", &slices));
601 
602     EXPECT_EQ(2, slices.size());
603     EXPECT_EQ(slice1.DebugString(), slices[0].DebugString());
604     EXPECT_EQ(slice2.DebugString(), slices[1].DebugString());
605   }
606   // Reads a slice consisting of first two columns, "cutting" both slices.
607   {
608     BundleReader reader(Env::Default(), Prefix("foo"));
609     TF_ASSERT_OK(reader.status());
610 
611     // First two columns, "cutting" both slices.
612     const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:0,2");
613     Tensor expected_val(DT_FLOAT, TensorShape({5, 2}));
614     test::FillFn<float>(&expected_val, [](int offset) -> float {
615       if (offset % 2 == 0) {
616         return 0;  // First column zeros.
617       }
618       return 1;  // Other columns ones.
619     });
620 
621     Tensor val(DT_FLOAT, TensorShape({5, 2}));
622     TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
623     test::ExpectTensorEqual<float>(val, expected_val);
624   }
625   // Reads a slice consisting of columns 2-4, "cutting" the second slice only.
626   {
627     BundleReader reader(Env::Default(), Prefix("foo"));
628     TF_ASSERT_OK(reader.status());
629 
630     const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:2,2");
631     Tensor val(DT_FLOAT, TensorShape({5, 2}));
632     TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
633     test::ExpectTensorEqual<float>(val,
634                                    Constant<float>(1., TensorShape({5, 2})));
635   }
636 }
637 
TEST(TensorBundleTest,EquivalentSliceTest)638 TEST(TensorBundleTest, EquivalentSliceTest) {
639   const TensorShape kFullShape({5, 10});
640   const Tensor kExpected(Constant<float>(1., kFullShape));
641   {
642     BundleWriter writer(Env::Default(), Prefix("foo"));
643     TF_ASSERT_OK(writer.AddSlice("no_extents", kFullShape,
644                                  TensorSlice::ParseOrDie("-:-"), kExpected));
645     TF_ASSERT_OK(writer.AddSlice("both_extents", kFullShape,
646                                  TensorSlice::ParseOrDie("0,5:0,10"),
647                                  kExpected));
648     TF_ASSERT_OK(writer.Finish());
649   }
650   // Slices match exactly and are fully abbreviated.
651   {
652     BundleReader reader(Env::Default(), Prefix("foo"));
653     TF_ASSERT_OK(reader.status());
654     const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
655     Tensor val(DT_FLOAT, TensorShape(kFullShape));
656     TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
657     test::ExpectTensorEqual<float>(val, kExpected);
658   }
659   // Slice match exactly and are fully specified.
660   {
661     BundleReader reader(Env::Default(), Prefix("foo"));
662     TF_ASSERT_OK(reader.status());
663     const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
664     Tensor val(DT_FLOAT, TensorShape(kFullShape));
665     TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
666     test::ExpectTensorEqual<float>(val, kExpected);
667   }
668   // Stored slice has no extents, spec has extents.
669   {
670     BundleReader reader(Env::Default(), Prefix("foo"));
671     TF_ASSERT_OK(reader.status());
672     const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
673     Tensor val(DT_FLOAT, TensorShape(kFullShape));
674     TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
675     test::ExpectTensorEqual<float>(val, kExpected);
676   }
677   // Stored slice has both extents, spec has no extents.
678   {
679     BundleReader reader(Env::Default(), Prefix("foo"));
680     TF_ASSERT_OK(reader.status());
681     const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
682     Tensor val(DT_FLOAT, TensorShape(kFullShape));
683     TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
684     test::ExpectTensorEqual<float>(val, kExpected);
685   }
686 }
687 
TEST(TensorBundleTest,NonStandardShapes)688 TEST(TensorBundleTest, NonStandardShapes) {
689   TestNonStandardShapes<float>();
690   TestNonStandardShapes<double>();
691   TestNonStandardShapes<int32>();
692   TestNonStandardShapes<uint8>();
693   TestNonStandardShapes<int16>();
694   TestNonStandardShapes<int8>();
695   TestNonStandardShapes<complex64>();
696   TestNonStandardShapes<complex128>();
697   TestNonStandardShapes<int64>();
698   TestNonStandardShapes<bool>();
699   TestNonStandardShapes<qint32>();
700   TestNonStandardShapes<quint8>();
701   TestNonStandardShapes<qint8>();
702   TestNonStandardShapes<bfloat16>();
703 }
704 
TEST(TensorBundleTest,StringTensorsOldFormat)705 TEST(TensorBundleTest, StringTensorsOldFormat) {
706   // Test string tensor bundle made with previous version of code that use
707   // varint32s to store string lengths (we now use varint64s).
708   BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo"));
709   TF_ASSERT_OK(reader.status());
710   EXPECT_EQ(AllTensorKeys(&reader),
711             std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
712 
713   Expect<tstring>(&reader, "string_tensor",
714                   Tensor(DT_STRING, TensorShape({1})));
715   Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
716   Expect<tstring>(
717       &reader, "strs",
718       test::AsTensor<tstring>({"hello", "", "x01", string(1 << 10, 'c')}));
719   Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
720 }
721 
TEST(TensorBundleTest,StringTensors)722 TEST(TensorBundleTest, StringTensors) {
723   constexpr size_t kLongLength = static_cast<size_t>(UINT32_MAX) + 1;
724   Tensor long_string_tensor(DT_STRING, TensorShape({1}));
725 
726   {
727     BundleWriter writer(Env::Default(), Prefix("foo"));
728     TF_EXPECT_OK(writer.Add("string_tensor",
729                             Tensor(DT_STRING, TensorShape({1}))));  // Empty.
730     TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<tstring>({"hello"})));
731     TF_EXPECT_OK(writer.Add(
732         "strs",
733         test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')})));
734 
735     // Requires a 64-bit length.
736     tstring* backing_string = long_string_tensor.flat<tstring>().data();
737     backing_string->resize_uninitialized(kLongLength);
738     std::char_traits<char>::assign(backing_string->data(), kLongLength, 'd');
739     TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
740 
741     // Mixes in some floats.
742     TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18)));
743     TF_ASSERT_OK(writer.Finish());
744   }
745   {
746     BundleReader reader(Env::Default(), Prefix("foo"));
747     TF_ASSERT_OK(reader.status());
748     EXPECT_EQ(AllTensorKeys(&reader),
749               std::vector<string>({"floats", "long_scalar", "scalar",
750                                    "string_tensor", "strs"}));
751 
752     Expect<tstring>(&reader, "string_tensor",
753                     Tensor(DT_STRING, TensorShape({1})));
754     Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
755     Expect<tstring>(
756         &reader, "strs",
757         test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')}));
758 
759     Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
760 
761     // We don't use the Expect function so we can re-use the
762     // `long_string_tensor` buffer for reading out long_scalar to keep memory
763     // usage reasonable.
764     EXPECT_TRUE(reader.Contains("long_scalar"));
765     DataType dtype;
766     TensorShape shape;
767     TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape));
768     EXPECT_EQ(DT_STRING, dtype);
769     EXPECT_EQ(TensorShape({1}), shape);
770 
771     // Fill the string differently so that we can be sure the new one is read
772     // in. Because fragmentation in tc-malloc and we have such a big tensor
773     // of 4GB, therefore it is not ideal to free the buffer right now.
774     // The rationale is to make allocation/free close to each other.
775     tstring* backing_string = long_string_tensor.flat<tstring>().data();
776     std::char_traits<char>::assign(backing_string->data(), kLongLength, 'e');
777 
778     // Read long_scalar and check it contains kLongLength 'd's.
779     TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
780     ASSERT_EQ(backing_string, long_string_tensor.flat<tstring>().data());
781     EXPECT_EQ(kLongLength, backing_string->length());
782     for (size_t i = 0; i < kLongLength; i++) {
783       // Not using ASSERT_EQ('d', c) because this way is twice as fast due to
784       // compiler optimizations.
785       if ((*backing_string)[i] != 'd') {
786         FAIL() << "long_scalar is not full of 'd's as expected.";
787         break;
788       }
789     }
790   }
791 }
792 
793 class VariantObject {
794  public:
VariantObject()795   VariantObject() {}
VariantObject(const string & metadata,int64 value)796   VariantObject(const string& metadata, int64 value)
797       : metadata_(metadata), value_(value) {}
798 
TypeName() const799   string TypeName() const { return "TEST VariantObject"; }
Encode(VariantTensorData * data) const800   void Encode(VariantTensorData* data) const {
801     data->set_type_name(TypeName());
802     data->set_metadata(metadata_);
803     Tensor val_t = Tensor(DT_INT64, TensorShape({}));
804     val_t.scalar<int64>()() = value_;
805     *(data->add_tensors()) = val_t;
806   }
Decode(const VariantTensorData & data)807   bool Decode(const VariantTensorData& data) {
808     EXPECT_EQ(data.type_name(), TypeName());
809     data.get_metadata(&metadata_);
810     EXPECT_EQ(data.tensors_size(), 1);
811     value_ = data.tensors(0).scalar<int64>()();
812     return true;
813   }
operator ==(const VariantObject other) const814   bool operator==(const VariantObject other) const {
815     return metadata_ == other.metadata_ && value_ == other.value_;
816   }
817   string metadata_;
818   int64 value_;
819 };
820 
821 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantObject, "TEST VariantObject");
822 
TEST(TensorBundleTest,VariantTensors)823 TEST(TensorBundleTest, VariantTensors) {
824   {
825     BundleWriter writer(Env::Default(), Prefix("foo"));
826     TF_EXPECT_OK(
827         writer.Add("variant_tensor",
828                    test::AsTensor<Variant>({VariantObject("test", 10),
829                                             VariantObject("test1", 20)})));
830     TF_ASSERT_OK(writer.Finish());
831   }
832   {
833     BundleReader reader(Env::Default(), Prefix("foo"));
834     TF_ASSERT_OK(reader.status());
835     ExpectVariant<VariantObject>(
836         &reader, "variant_tensor",
837         test::AsTensor<Variant>(
838             {VariantObject("test", 10), VariantObject("test1", 20)}));
839   }
840 }
841 
TEST(TensorBundleTest,DirectoryStructure)842 TEST(TensorBundleTest, DirectoryStructure) {
843   Env* env = Env::Default();
844   // Writes two bundles.
845   const std::vector<string> kBundlePrefixes = {Prefix("worker0"),
846                                                Prefix("worker1")};
847   for (int i = 0; i < 2; ++i) {
848     BundleWriter writer(env, kBundlePrefixes[i]);
849     TF_EXPECT_OK(
850         writer.Add(strings::StrCat("tensor", i), Constant_2x3<float>(0.)));
851     TF_ASSERT_OK(writer.Finish());
852   }
853 
854   // Ensures we have the expected files.
855   auto CheckDirFiles = [env](const string& bundle_prefix,
856                              gtl::ArraySlice<string> expected_files) {
857     StringPiece dir = io::Dirname(bundle_prefix);
858     for (const string& expected_file : expected_files) {
859       TF_EXPECT_OK(env->FileExists(io::JoinPath(dir, expected_file)));
860     }
861   };
862 
863   // Check we have:
864   //   worker<i>.index
865   //   worker<i>.data-00000-of-00001
866   CheckDirFiles(kBundlePrefixes[0],
867                 {"worker0.index", "worker0.data-00000-of-00001"});
868   CheckDirFiles(kBundlePrefixes[1],
869                 {"worker1.index", "worker1.data-00000-of-00001"});
870 
871   // Trivially "merge" one bundle to some other location (i.e., a renaming).
872   const string kAnotherPrefix = Prefix("another");
873   TF_ASSERT_OK(MergeBundles(env, {kBundlePrefixes[0]}, kAnotherPrefix));
874   CheckDirFiles(kAnotherPrefix,
875                 {"another.index", "another.data-00000-of-00001"});
876 
877   // Performs actual merge of the two bundles.  Check we have:
878   //   merged.index
879   //   merged.data-00000-of-00002
880   //   merged.data-00001-of-00002
881   const string kMerged = Prefix("merged");
882   TF_ASSERT_OK(
883       MergeBundles(env, {kAnotherPrefix, kBundlePrefixes[1]}, kMerged));
884   CheckDirFiles(kMerged, {"merged.index", "merged.data-00000-of-00002",
885                           "merged.data-00001-of-00002"});
886 }
887 
TEST(TensorBundleTest,Error)888 TEST(TensorBundleTest, Error) {
889   {  // Dup keys.
890     BundleWriter writer(Env::Default(), Prefix("dup"));
891     TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
892     EXPECT_FALSE(writer.Add("foo", Constant_2x3(2.f)).ok());
893     EXPECT_TRUE(absl::StrContains(writer.status().ToString(), "duplicate key"));
894     EXPECT_FALSE(writer.Finish().ok());
895   }
896   {  // Double finish
897     BundleWriter writer(Env::Default(), Prefix("bad"));
898     EXPECT_TRUE(writer.Finish().ok());
899     EXPECT_FALSE(writer.Finish().ok());
900   }
901   {  // Not found.
902     BundleReader reader(Env::Default(), Prefix("nonexist"));
903     EXPECT_TRUE(absl::StrContains(reader.status().ToString(), "Not found"));
904   }
905 }
906 
TEST(TensorBundleTest,Checksum)907 TEST(TensorBundleTest, Checksum) {
908   // Randomly flips a byte in [pos_lhs, end of data file), or exactly byte
909   // pos_lhs if exact_pos == True.
910   auto FlipByte = [](const string& prefix, int pos_lhs,
911                      bool exact_pos = false) {
912     DCHECK_GE(pos_lhs, 0);
913     const string& datafile = DataFilename(Prefix(prefix), 0, 1);
914     string data;
915     TF_ASSERT_OK(ReadFileToString(Env::Default(), datafile, &data));
916 
917     int byte_pos = 0;
918     if (!exact_pos) {
919       std::mt19937 rng;
920       std::uniform_int_distribution<int> dist(pos_lhs, data.size() - 1);
921       byte_pos = dist(rng);
922     } else {
923       byte_pos = pos_lhs;
924     }
925     data[byte_pos] = ~data[byte_pos];
926     TF_ASSERT_OK(WriteStringToFile(Env::Default(), datafile, data));
927   };
928   // The lookup should fail with a checksum-related message.
929   auto ExpectLookupFails = [](const string& prefix, const string& key,
930                               const string& expected_msg, Tensor& val) {
931     BundleReader reader(Env::Default(), Prefix(prefix));
932     Status status = reader.Lookup(key, &val);
933     EXPECT_TRUE(errors::IsDataLoss(status));
934     EXPECT_TRUE(absl::StrContains(status.ToString(), expected_msg));
935   };
936 
937   // Corrupts a float tensor.
938   {
939     BundleWriter writer(Env::Default(), Prefix("singleton"));
940     TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
941     TF_ASSERT_OK(writer.Finish());
942 
943     FlipByte("singleton", 0 /* corrupts any byte */);
944     Tensor val(DT_FLOAT, TensorShape({2, 3}));
945     ExpectLookupFails("singleton", "foo",
946                       "Checksum does not match" /* expected fail msg */, val);
947   }
948   // Corrupts a string tensor.
949   {
950     auto WriteStrings = []() {
951       BundleWriter writer(Env::Default(), Prefix("strings"));
952       TF_EXPECT_OK(
953           writer.Add("foo", test::AsTensor<tstring>({"hello", "world"})));
954       TF_ASSERT_OK(writer.Finish());
955     };
956     // Corrupts the first two bytes, which are the varint32-encoded lengths
957     // of the two string elements.  Should hit mismatch on length cksum.
958     for (int i = 0; i < 2; ++i) {
959       WriteStrings();
960       FlipByte("strings", i, true /* corrupts exactly byte i */);
961       Tensor val(DT_STRING, TensorShape({2}));
962       ExpectLookupFails(
963           "strings", "foo",
964           "length checksum does not match" /* expected fail msg */, val);
965     }
966     // Corrupts the string bytes, should hit an overall cksum mismatch.
967     WriteStrings();
968     FlipByte("strings", 2 /* corrupts starting from byte 2 */);
969     Tensor val(DT_STRING, TensorShape({2}));
970     ExpectLookupFails("strings", "foo",
971                       "Checksum does not match" /* expected fail msg */, val);
972   }
973 }
974 
TEST(TensorBundleTest,TruncatedTensorContents)975 TEST(TensorBundleTest, TruncatedTensorContents) {
976   Env* env = Env::Default();
977   BundleWriter writer(env, Prefix("end"));
978   TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
979   TF_ASSERT_OK(writer.Finish());
980 
981   // Truncates the data file by one byte, so that we hit EOF.
982   const string datafile = DataFilename(Prefix("end"), 0, 1);
983   string data;
984   TF_ASSERT_OK(ReadFileToString(env, datafile, &data));
985   ASSERT_TRUE(!data.empty());
986   TF_ASSERT_OK(WriteStringToFile(env, datafile,
987                                  StringPiece(data.data(), data.size() - 1)));
988 
989   BundleReader reader(env, Prefix("end"));
990   TF_ASSERT_OK(reader.status());
991   Tensor val(DT_FLOAT, TensorShape({2, 3}));
992   EXPECT_TRUE(errors::IsOutOfRange(reader.Lookup("key", &val)));
993 }
994 
TEST(TensorBundleTest,HeaderEntry)995 TEST(TensorBundleTest, HeaderEntry) {
996   {
997     BundleWriter writer(Env::Default(), Prefix("b"));
998     TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
999     TF_ASSERT_OK(writer.Finish());
1000   }
1001 
1002   // Extracts out the header.
1003   BundleHeaderProto header;
1004   {
1005     BundleReader reader(Env::Default(), Prefix("b"));
1006     TF_ASSERT_OK(reader.status());
1007     reader.Seek(kHeaderEntryKey);
1008     ASSERT_TRUE(reader.Valid());
1009     ASSERT_TRUE(ParseProtoUnlimited(&header, reader.value().data(),
1010                                     reader.value().size()));
1011   }
1012 
1013   // num_shards
1014   EXPECT_EQ(1, header.num_shards());
1015   // endianness
1016   if (port::kLittleEndian) {
1017     EXPECT_EQ(BundleHeaderProto::LITTLE, header.endianness());
1018   } else {
1019     EXPECT_EQ(BundleHeaderProto::BIG, header.endianness());
1020   }
1021   // version
1022   EXPECT_GT(kTensorBundleVersion, 0);
1023   EXPECT_EQ(kTensorBundleVersion, header.version().producer());
1024   EXPECT_EQ(kTensorBundleMinConsumer, header.version().min_consumer());
1025 }
1026 
TEST(TensorBundleTest,VersionTest)1027 TEST(TensorBundleTest, VersionTest) {
1028   // Min consumer.
1029   {
1030     VersionDef versions;
1031     versions.set_producer(kTensorBundleVersion + 1);
1032     versions.set_min_consumer(kTensorBundleVersion + 1);
1033     VersionTest(
1034         versions,
1035         strings::StrCat("Checkpoint min consumer version ",
1036                         kTensorBundleVersion + 1, " above current version ",
1037                         kTensorBundleVersion, " for TensorFlow"));
1038   }
1039   // Min producer.
1040   {
1041     VersionDef versions;
1042     versions.set_producer(kTensorBundleMinProducer - 1);
1043     VersionTest(
1044         versions,
1045         strings::StrCat("Checkpoint producer version ",
1046                         kTensorBundleMinProducer - 1, " below min producer ",
1047                         kTensorBundleMinProducer, " supported by TensorFlow"));
1048   }
1049   // Bad consumer.
1050   {
1051     VersionDef versions;
1052     versions.set_producer(kTensorBundleVersion + 1);
1053     versions.add_bad_consumers(kTensorBundleVersion);
1054     VersionTest(
1055         versions,
1056         strings::StrCat(
1057             "Checkpoint disallows consumer version ", kTensorBundleVersion,
1058             ".  Please upgrade TensorFlow: this version is likely buggy."));
1059   }
1060 }
1061 
1062 class TensorBundleAlignmentTest : public ::testing::Test {
1063  protected:
1064   template <typename T>
ExpectAlignment(BundleReader * reader,const string & key,int alignment)1065   void ExpectAlignment(BundleReader* reader, const string& key, int alignment) {
1066     BundleEntryProto full_tensor_entry;
1067     TF_ASSERT_OK(reader->GetBundleEntryProto(key, &full_tensor_entry));
1068     EXPECT_EQ(0, full_tensor_entry.offset() % alignment);
1069   }
1070 };
1071 
TEST_F(TensorBundleAlignmentTest,AlignmentTest)1072 TEST_F(TensorBundleAlignmentTest, AlignmentTest) {
1073   {
1074     BundleWriter::Options opts;
1075     opts.data_alignment = 42;
1076     BundleWriter writer(Env::Default(), Prefix("foo"), opts);
1077     TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<float>(3)));
1078     TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<float>(0)));
1079     TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<float>(2)));
1080     TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<float>(1)));
1081     TF_ASSERT_OK(writer.Finish());
1082   }
1083   {
1084     BundleReader reader(Env::Default(), Prefix("foo"));
1085     TF_ASSERT_OK(reader.status());
1086     EXPECT_EQ(
1087         AllTensorKeys(&reader),
1088         std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
1089     Expect<float>(&reader, "foo_000", Constant_2x3<float>(0));
1090     Expect<float>(&reader, "foo_001", Constant_2x3<float>(1));
1091     Expect<float>(&reader, "foo_002", Constant_2x3<float>(2));
1092     Expect<float>(&reader, "foo_003", Constant_2x3<float>(3));
1093   }
1094   {
1095     BundleReader reader(Env::Default(), Prefix("foo"));
1096     TF_ASSERT_OK(reader.status());
1097     ExpectNext<float>(&reader, Constant_2x3<float>(0));
1098     ExpectNext<float>(&reader, Constant_2x3<float>(1));
1099     ExpectNext<float>(&reader, Constant_2x3<float>(2));
1100     ExpectNext<float>(&reader, Constant_2x3<float>(3));
1101     EXPECT_TRUE(reader.Valid());
1102     reader.Next();
1103     EXPECT_FALSE(reader.Valid());
1104   }
1105   {
1106     BundleReader reader(Env::Default(), Prefix("foo"));
1107     TF_ASSERT_OK(reader.status());
1108     ExpectAlignment<float>(&reader, "foo_000", 42);
1109     ExpectAlignment<float>(&reader, "foo_001", 42);
1110     ExpectAlignment<float>(&reader, "foo_002", 42);
1111     ExpectAlignment<float>(&reader, "foo_003", 42);
1112   }
1113 }
1114 
BM_BundleAlignmentByteOff(::testing::benchmark::State & state,int alignment,int tensor_size)1115 static void BM_BundleAlignmentByteOff(::testing::benchmark::State& state,
1116                                       int alignment, int tensor_size) {
1117   {
1118     BundleWriter::Options opts;
1119     opts.data_alignment = alignment;
1120     BundleWriter writer(Env::Default(), Prefix("foo"), opts);
1121     TF_CHECK_OK(writer.Add("small", Constant(true, TensorShape({1}))));
1122     TF_CHECK_OK(writer.Add("big", Constant(32.1, TensorShape({tensor_size}))));
1123     TF_CHECK_OK(writer.Finish());
1124   }
1125   BundleReader reader(Env::Default(), Prefix("foo"));
1126   TF_CHECK_OK(reader.status());
1127   for (auto s : state) {
1128     Tensor t;
1129     TF_CHECK_OK(reader.Lookup("big", &t));
1130   }
1131 }
1132 
1133 #define BM_BundleAlignment(ALIGN, SIZE)            \
1134   static void BM_BundleAlignment_##ALIGN##_##SIZE( \
1135       ::testing::benchmark::State& state) {        \
1136     BM_BundleAlignmentByteOff(state, ALIGN, SIZE); \
1137   }                                                \
1138   BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
1139 
1140 BM_BundleAlignment(1, 512);
1141 BM_BundleAlignment(1, 4096);
1142 BM_BundleAlignment(1, 1048576);
1143 BM_BundleAlignment(4096, 512);
1144 BM_BundleAlignment(4096, 4096);
1145 BM_BundleAlignment(4096, 1048576);
1146 
1147 }  // namespace tensorflow
1148