1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17
18 #include <random>
19 #include <vector>
20
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/framework/tensor_util.h"
23 #include "tensorflow/core/framework/types.pb.h"
24 #include "tensorflow/core/framework/variant.h"
25 #include "tensorflow/core/framework/variant_op_registry.h"
26 #include "tensorflow/core/framework/versions.pb.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/lib/io/table_builder.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
35
36 namespace tensorflow {
37
38 namespace {
39
40 // Prepend the current test case's working temporary directory to <prefix>
Prefix(const string & prefix)41 string Prefix(const string& prefix) {
42 return strings::StrCat(testing::TmpDir(), "/", prefix);
43 }
44
45 // Construct a data input directory by prepending the test data root
46 // directory to <prefix>
TestdataPrefix(const string & prefix)47 string TestdataPrefix(const string& prefix) {
48 return strings::StrCat(testing::TensorFlowSrcRoot(),
49 "/core/util/tensor_bundle/testdata/", prefix);
50 }
51
52 template <typename T>
Constant(T v,TensorShape shape)53 Tensor Constant(T v, TensorShape shape) {
54 Tensor ret(DataTypeToEnum<T>::value, shape);
55 ret.flat<T>().setConstant(v);
56 return ret;
57 }
58
59 template <typename T>
Constant_2x3(T v)60 Tensor Constant_2x3(T v) {
61 return Constant(v, TensorShape({2, 3}));
62 }
63
ByteSwap(Tensor t)64 Tensor ByteSwap(Tensor t) {
65 Tensor ret = tensor::DeepCopy(t);
66 TF_EXPECT_OK(ByteSwapTensor(&ret));
67 return ret;
68 }
69
70 // Assert that <reader> has a tensor under <key> matching <expected_val> in
71 // terms of both shape, dtype, and value
72 template <typename T>
Expect(BundleReader * reader,const string & key,const Tensor & expected_val)73 void Expect(BundleReader* reader, const string& key,
74 const Tensor& expected_val) {
75 // Tests for Contains().
76 EXPECT_TRUE(reader->Contains(key));
77 // Tests for LookupDtypeAndShape().
78 DataType dtype;
79 TensorShape shape;
80 TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
81 EXPECT_EQ(expected_val.dtype(), dtype);
82 EXPECT_EQ(expected_val.shape(), shape);
83 // Tests for Lookup(), checking tensor contents.
84 Tensor val(expected_val.dtype(), shape);
85 TF_ASSERT_OK(reader->Lookup(key, &val));
86 test::ExpectTensorEqual<T>(val, expected_val);
87 }
88
89 template <class T>
ExpectVariant(BundleReader * reader,const string & key,const Tensor & expected_t)90 void ExpectVariant(BundleReader* reader, const string& key,
91 const Tensor& expected_t) {
92 // Tests for Contains().
93 EXPECT_TRUE(reader->Contains(key));
94 // Tests for LookupDtypeAndShape().
95 DataType dtype;
96 TensorShape shape;
97 TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
98 // Tests for Lookup(), checking tensor contents.
99 EXPECT_EQ(expected_t.dtype(), dtype);
100 EXPECT_EQ(expected_t.shape(), shape);
101 Tensor actual_t(dtype, shape);
102 TF_ASSERT_OK(reader->Lookup(key, &actual_t));
103 for (int i = 0; i < expected_t.NumElements(); i++) {
104 Variant actual_var = actual_t.flat<Variant>()(i);
105 Variant expected_var = expected_t.flat<Variant>()(i);
106 EXPECT_EQ(actual_var.TypeName(), expected_var.TypeName());
107 auto* actual_val = actual_var.get<T>();
108 auto* expected_val = expected_var.get<T>();
109 EXPECT_EQ(*expected_val, *actual_val);
110 }
111 }
112
113 template <typename T>
ExpectNext(BundleReader * reader,const Tensor & expected_val)114 void ExpectNext(BundleReader* reader, const Tensor& expected_val) {
115 EXPECT_TRUE(reader->Valid());
116 reader->Next();
117 TF_ASSERT_OK(reader->status());
118 Tensor val;
119 TF_ASSERT_OK(reader->ReadCurrent(&val));
120 test::ExpectTensorEqual<T>(val, expected_val);
121 }
122
AllTensorKeys(BundleReader * reader)123 std::vector<string> AllTensorKeys(BundleReader* reader) {
124 std::vector<string> ret;
125 reader->Seek(kHeaderEntryKey);
126 reader->Next();
127 for (; reader->Valid(); reader->Next()) {
128 ret.emplace_back(reader->key());
129 }
130 return ret;
131 }
132
133 // Writes out the metadata file of a bundle again, with the endianness marker
134 // bit flipped.
FlipEndiannessBit(const string & prefix)135 Status FlipEndiannessBit(const string& prefix) {
136 Env* env = Env::Default();
137 const string metadata_tmp_path = Prefix("some_tmp_path");
138 std::unique_ptr<WritableFile> metadata_file;
139 TF_RETURN_IF_ERROR(env->NewWritableFile(metadata_tmp_path, &metadata_file));
140 // We create the builder lazily in case we run into an exception earlier, in
141 // which case we'd forget to call Finish() and TableBuilder's destructor
142 // would complain.
143 std::unique_ptr<table::TableBuilder> builder;
144
145 // Reads the existing metadata file, and fills the builder.
146 {
147 const string filename = MetaFilename(prefix);
148 uint64 file_size;
149 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
150 std::unique_ptr<RandomAccessFile> file;
151 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
152
153 table::Table* table = nullptr;
154 TF_RETURN_IF_ERROR(
155 table::Table::Open(table::Options(), file.get(), file_size, &table));
156 std::unique_ptr<table::Table> table_deleter(table);
157 std::unique_ptr<table::Iterator> iter(table->NewIterator());
158
159 // Reads the header entry.
160 iter->Seek(kHeaderEntryKey);
161 CHECK(iter->Valid());
162 BundleHeaderProto header;
163 CHECK(header.ParseFromArray(iter->value().data(), iter->value().size()));
164 // Flips the endianness.
165 if (header.endianness() == BundleHeaderProto::LITTLE) {
166 header.set_endianness(BundleHeaderProto::BIG);
167 } else {
168 header.set_endianness(BundleHeaderProto::LITTLE);
169 }
170 builder.reset(
171 new table::TableBuilder(table::Options(), metadata_file.get()));
172 builder->Add(iter->key(), header.SerializeAsString());
173 iter->Next();
174
175 // Adds the non-header entries unmodified.
176 for (; iter->Valid(); iter->Next())
177 builder->Add(iter->key(), iter->value());
178 }
179 TF_RETURN_IF_ERROR(builder->Finish());
180 TF_RETURN_IF_ERROR(env->RenameFile(metadata_tmp_path, MetaFilename(prefix)));
181 return metadata_file->Close();
182 }
183
184 template <typename T>
TestBasic()185 void TestBasic() {
186 {
187 BundleWriter writer(Env::Default(), Prefix("foo"));
188 TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3(T(3))));
189 TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3(T(0))));
190 TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3(T(2))));
191 TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3(T(1))));
192 TF_ASSERT_OK(writer.Finish());
193 }
194 {
195 BundleReader reader(Env::Default(), Prefix("foo"));
196 TF_ASSERT_OK(reader.status());
197 EXPECT_EQ(
198 AllTensorKeys(&reader),
199 std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
200 Expect<T>(&reader, "foo_000", Constant_2x3(T(0)));
201 Expect<T>(&reader, "foo_001", Constant_2x3(T(1)));
202 Expect<T>(&reader, "foo_002", Constant_2x3(T(2)));
203 Expect<T>(&reader, "foo_003", Constant_2x3(T(3)));
204 }
205 {
206 BundleReader reader(Env::Default(), Prefix("foo"));
207 TF_ASSERT_OK(reader.status());
208 ExpectNext<T>(&reader, Constant_2x3(T(0)));
209 ExpectNext<T>(&reader, Constant_2x3(T(1)));
210 ExpectNext<T>(&reader, Constant_2x3(T(2)));
211 ExpectNext<T>(&reader, Constant_2x3(T(3)));
212 EXPECT_TRUE(reader.Valid());
213 reader.Next();
214 EXPECT_FALSE(reader.Valid());
215 }
216 {
217 BundleWriter writer(Env::Default(), Prefix("bar"));
218 TF_EXPECT_OK(writer.Add("bar_003", Constant_2x3(T(3))));
219 TF_EXPECT_OK(writer.Add("bar_000", Constant_2x3(T(0))));
220 TF_EXPECT_OK(writer.Add("bar_002", Constant_2x3(T(2))));
221 TF_EXPECT_OK(writer.Add("bar_001", Constant_2x3(T(1))));
222 TF_ASSERT_OK(writer.Finish());
223 }
224 {
225 BundleReader reader(Env::Default(), Prefix("bar"));
226 TF_ASSERT_OK(reader.status());
227 EXPECT_EQ(
228 AllTensorKeys(&reader),
229 std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003"}));
230 Expect<T>(&reader, "bar_003", Constant_2x3(T(3)));
231 Expect<T>(&reader, "bar_002", Constant_2x3(T(2)));
232 Expect<T>(&reader, "bar_001", Constant_2x3(T(1)));
233 Expect<T>(&reader, "bar_000", Constant_2x3(T(0)));
234 }
235 {
236 BundleReader reader(Env::Default(), Prefix("bar"));
237 TF_ASSERT_OK(reader.status());
238 ExpectNext<T>(&reader, Constant_2x3(T(0)));
239 ExpectNext<T>(&reader, Constant_2x3(T(1)));
240 ExpectNext<T>(&reader, Constant_2x3(T(2)));
241 ExpectNext<T>(&reader, Constant_2x3(T(3)));
242 EXPECT_TRUE(reader.Valid());
243 reader.Next();
244 EXPECT_FALSE(reader.Valid());
245 }
246 TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")},
247 Prefix("merged")));
248 {
249 BundleReader reader(Env::Default(), Prefix("merged"));
250 TF_ASSERT_OK(reader.status());
251 EXPECT_EQ(
252 AllTensorKeys(&reader),
253 std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003",
254 "foo_000", "foo_001", "foo_002", "foo_003"}));
255 Expect<T>(&reader, "bar_000", Constant_2x3(T(0)));
256 Expect<T>(&reader, "bar_001", Constant_2x3(T(1)));
257 Expect<T>(&reader, "bar_002", Constant_2x3(T(2)));
258 Expect<T>(&reader, "bar_003", Constant_2x3(T(3)));
259 Expect<T>(&reader, "foo_000", Constant_2x3(T(0)));
260 Expect<T>(&reader, "foo_001", Constant_2x3(T(1)));
261 Expect<T>(&reader, "foo_002", Constant_2x3(T(2)));
262 Expect<T>(&reader, "foo_003", Constant_2x3(T(3)));
263 }
264 {
265 BundleReader reader(Env::Default(), Prefix("merged"));
266 TF_ASSERT_OK(reader.status());
267 ExpectNext<T>(&reader, Constant_2x3(T(0)));
268 ExpectNext<T>(&reader, Constant_2x3(T(1)));
269 ExpectNext<T>(&reader, Constant_2x3(T(2)));
270 ExpectNext<T>(&reader, Constant_2x3(T(3)));
271 ExpectNext<T>(&reader, Constant_2x3(T(0)));
272 ExpectNext<T>(&reader, Constant_2x3(T(1)));
273 ExpectNext<T>(&reader, Constant_2x3(T(2)));
274 ExpectNext<T>(&reader, Constant_2x3(T(3)));
275 EXPECT_TRUE(reader.Valid());
276 reader.Next();
277 EXPECT_FALSE(reader.Valid());
278 }
279 }
280
281 // Type-specific subroutine of SwapBytes test below
282 template <typename T>
TestByteSwap(const T * forward,const T * swapped,int array_len)283 void TestByteSwap(const T* forward, const T* swapped, int array_len) {
284 auto bytes_per_elem = sizeof(T);
285
286 // Convert the entire array at once
287 std::unique_ptr<T[]> forward_copy(new T[array_len]);
288 std::memcpy(forward_copy.get(), forward, array_len * bytes_per_elem);
289 TF_EXPECT_OK(ByteSwapArray(reinterpret_cast<char*>(forward_copy.get()),
290 bytes_per_elem, array_len));
291 for (int i = 0; i < array_len; i++) {
292 EXPECT_EQ(forward_copy.get()[i], swapped[i]);
293 }
294
295 // Then the array wrapped in a tensor
296 auto shape = TensorShape({array_len});
297 auto dtype = DataTypeToEnum<T>::value;
298 Tensor forward_tensor(dtype, shape);
299 Tensor swapped_tensor(dtype, shape);
300 std::memcpy(const_cast<char*>(forward_tensor.tensor_data().data()), forward,
301 array_len * bytes_per_elem);
302 std::memcpy(const_cast<char*>(swapped_tensor.tensor_data().data()), swapped,
303 array_len * bytes_per_elem);
304 TF_EXPECT_OK(ByteSwapTensor(&forward_tensor));
305 test::ExpectTensorEqual<T>(forward_tensor, swapped_tensor);
306 }
307
308 // Unit test of the byte-swapping operations that TensorBundle uses.
TEST(TensorBundleTest,SwapBytes)309 TEST(TensorBundleTest, SwapBytes) {
310 // A bug in the compiler on MacOS causes ByteSwap() and FlipEndiannessBit()
311 // to be removed from the executable if they are only called from templated
312 // functions. As a workaround, we make some dummy calls here.
313 // TODO(frreiss): Remove this workaround when the compiler bug is fixed.
314 ByteSwap(Constant_2x3<int>(42));
315 EXPECT_NE(Status::OK(), FlipEndiannessBit(Prefix("not_a_valid_prefix")));
316
317 // Test patterns, manually swapped so that we aren't relying on the
318 // correctness of our own byte-swapping macros when testing those macros.
319 // At least one of the entries in each list has the sign bit set when
320 // interpreted as a signed int.
321 const int arr_len_16 = 4;
322 const uint16_t forward_16[] = {0x1de5, 0xd017, 0xf1ea, 0xc0a1};
323 const uint16_t swapped_16[] = {0xe51d, 0x17d0, 0xeaf1, 0xa1c0};
324 const int arr_len_32 = 2;
325 const uint32_t forward_32[] = {0x0ddba115, 0xf01dab1e};
326 const uint32_t swapped_32[] = {0x15a1db0d, 0x1eab1df0};
327 const int arr_len_64 = 2;
328 const uint64_t forward_64[] = {0xf005ba11caba1000, 0x5ca1ab1ecab005e5};
329 const uint64_t swapped_64[] = {0x0010baca11ba05f0, 0xe505b0ca1eaba15c};
330
331 // 16-bit types
332 TestByteSwap(forward_16, swapped_16, arr_len_16);
333 TestByteSwap(reinterpret_cast<const int16_t*>(forward_16),
334 reinterpret_cast<const int16_t*>(swapped_16), arr_len_16);
335 TestByteSwap(reinterpret_cast<const bfloat16*>(forward_16),
336 reinterpret_cast<const bfloat16*>(swapped_16), arr_len_16);
337
338 // 32-bit types
339 TestByteSwap(forward_32, swapped_32, arr_len_32);
340 TestByteSwap(reinterpret_cast<const int32_t*>(forward_32),
341 reinterpret_cast<const int32_t*>(swapped_32), arr_len_32);
342 TestByteSwap(reinterpret_cast<const float*>(forward_32),
343 reinterpret_cast<const float*>(swapped_32), arr_len_32);
344
345 // 64-bit types
346 // Cast to uint64*/int64* to make DataTypeToEnum<T> happy
347 TestByteSwap(reinterpret_cast<const uint64*>(forward_64),
348 reinterpret_cast<const uint64*>(swapped_64), arr_len_64);
349 TestByteSwap(reinterpret_cast<const int64*>(forward_64),
350 reinterpret_cast<const int64*>(swapped_64), arr_len_64);
351 TestByteSwap(reinterpret_cast<const double*>(forward_64),
352 reinterpret_cast<const double*>(swapped_64), arr_len_64);
353
354 // Complex types.
355 // Logic for complex number handling is only in ByteSwapTensor, so don't test
356 // ByteSwapArray
357 const float* forward_float = reinterpret_cast<const float*>(forward_32);
358 const float* swapped_float = reinterpret_cast<const float*>(swapped_32);
359 const double* forward_double = reinterpret_cast<const double*>(forward_64);
360 const double* swapped_double = reinterpret_cast<const double*>(swapped_64);
361 Tensor forward_complex64 = Constant_2x3<complex64>(
362 std::complex<float>(forward_float[0], forward_float[1]));
363 Tensor swapped_complex64 = Constant_2x3<complex64>(
364 std::complex<float>(swapped_float[0], swapped_float[1]));
365 Tensor forward_complex128 = Constant_2x3<complex128>(
366 std::complex<double>(forward_double[0], forward_double[1]));
367 Tensor swapped_complex128 = Constant_2x3<complex128>(
368 std::complex<double>(swapped_double[0], swapped_double[1]));
369
370 TF_EXPECT_OK(ByteSwapTensor(&forward_complex64));
371 test::ExpectTensorEqual<complex64>(forward_complex64, swapped_complex64);
372
373 TF_EXPECT_OK(ByteSwapTensor(&forward_complex128));
374 test::ExpectTensorEqual<complex128>(forward_complex128, swapped_complex128);
375 }
376
377 // Basic test of alternate-endianness support. Generates a bundle in
378 // the opposite of the current system's endianness and attempts to
379 // read the bundle back in. Does not exercise sharding or access to
380 // nonaligned tensors. Does cover the major access types exercised
381 // in TestBasic.
382 template <typename T>
TestEndianness()383 void TestEndianness() {
384 {
385 // Write out a TensorBundle in the opposite of this host's endianness.
386 BundleWriter writer(Env::Default(), Prefix("foo"));
387 TF_EXPECT_OK(writer.Add("foo_003", ByteSwap(Constant_2x3<T>(T(3)))));
388 TF_EXPECT_OK(writer.Add("foo_000", ByteSwap(Constant_2x3<T>(T(0)))));
389 TF_EXPECT_OK(writer.Add("foo_002", ByteSwap(Constant_2x3<T>(T(2)))));
390 TF_EXPECT_OK(writer.Add("foo_001", ByteSwap(Constant_2x3<T>(T(1)))));
391 TF_ASSERT_OK(writer.Finish());
392 TF_ASSERT_OK(FlipEndiannessBit(Prefix("foo")));
393 }
394 {
395 BundleReader reader(Env::Default(), Prefix("foo"));
396 TF_ASSERT_OK(reader.status());
397 EXPECT_EQ(
398 AllTensorKeys(&reader),
399 std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
400 Expect<T>(&reader, "foo_000", Constant_2x3<T>(T(0)));
401 Expect<T>(&reader, "foo_001", Constant_2x3<T>(T(1)));
402 Expect<T>(&reader, "foo_002", Constant_2x3<T>(T(2)));
403 Expect<T>(&reader, "foo_003", Constant_2x3<T>(T(3)));
404 }
405 {
406 BundleReader reader(Env::Default(), Prefix("foo"));
407 TF_ASSERT_OK(reader.status());
408 ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
409 ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
410 ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
411 ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
412 EXPECT_TRUE(reader.Valid());
413 reader.Next();
414 EXPECT_FALSE(reader.Valid());
415 }
416 {
417 BundleWriter writer(Env::Default(), Prefix("bar"));
418 TF_EXPECT_OK(writer.Add("bar_003", ByteSwap(Constant_2x3<T>(T(3)))));
419 TF_EXPECT_OK(writer.Add("bar_000", ByteSwap(Constant_2x3<T>(T(0)))));
420 TF_EXPECT_OK(writer.Add("bar_002", ByteSwap(Constant_2x3<T>(T(2)))));
421 TF_EXPECT_OK(writer.Add("bar_001", ByteSwap(Constant_2x3<T>(T(1)))));
422 TF_ASSERT_OK(writer.Finish());
423 TF_ASSERT_OK(FlipEndiannessBit(Prefix("bar")));
424 }
425 {
426 BundleReader reader(Env::Default(), Prefix("bar"));
427 TF_ASSERT_OK(reader.status());
428 EXPECT_EQ(
429 AllTensorKeys(&reader),
430 std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003"}));
431 Expect<T>(&reader, "bar_003", Constant_2x3<T>(T(3)));
432 Expect<T>(&reader, "bar_002", Constant_2x3<T>(T(2)));
433 Expect<T>(&reader, "bar_001", Constant_2x3<T>(T(1)));
434 Expect<T>(&reader, "bar_000", Constant_2x3<T>(T(0)));
435 }
436 {
437 BundleReader reader(Env::Default(), Prefix("bar"));
438 TF_ASSERT_OK(reader.status());
439 ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
440 ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
441 ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
442 ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
443 EXPECT_TRUE(reader.Valid());
444 reader.Next();
445 EXPECT_FALSE(reader.Valid());
446 }
447 TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")},
448 Prefix("merged")));
449 {
450 BundleReader reader(Env::Default(), Prefix("merged"));
451 TF_ASSERT_OK(reader.status());
452 EXPECT_EQ(
453 AllTensorKeys(&reader),
454 std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003",
455 "foo_000", "foo_001", "foo_002", "foo_003"}));
456 Expect<T>(&reader, "bar_000", Constant_2x3<T>(T(0)));
457 Expect<T>(&reader, "bar_001", Constant_2x3<T>(T(1)));
458 Expect<T>(&reader, "bar_002", Constant_2x3<T>(T(2)));
459 Expect<T>(&reader, "bar_003", Constant_2x3<T>(T(3)));
460 Expect<T>(&reader, "foo_000", Constant_2x3<T>(T(0)));
461 Expect<T>(&reader, "foo_001", Constant_2x3<T>(T(1)));
462 Expect<T>(&reader, "foo_002", Constant_2x3<T>(T(2)));
463 Expect<T>(&reader, "foo_003", Constant_2x3<T>(T(3)));
464 }
465 {
466 BundleReader reader(Env::Default(), Prefix("merged"));
467 TF_ASSERT_OK(reader.status());
468 ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
469 ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
470 ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
471 ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
472 ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
473 ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
474 ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
475 ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
476 EXPECT_TRUE(reader.Valid());
477 reader.Next();
478 EXPECT_FALSE(reader.Valid());
479 }
480 }
481
482 template <typename T>
TestNonStandardShapes()483 void TestNonStandardShapes() {
484 {
485 BundleWriter writer(Env::Default(), Prefix("nonstandard"));
486 TF_EXPECT_OK(writer.Add("scalar", Constant(T(0), TensorShape())));
487 TF_EXPECT_OK(
488 writer.Add("non_standard0", Constant(T(0), TensorShape({0, 1618}))));
489 TF_EXPECT_OK(
490 writer.Add("non_standard1", Constant(T(0), TensorShape({16, 0, 18}))));
491 TF_ASSERT_OK(writer.Finish());
492 }
493 {
494 BundleReader reader(Env::Default(), Prefix("nonstandard"));
495 TF_ASSERT_OK(reader.status());
496 Expect<T>(&reader, "scalar", Constant(T(0), TensorShape()));
497 Expect<T>(&reader, "non_standard0", Constant(T(0), TensorShape({0, 1618})));
498 Expect<T>(&reader, "non_standard1",
499 Constant(T(0), TensorShape({16, 0, 18})));
500 }
501 }
502
503 // Writes a bundle to disk with a bad "version"; checks for "expected_error".
VersionTest(const VersionDef & version,StringPiece expected_error)504 void VersionTest(const VersionDef& version, StringPiece expected_error) {
505 const string path = Prefix("version_test");
506 {
507 // Prepare an empty bundle with the given version information.
508 BundleHeaderProto header;
509 *header.mutable_version() = version;
510
511 // Write the metadata file to disk.
512 std::unique_ptr<WritableFile> file;
513 TF_ASSERT_OK(Env::Default()->NewWritableFile(MetaFilename(path), &file));
514 table::TableBuilder builder(table::Options(), file.get());
515 builder.Add(kHeaderEntryKey, header.SerializeAsString());
516 TF_ASSERT_OK(builder.Finish());
517 }
518 // Read it back in and verify that we get the expected error.
519 BundleReader reader(Env::Default(), path);
520 EXPECT_TRUE(errors::IsInvalidArgument(reader.status()));
521 EXPECT_TRUE(
522 absl::StartsWith(reader.status().error_message(), expected_error));
523 }
524
525 } // namespace
526
TEST(TensorBundleTest,Basic)527 TEST(TensorBundleTest, Basic) {
528 TestBasic<float>();
529 TestBasic<double>();
530 TestBasic<int32>();
531 TestBasic<uint8>();
532 TestBasic<int16>();
533 TestBasic<int8>();
534 TestBasic<complex64>();
535 TestBasic<complex128>();
536 TestBasic<int64>();
537 TestBasic<bool>();
538 TestBasic<qint32>();
539 TestBasic<quint8>();
540 TestBasic<qint8>();
541 TestBasic<bfloat16>();
542 }
543
TEST(TensorBundleTest,Endianness)544 TEST(TensorBundleTest, Endianness) {
545 TestEndianness<float>();
546 TestEndianness<double>();
547 TestEndianness<int32>();
548 TestEndianness<uint8>();
549 TestEndianness<int16>();
550 TestEndianness<int8>();
551 TestEndianness<complex64>();
552 TestEndianness<complex128>();
553 TestEndianness<int64>();
554 TestEndianness<bool>();
555 TestEndianness<qint32>();
556 TestEndianness<quint8>();
557 TestEndianness<qint8>();
558 TestEndianness<bfloat16>();
559 }
560
TEST(TensorBundleTest,PartitionedVariables)561 TEST(TensorBundleTest, PartitionedVariables) {
562 const TensorShape kFullShape({5, 10});
563 // Adds two slices.
564 // First slice: column 0, all zeros.
565 // Second slice: column 1 to rest, all ones.
566 TensorSlice slice1 = TensorSlice::ParseOrDie("-:0,1");
567 TensorSlice slice2 = TensorSlice::ParseOrDie("-:1,9");
568 {
569 BundleWriter writer(Env::Default(), Prefix("foo"));
570
571 TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice1,
572 Constant<float>(0., TensorShape({5, 1}))));
573 TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice2,
574 Constant<float>(1., TensorShape({5, 9}))));
575 TF_ASSERT_OK(writer.Finish());
576 }
577 // Reads in full.
578 {
579 BundleReader reader(Env::Default(), Prefix("foo"));
580 TF_ASSERT_OK(reader.status());
581
582 Tensor expected_val(DT_FLOAT, kFullShape);
583 test::FillFn<float>(&expected_val, [](int offset) -> float {
584 if (offset % 10 == 0) {
585 return 0; // First column zeros.
586 }
587 return 1; // Other columns ones.
588 });
589
590 Tensor val(DT_FLOAT, kFullShape);
591 TF_ASSERT_OK(reader.Lookup("foo", &val));
592 test::ExpectTensorEqual<float>(val, expected_val);
593 }
594 // Reads all slices.
595 {
596 BundleReader reader(Env::Default(), Prefix("foo"));
597 TF_ASSERT_OK(reader.status());
598
599 std::vector<TensorSlice> slices;
600 TF_ASSERT_OK(reader.LookupTensorSlices("foo", &slices));
601
602 EXPECT_EQ(2, slices.size());
603 EXPECT_EQ(slice1.DebugString(), slices[0].DebugString());
604 EXPECT_EQ(slice2.DebugString(), slices[1].DebugString());
605 }
606 // Reads a slice consisting of first two columns, "cutting" both slices.
607 {
608 BundleReader reader(Env::Default(), Prefix("foo"));
609 TF_ASSERT_OK(reader.status());
610
611 // First two columns, "cutting" both slices.
612 const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:0,2");
613 Tensor expected_val(DT_FLOAT, TensorShape({5, 2}));
614 test::FillFn<float>(&expected_val, [](int offset) -> float {
615 if (offset % 2 == 0) {
616 return 0; // First column zeros.
617 }
618 return 1; // Other columns ones.
619 });
620
621 Tensor val(DT_FLOAT, TensorShape({5, 2}));
622 TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
623 test::ExpectTensorEqual<float>(val, expected_val);
624 }
625 // Reads a slice consisting of columns 2-4, "cutting" the second slice only.
626 {
627 BundleReader reader(Env::Default(), Prefix("foo"));
628 TF_ASSERT_OK(reader.status());
629
630 const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:2,2");
631 Tensor val(DT_FLOAT, TensorShape({5, 2}));
632 TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
633 test::ExpectTensorEqual<float>(val,
634 Constant<float>(1., TensorShape({5, 2})));
635 }
636 }
637
TEST(TensorBundleTest,EquivalentSliceTest)638 TEST(TensorBundleTest, EquivalentSliceTest) {
639 const TensorShape kFullShape({5, 10});
640 const Tensor kExpected(Constant<float>(1., kFullShape));
641 {
642 BundleWriter writer(Env::Default(), Prefix("foo"));
643 TF_ASSERT_OK(writer.AddSlice("no_extents", kFullShape,
644 TensorSlice::ParseOrDie("-:-"), kExpected));
645 TF_ASSERT_OK(writer.AddSlice("both_extents", kFullShape,
646 TensorSlice::ParseOrDie("0,5:0,10"),
647 kExpected));
648 TF_ASSERT_OK(writer.Finish());
649 }
650 // Slices match exactly and are fully abbreviated.
651 {
652 BundleReader reader(Env::Default(), Prefix("foo"));
653 TF_ASSERT_OK(reader.status());
654 const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
655 Tensor val(DT_FLOAT, TensorShape(kFullShape));
656 TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
657 test::ExpectTensorEqual<float>(val, kExpected);
658 }
659 // Slice match exactly and are fully specified.
660 {
661 BundleReader reader(Env::Default(), Prefix("foo"));
662 TF_ASSERT_OK(reader.status());
663 const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
664 Tensor val(DT_FLOAT, TensorShape(kFullShape));
665 TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
666 test::ExpectTensorEqual<float>(val, kExpected);
667 }
668 // Stored slice has no extents, spec has extents.
669 {
670 BundleReader reader(Env::Default(), Prefix("foo"));
671 TF_ASSERT_OK(reader.status());
672 const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
673 Tensor val(DT_FLOAT, TensorShape(kFullShape));
674 TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
675 test::ExpectTensorEqual<float>(val, kExpected);
676 }
677 // Stored slice has both extents, spec has no extents.
678 {
679 BundleReader reader(Env::Default(), Prefix("foo"));
680 TF_ASSERT_OK(reader.status());
681 const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
682 Tensor val(DT_FLOAT, TensorShape(kFullShape));
683 TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
684 test::ExpectTensorEqual<float>(val, kExpected);
685 }
686 }
687
TEST(TensorBundleTest,NonStandardShapes)688 TEST(TensorBundleTest, NonStandardShapes) {
689 TestNonStandardShapes<float>();
690 TestNonStandardShapes<double>();
691 TestNonStandardShapes<int32>();
692 TestNonStandardShapes<uint8>();
693 TestNonStandardShapes<int16>();
694 TestNonStandardShapes<int8>();
695 TestNonStandardShapes<complex64>();
696 TestNonStandardShapes<complex128>();
697 TestNonStandardShapes<int64>();
698 TestNonStandardShapes<bool>();
699 TestNonStandardShapes<qint32>();
700 TestNonStandardShapes<quint8>();
701 TestNonStandardShapes<qint8>();
702 TestNonStandardShapes<bfloat16>();
703 }
704
TEST(TensorBundleTest,StringTensorsOldFormat)705 TEST(TensorBundleTest, StringTensorsOldFormat) {
706 // Test string tensor bundle made with previous version of code that use
707 // varint32s to store string lengths (we now use varint64s).
708 BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo"));
709 TF_ASSERT_OK(reader.status());
710 EXPECT_EQ(AllTensorKeys(&reader),
711 std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
712
713 Expect<tstring>(&reader, "string_tensor",
714 Tensor(DT_STRING, TensorShape({1})));
715 Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
716 Expect<tstring>(
717 &reader, "strs",
718 test::AsTensor<tstring>({"hello", "", "x01", string(1 << 10, 'c')}));
719 Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
720 }
721
TEST(TensorBundleTest,StringTensors)722 TEST(TensorBundleTest, StringTensors) {
723 constexpr size_t kLongLength = static_cast<size_t>(UINT32_MAX) + 1;
724 Tensor long_string_tensor(DT_STRING, TensorShape({1}));
725
726 {
727 BundleWriter writer(Env::Default(), Prefix("foo"));
728 TF_EXPECT_OK(writer.Add("string_tensor",
729 Tensor(DT_STRING, TensorShape({1})))); // Empty.
730 TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<tstring>({"hello"})));
731 TF_EXPECT_OK(writer.Add(
732 "strs",
733 test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')})));
734
735 // Requires a 64-bit length.
736 tstring* backing_string = long_string_tensor.flat<tstring>().data();
737 backing_string->resize_uninitialized(kLongLength);
738 std::char_traits<char>::assign(backing_string->data(), kLongLength, 'd');
739 TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
740
741 // Mixes in some floats.
742 TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18)));
743 TF_ASSERT_OK(writer.Finish());
744 }
745 {
746 BundleReader reader(Env::Default(), Prefix("foo"));
747 TF_ASSERT_OK(reader.status());
748 EXPECT_EQ(AllTensorKeys(&reader),
749 std::vector<string>({"floats", "long_scalar", "scalar",
750 "string_tensor", "strs"}));
751
752 Expect<tstring>(&reader, "string_tensor",
753 Tensor(DT_STRING, TensorShape({1})));
754 Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
755 Expect<tstring>(
756 &reader, "strs",
757 test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')}));
758
759 Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
760
761 // We don't use the Expect function so we can re-use the
762 // `long_string_tensor` buffer for reading out long_scalar to keep memory
763 // usage reasonable.
764 EXPECT_TRUE(reader.Contains("long_scalar"));
765 DataType dtype;
766 TensorShape shape;
767 TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape));
768 EXPECT_EQ(DT_STRING, dtype);
769 EXPECT_EQ(TensorShape({1}), shape);
770
771 // Fill the string differently so that we can be sure the new one is read
772 // in. Because fragmentation in tc-malloc and we have such a big tensor
773 // of 4GB, therefore it is not ideal to free the buffer right now.
774 // The rationale is to make allocation/free close to each other.
775 tstring* backing_string = long_string_tensor.flat<tstring>().data();
776 std::char_traits<char>::assign(backing_string->data(), kLongLength, 'e');
777
778 // Read long_scalar and check it contains kLongLength 'd's.
779 TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
780 ASSERT_EQ(backing_string, long_string_tensor.flat<tstring>().data());
781 EXPECT_EQ(kLongLength, backing_string->length());
782 for (size_t i = 0; i < kLongLength; i++) {
783 // Not using ASSERT_EQ('d', c) because this way is twice as fast due to
784 // compiler optimizations.
785 if ((*backing_string)[i] != 'd') {
786 FAIL() << "long_scalar is not full of 'd's as expected.";
787 break;
788 }
789 }
790 }
791 }
792
793 class VariantObject {
794 public:
VariantObject()795 VariantObject() {}
VariantObject(const string & metadata,int64 value)796 VariantObject(const string& metadata, int64 value)
797 : metadata_(metadata), value_(value) {}
798
TypeName() const799 string TypeName() const { return "TEST VariantObject"; }
Encode(VariantTensorData * data) const800 void Encode(VariantTensorData* data) const {
801 data->set_type_name(TypeName());
802 data->set_metadata(metadata_);
803 Tensor val_t = Tensor(DT_INT64, TensorShape({}));
804 val_t.scalar<int64>()() = value_;
805 *(data->add_tensors()) = val_t;
806 }
Decode(const VariantTensorData & data)807 bool Decode(const VariantTensorData& data) {
808 EXPECT_EQ(data.type_name(), TypeName());
809 data.get_metadata(&metadata_);
810 EXPECT_EQ(data.tensors_size(), 1);
811 value_ = data.tensors(0).scalar<int64>()();
812 return true;
813 }
operator ==(const VariantObject other) const814 bool operator==(const VariantObject other) const {
815 return metadata_ == other.metadata_ && value_ == other.value_;
816 }
817 string metadata_;
818 int64 value_;
819 };
820
821 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantObject, "TEST VariantObject");
822
TEST(TensorBundleTest,VariantTensors)823 TEST(TensorBundleTest, VariantTensors) {
824 {
825 BundleWriter writer(Env::Default(), Prefix("foo"));
826 TF_EXPECT_OK(
827 writer.Add("variant_tensor",
828 test::AsTensor<Variant>({VariantObject("test", 10),
829 VariantObject("test1", 20)})));
830 TF_ASSERT_OK(writer.Finish());
831 }
832 {
833 BundleReader reader(Env::Default(), Prefix("foo"));
834 TF_ASSERT_OK(reader.status());
835 ExpectVariant<VariantObject>(
836 &reader, "variant_tensor",
837 test::AsTensor<Variant>(
838 {VariantObject("test", 10), VariantObject("test1", 20)}));
839 }
840 }
841
TEST(TensorBundleTest,DirectoryStructure)842 TEST(TensorBundleTest, DirectoryStructure) {
843 Env* env = Env::Default();
844 // Writes two bundles.
845 const std::vector<string> kBundlePrefixes = {Prefix("worker0"),
846 Prefix("worker1")};
847 for (int i = 0; i < 2; ++i) {
848 BundleWriter writer(env, kBundlePrefixes[i]);
849 TF_EXPECT_OK(
850 writer.Add(strings::StrCat("tensor", i), Constant_2x3<float>(0.)));
851 TF_ASSERT_OK(writer.Finish());
852 }
853
854 // Ensures we have the expected files.
855 auto CheckDirFiles = [env](const string& bundle_prefix,
856 gtl::ArraySlice<string> expected_files) {
857 StringPiece dir = io::Dirname(bundle_prefix);
858 for (const string& expected_file : expected_files) {
859 TF_EXPECT_OK(env->FileExists(io::JoinPath(dir, expected_file)));
860 }
861 };
862
863 // Check we have:
864 // worker<i>.index
865 // worker<i>.data-00000-of-00001
866 CheckDirFiles(kBundlePrefixes[0],
867 {"worker0.index", "worker0.data-00000-of-00001"});
868 CheckDirFiles(kBundlePrefixes[1],
869 {"worker1.index", "worker1.data-00000-of-00001"});
870
871 // Trivially "merge" one bundle to some other location (i.e., a renaming).
872 const string kAnotherPrefix = Prefix("another");
873 TF_ASSERT_OK(MergeBundles(env, {kBundlePrefixes[0]}, kAnotherPrefix));
874 CheckDirFiles(kAnotherPrefix,
875 {"another.index", "another.data-00000-of-00001"});
876
877 // Performs actual merge of the two bundles. Check we have:
878 // merged.index
879 // merged.data-00000-of-00002
880 // merged.data-00001-of-00002
881 const string kMerged = Prefix("merged");
882 TF_ASSERT_OK(
883 MergeBundles(env, {kAnotherPrefix, kBundlePrefixes[1]}, kMerged));
884 CheckDirFiles(kMerged, {"merged.index", "merged.data-00000-of-00002",
885 "merged.data-00001-of-00002"});
886 }
887
TEST(TensorBundleTest,Error)888 TEST(TensorBundleTest, Error) {
889 { // Dup keys.
890 BundleWriter writer(Env::Default(), Prefix("dup"));
891 TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
892 EXPECT_FALSE(writer.Add("foo", Constant_2x3(2.f)).ok());
893 EXPECT_TRUE(absl::StrContains(writer.status().ToString(), "duplicate key"));
894 EXPECT_FALSE(writer.Finish().ok());
895 }
896 { // Double finish
897 BundleWriter writer(Env::Default(), Prefix("bad"));
898 EXPECT_TRUE(writer.Finish().ok());
899 EXPECT_FALSE(writer.Finish().ok());
900 }
901 { // Not found.
902 BundleReader reader(Env::Default(), Prefix("nonexist"));
903 EXPECT_TRUE(absl::StrContains(reader.status().ToString(), "Not found"));
904 }
905 }
906
TEST(TensorBundleTest,Checksum)907 TEST(TensorBundleTest, Checksum) {
908 // Randomly flips a byte in [pos_lhs, end of data file), or exactly byte
909 // pos_lhs if exact_pos == True.
910 auto FlipByte = [](const string& prefix, int pos_lhs,
911 bool exact_pos = false) {
912 DCHECK_GE(pos_lhs, 0);
913 const string& datafile = DataFilename(Prefix(prefix), 0, 1);
914 string data;
915 TF_ASSERT_OK(ReadFileToString(Env::Default(), datafile, &data));
916
917 int byte_pos = 0;
918 if (!exact_pos) {
919 std::mt19937 rng;
920 std::uniform_int_distribution<int> dist(pos_lhs, data.size() - 1);
921 byte_pos = dist(rng);
922 } else {
923 byte_pos = pos_lhs;
924 }
925 data[byte_pos] = ~data[byte_pos];
926 TF_ASSERT_OK(WriteStringToFile(Env::Default(), datafile, data));
927 };
928 // The lookup should fail with a checksum-related message.
929 auto ExpectLookupFails = [](const string& prefix, const string& key,
930 const string& expected_msg, Tensor& val) {
931 BundleReader reader(Env::Default(), Prefix(prefix));
932 Status status = reader.Lookup(key, &val);
933 EXPECT_TRUE(errors::IsDataLoss(status));
934 EXPECT_TRUE(absl::StrContains(status.ToString(), expected_msg));
935 };
936
937 // Corrupts a float tensor.
938 {
939 BundleWriter writer(Env::Default(), Prefix("singleton"));
940 TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
941 TF_ASSERT_OK(writer.Finish());
942
943 FlipByte("singleton", 0 /* corrupts any byte */);
944 Tensor val(DT_FLOAT, TensorShape({2, 3}));
945 ExpectLookupFails("singleton", "foo",
946 "Checksum does not match" /* expected fail msg */, val);
947 }
948 // Corrupts a string tensor.
949 {
950 auto WriteStrings = []() {
951 BundleWriter writer(Env::Default(), Prefix("strings"));
952 TF_EXPECT_OK(
953 writer.Add("foo", test::AsTensor<tstring>({"hello", "world"})));
954 TF_ASSERT_OK(writer.Finish());
955 };
956 // Corrupts the first two bytes, which are the varint32-encoded lengths
957 // of the two string elements. Should hit mismatch on length cksum.
958 for (int i = 0; i < 2; ++i) {
959 WriteStrings();
960 FlipByte("strings", i, true /* corrupts exactly byte i */);
961 Tensor val(DT_STRING, TensorShape({2}));
962 ExpectLookupFails(
963 "strings", "foo",
964 "length checksum does not match" /* expected fail msg */, val);
965 }
966 // Corrupts the string bytes, should hit an overall cksum mismatch.
967 WriteStrings();
968 FlipByte("strings", 2 /* corrupts starting from byte 2 */);
969 Tensor val(DT_STRING, TensorShape({2}));
970 ExpectLookupFails("strings", "foo",
971 "Checksum does not match" /* expected fail msg */, val);
972 }
973 }
974
TEST(TensorBundleTest,TruncatedTensorContents)975 TEST(TensorBundleTest, TruncatedTensorContents) {
976 Env* env = Env::Default();
977 BundleWriter writer(env, Prefix("end"));
978 TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
979 TF_ASSERT_OK(writer.Finish());
980
981 // Truncates the data file by one byte, so that we hit EOF.
982 const string datafile = DataFilename(Prefix("end"), 0, 1);
983 string data;
984 TF_ASSERT_OK(ReadFileToString(env, datafile, &data));
985 ASSERT_TRUE(!data.empty());
986 TF_ASSERT_OK(WriteStringToFile(env, datafile,
987 StringPiece(data.data(), data.size() - 1)));
988
989 BundleReader reader(env, Prefix("end"));
990 TF_ASSERT_OK(reader.status());
991 Tensor val(DT_FLOAT, TensorShape({2, 3}));
992 EXPECT_TRUE(errors::IsOutOfRange(reader.Lookup("key", &val)));
993 }
994
TEST(TensorBundleTest,HeaderEntry)995 TEST(TensorBundleTest, HeaderEntry) {
996 {
997 BundleWriter writer(Env::Default(), Prefix("b"));
998 TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
999 TF_ASSERT_OK(writer.Finish());
1000 }
1001
1002 // Extracts out the header.
1003 BundleHeaderProto header;
1004 {
1005 BundleReader reader(Env::Default(), Prefix("b"));
1006 TF_ASSERT_OK(reader.status());
1007 reader.Seek(kHeaderEntryKey);
1008 ASSERT_TRUE(reader.Valid());
1009 ASSERT_TRUE(ParseProtoUnlimited(&header, reader.value().data(),
1010 reader.value().size()));
1011 }
1012
1013 // num_shards
1014 EXPECT_EQ(1, header.num_shards());
1015 // endianness
1016 if (port::kLittleEndian) {
1017 EXPECT_EQ(BundleHeaderProto::LITTLE, header.endianness());
1018 } else {
1019 EXPECT_EQ(BundleHeaderProto::BIG, header.endianness());
1020 }
1021 // version
1022 EXPECT_GT(kTensorBundleVersion, 0);
1023 EXPECT_EQ(kTensorBundleVersion, header.version().producer());
1024 EXPECT_EQ(kTensorBundleMinConsumer, header.version().min_consumer());
1025 }
1026
TEST(TensorBundleTest,VersionTest)1027 TEST(TensorBundleTest, VersionTest) {
1028 // Min consumer.
1029 {
1030 VersionDef versions;
1031 versions.set_producer(kTensorBundleVersion + 1);
1032 versions.set_min_consumer(kTensorBundleVersion + 1);
1033 VersionTest(
1034 versions,
1035 strings::StrCat("Checkpoint min consumer version ",
1036 kTensorBundleVersion + 1, " above current version ",
1037 kTensorBundleVersion, " for TensorFlow"));
1038 }
1039 // Min producer.
1040 {
1041 VersionDef versions;
1042 versions.set_producer(kTensorBundleMinProducer - 1);
1043 VersionTest(
1044 versions,
1045 strings::StrCat("Checkpoint producer version ",
1046 kTensorBundleMinProducer - 1, " below min producer ",
1047 kTensorBundleMinProducer, " supported by TensorFlow"));
1048 }
1049 // Bad consumer.
1050 {
1051 VersionDef versions;
1052 versions.set_producer(kTensorBundleVersion + 1);
1053 versions.add_bad_consumers(kTensorBundleVersion);
1054 VersionTest(
1055 versions,
1056 strings::StrCat(
1057 "Checkpoint disallows consumer version ", kTensorBundleVersion,
1058 ". Please upgrade TensorFlow: this version is likely buggy."));
1059 }
1060 }
1061
1062 class TensorBundleAlignmentTest : public ::testing::Test {
1063 protected:
1064 template <typename T>
ExpectAlignment(BundleReader * reader,const string & key,int alignment)1065 void ExpectAlignment(BundleReader* reader, const string& key, int alignment) {
1066 BundleEntryProto full_tensor_entry;
1067 TF_ASSERT_OK(reader->GetBundleEntryProto(key, &full_tensor_entry));
1068 EXPECT_EQ(0, full_tensor_entry.offset() % alignment);
1069 }
1070 };
1071
TEST_F(TensorBundleAlignmentTest,AlignmentTest)1072 TEST_F(TensorBundleAlignmentTest, AlignmentTest) {
1073 {
1074 BundleWriter::Options opts;
1075 opts.data_alignment = 42;
1076 BundleWriter writer(Env::Default(), Prefix("foo"), opts);
1077 TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<float>(3)));
1078 TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<float>(0)));
1079 TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<float>(2)));
1080 TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<float>(1)));
1081 TF_ASSERT_OK(writer.Finish());
1082 }
1083 {
1084 BundleReader reader(Env::Default(), Prefix("foo"));
1085 TF_ASSERT_OK(reader.status());
1086 EXPECT_EQ(
1087 AllTensorKeys(&reader),
1088 std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
1089 Expect<float>(&reader, "foo_000", Constant_2x3<float>(0));
1090 Expect<float>(&reader, "foo_001", Constant_2x3<float>(1));
1091 Expect<float>(&reader, "foo_002", Constant_2x3<float>(2));
1092 Expect<float>(&reader, "foo_003", Constant_2x3<float>(3));
1093 }
1094 {
1095 BundleReader reader(Env::Default(), Prefix("foo"));
1096 TF_ASSERT_OK(reader.status());
1097 ExpectNext<float>(&reader, Constant_2x3<float>(0));
1098 ExpectNext<float>(&reader, Constant_2x3<float>(1));
1099 ExpectNext<float>(&reader, Constant_2x3<float>(2));
1100 ExpectNext<float>(&reader, Constant_2x3<float>(3));
1101 EXPECT_TRUE(reader.Valid());
1102 reader.Next();
1103 EXPECT_FALSE(reader.Valid());
1104 }
1105 {
1106 BundleReader reader(Env::Default(), Prefix("foo"));
1107 TF_ASSERT_OK(reader.status());
1108 ExpectAlignment<float>(&reader, "foo_000", 42);
1109 ExpectAlignment<float>(&reader, "foo_001", 42);
1110 ExpectAlignment<float>(&reader, "foo_002", 42);
1111 ExpectAlignment<float>(&reader, "foo_003", 42);
1112 }
1113 }
1114
BM_BundleAlignmentByteOff(::testing::benchmark::State & state,int alignment,int tensor_size)1115 static void BM_BundleAlignmentByteOff(::testing::benchmark::State& state,
1116 int alignment, int tensor_size) {
1117 {
1118 BundleWriter::Options opts;
1119 opts.data_alignment = alignment;
1120 BundleWriter writer(Env::Default(), Prefix("foo"), opts);
1121 TF_CHECK_OK(writer.Add("small", Constant(true, TensorShape({1}))));
1122 TF_CHECK_OK(writer.Add("big", Constant(32.1, TensorShape({tensor_size}))));
1123 TF_CHECK_OK(writer.Finish());
1124 }
1125 BundleReader reader(Env::Default(), Prefix("foo"));
1126 TF_CHECK_OK(reader.status());
1127 for (auto s : state) {
1128 Tensor t;
1129 TF_CHECK_OK(reader.Lookup("big", &t));
1130 }
1131 }
1132
1133 #define BM_BundleAlignment(ALIGN, SIZE) \
1134 static void BM_BundleAlignment_##ALIGN##_##SIZE( \
1135 ::testing::benchmark::State& state) { \
1136 BM_BundleAlignmentByteOff(state, ALIGN, SIZE); \
1137 } \
1138 BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
1139
1140 BM_BundleAlignment(1, 512);
1141 BM_BundleAlignment(1, 4096);
1142 BM_BundleAlignment(1, 1048576);
1143 BM_BundleAlignment(4096, 512);
1144 BM_BundleAlignment(4096, 4096);
1145 BM_BundleAlignment(4096, 1048576);
1146
1147 } // namespace tensorflow
1148