1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17 
18 #include <algorithm>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23 
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb_text.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb_text.h"
30 #include "tensorflow/core/framework/variant.h"
31 #include "tensorflow/core/framework/variant_op_registry.h"
32 #include "tensorflow/core/framework/variant_tensor_data.h"
33 #include "tensorflow/core/framework/versions.h"
34 #include "tensorflow/core/framework/versions.pb.h"
35 #include "tensorflow/core/lib/core/coding.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/gtl/map_util.h"
38 #include "tensorflow/core/lib/gtl/stl_util.h"
39 #include "tensorflow/core/lib/hash/crc32c.h"
40 #include "tensorflow/core/lib/io/path.h"
41 #include "tensorflow/core/lib/io/table_builder.h"
42 #include "tensorflow/core/lib/random/random.h"
43 #include "tensorflow/core/lib/strings/stringprintf.h"
44 #include "tensorflow/core/util/saved_tensor_slice_util.h"
45 #include "tensorflow/core/util/tensor_slice_util.h"
46 
47 namespace tensorflow {
48 
49 // Versioning of the tensor bundle format.
50 const int kTensorBundleMinProducer = 0;
51 const int kTensorBundleMinConsumer = 0;
52 const int kTensorBundleVersion = 1;
53 
54 // Size of our input buffer for streaming reads
55 static const int kBufferSize = 1024 * 1024;
56 
57 // Key to the special BundleHeaderProto entry.  Do not change this, as clients
58 // can make the assumption that the header is always the first entry in the
59 // bundle.
60 const char* const kHeaderEntryKey = "";
61 
62 namespace {
63 
64 // Reads "num_elements" string elements from file[offset, offset+size) into the
65 // length-N "destination".  Discards the original content of "destination".
66 //
67 // Checksums the string lengths (as restored uint32 or uint64, not varint64
68 // bytes) and string bytes, and stores it into "actual_crc32c".
ReadStringTensor(io::InputBuffer * buffered_file,size_t num_elements,size_t offset,size_t size,string * destination,uint32 * actual_crc32c)69 Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
70                         size_t offset, size_t size, string* destination,
71                         uint32* actual_crc32c) {
72   if (size == 0) return Status::OK();
73   CHECK_GT(size, 0);
74 
75   // Reads "num_elements" varint64's from "buffered_file".
76   TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
77   std::vector<uint64> string_lengths(num_elements);
78   for (size_t i = 0; i < num_elements; ++i) {
79     TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
80     if (string_lengths[i] <= UINT32_MAX) {
81       // We need to do this because older checkpoints only used uint32s and we
82       // should still support them.
83       const uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
84       *actual_crc32c = crc32c::Extend(
85           *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
86           sizeof(uint32));
87     } else {
88       *actual_crc32c = crc32c::Extend(
89           *actual_crc32c, reinterpret_cast<const char*>(&string_lengths[i]),
90           sizeof(uint64));
91     }
92   }
93   if (offset + size < buffered_file->Tell()) {
94     return errors::DataLoss("String lengths longer than expected offset ",
95                             offset + size);
96   }
97 
98   // Reads the length-checksum.
99   uint32 length_checksum = 0;
100   size_t unused_bytes_read = 0;
101   TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
102       sizeof(uint32), reinterpret_cast<char*>(&length_checksum),
103       &unused_bytes_read));
104   if (crc32c::Unmask(length_checksum) != *actual_crc32c) {
105     return errors::DataLoss(
106         "The length checksum does not match: expected ",
107         strings::Printf("%08u", crc32c::Unmask(length_checksum)),
108         " but actual is ", strings::Printf("%08u", *actual_crc32c));
109   }
110   *actual_crc32c =
111       crc32c::Extend(*actual_crc32c, reinterpret_cast<char*>(&length_checksum),
112                      sizeof(uint32));
113 
114   // Reads the actual string bytes.
115   for (size_t i = 0; i < num_elements; ++i) {
116     const uint64 string_length = string_lengths[i];
117     string* buffer = &destination[i];
118 
119     buffer->resize(string_length);
120     size_t bytes_read = 0;
121     TF_RETURN_IF_ERROR(
122         buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read));
123     *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read);
124   }
125   return Status::OK();
126 }
127 
ReadVariantTensor(io::InputBuffer * buffered_file,Tensor * ret,size_t offset,size_t size,uint32 * actual_crc32c)128 Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret,
129                          size_t offset, size_t size, uint32* actual_crc32c) {
130   // On-disk format:
131   //   [varint64 len1][bytes variant1][4 byte checksum]
132   //   ..
133   //   [varint64 lenN][bytes variantN][4 byte checksum]
134   // Var "crc32c" checksums all the lens, variant bytes, individual variant
135   // checksums (as uint32, not varint32 bytes).
136   if (size == 0) return Status::OK();
137   size_t num_elements = ret->NumElements();
138 
139   // Reads the actual string bytes.
140   TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
141   for (size_t i = 0; i < num_elements; ++i) {
142     // Read the serialized variant length.
143     uint64 string_length = 0;
144     TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length));
145     *actual_crc32c = crc32c::Extend(
146         *actual_crc32c, reinterpret_cast<const char*>(&string_length),
147         sizeof(uint64));
148     // Read the actual serialized variant.
149     string buffer;
150     buffer.resize(string_length);
151     size_t bytes_read = 0;
152     TF_RETURN_IF_ERROR(
153         buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read));
154     *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read);
155     VariantTensorDataProto proto;
156     if (!proto.ParseFromString(buffer)) {
157       return errors::DataLoss("Unable to parse VariantTensorDataProto from ",
158                               "buffer of size ", string_length, ". ",
159                               "Bundle entry offset: ", offset, " size: ", size);
160     }
161     Variant v = proto;
162     if (!DecodeUnaryVariant(&v)) {
163       return errors::Internal("Could not decode variant with type_name: \"",
164                               v.TypeName(), "\".  Perhaps you forgot to ",
165                               "register a decoder via ",
166                               "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?");
167     }
168 
169     // Read the checksum.
170     uint32 checksum = 0;
171     size_t unused_bytes_read = 0;
172     TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
173         sizeof(uint32), reinterpret_cast<char*>(&checksum),
174         &unused_bytes_read));
175     if (crc32c::Unmask(checksum) != *actual_crc32c) {
176       return errors::DataLoss(
177           "The checksum after Variant ", i, " does not match.",
178           " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)),
179           " Actual: ", strings::Printf("%08u", *actual_crc32c));
180     }
181     *actual_crc32c = crc32c::Extend(
182         *actual_crc32c, reinterpret_cast<char*>(&checksum), sizeof(uint32));
183 
184     ret->flat<Variant>()(i) = std::move(v);
185   }
186 
187   return Status::OK();
188 }
189 
GetBackingBuffer(const Tensor & val)190 char* GetBackingBuffer(const Tensor& val) {
191   CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype();
192   return const_cast<char*>(val.tensor_data().data());
193 }
194 
GetStringBackingBuffer(const Tensor & val)195 string* GetStringBackingBuffer(const Tensor& val) {
196   CHECK_EQ(DT_STRING, val.dtype());
197   return const_cast<string*>(val.flat<string>().data());
198 }
199 
ParseEntryProto(StringPiece key,StringPiece value,protobuf::MessageLite * out)200 Status ParseEntryProto(StringPiece key, StringPiece value,
201                        protobuf::MessageLite* out) {
202   if (!out->ParseFromArray(value.data(), value.size())) {
203     return errors::DataLoss("Entry for key ", key, " not parseable.");
204   }
205   return Status::OK();
206 }
207 
208 // Serializes the data bytes of the non-string tensor "val".  Discards the
209 // original content of "bytes_written", and on OK updates it with number of
210 // bytes written.
211 // REQUIRES: val.dtype() != DT_STRING
WriteTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written)212 Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
213                    size_t* bytes_written) {
214   DCHECK_NE(val.dtype(), DT_STRING);
215   DCHECK_NE(val.dtype(), DT_VARIANT);
216   *bytes_written = val.TotalBytes();
217   char* buf = GetBackingBuffer(val);
218   VLOG(1) << "Appending " << *bytes_written << " bytes to file";
219   return out->Append(StringPiece(buf, *bytes_written));
220 }
221 
222 // Serializes string tensor "val".  "bytes_written" is treated in the same
223 // fashion as WriteTensor().
224 //
225 // Checksums all bytes written and stores it into "crc32c".
226 // REQUIRES: val.dtype() == DT_STRING
WriteStringTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)227 Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
228                          size_t* bytes_written, uint32* crc32c) {
229   // On-disk format:
230   //   [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
231   // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
232   // the length-checksum, and all the string bytes.
233   DCHECK_EQ(val.dtype(), DT_STRING);
234   const string* strings = GetStringBackingBuffer(val);
235 
236   // Writes the varint lengths.
237   string lengths;
238   lengths.reserve(val.NumElements());  // At least 1 byte per element.
239   *crc32c = 0;
240   for (int64 i = 0; i < val.NumElements(); ++i) {
241     const string* elem = &strings[i];
242     DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
243     const uint64 elem_size = static_cast<uint64>(elem->size());
244 
245     core::PutVarint64(&lengths, elem_size);
246     if (elem_size <= UINT32_MAX) {
247       // We need to do this because older checkpoints only used uint32s and we
248       // should still support them.
249       const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
250       *crc32c = crc32c::Extend(*crc32c,
251                                reinterpret_cast<const char*>(&elem_size_uint32),
252                                sizeof(uint32));
253     } else {
254       *crc32c = crc32c::Extend(
255           *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
256     }
257   }
258   TF_RETURN_IF_ERROR(out->Append(lengths));
259   *bytes_written = lengths.size();
260 
261   // Writes the length checksum.
262   const uint32 length_checksum = crc32c::Mask(*crc32c);
263   TF_RETURN_IF_ERROR(out->Append(StringPiece(
264       reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
265   *crc32c = crc32c::Extend(
266       *crc32c, reinterpret_cast<const char*>(&length_checksum), sizeof(uint32));
267   *bytes_written += sizeof(uint32);
268 
269   // Writes all the string bytes out.
270   for (int64 i = 0; i < val.NumElements(); ++i) {
271     const string* string = &strings[i];
272     TF_RETURN_IF_ERROR(out->Append(*string));
273     *bytes_written += string->size();
274     *crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
275   }
276   return Status::OK();
277 }
278 
WriteVariantTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)279 Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out,
280                           size_t* bytes_written, uint32* crc32c) {
281   // On-disk format:
282   //   [varint64 len1][bytes variant1][4 byte checksum]
283   //   ..
284   //   [varint64 lenN][bytes variantN][4 byte checksum]
285   // Var "crc32c" checksums all the lens, variant bytes, individual variant
286   // checksums (as uint32, not varint32 bytes).
287   DCHECK_EQ(val.dtype(), DT_VARIANT);
288 
289   *crc32c = 0;
290   *bytes_written = 0;
291   for (int64 i = 0; i < val.NumElements(); ++i) {
292     VariantTensorData data;
293     val.flat<Variant>()(i).Encode(&data);
294     VariantTensorDataProto proto;
295     data.ToProto(&proto);
296     string elem;
297     proto.SerializeToString(&elem);
298 
299     // Write the length of the serialized variant.
300     DCHECK_EQ(elem.size(), static_cast<uint64>(elem.size()));
301     const auto elem_size = static_cast<uint64>(elem.size());
302     string len;
303     core::PutVarint64(&len, elem_size);
304     TF_RETURN_IF_ERROR(out->Append(len));
305     *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
306                              sizeof(uint64));
307     *bytes_written += len.size();
308 
309     // Write the serialized variant.
310     TF_RETURN_IF_ERROR(out->Append(elem));
311     *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size());
312     *bytes_written += elem.size();
313 
314     // Write the checksum.
315     const uint32 length_checksum = crc32c::Mask(*crc32c);
316     TF_RETURN_IF_ERROR(out->Append(StringPiece(
317         reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
318     *crc32c =
319         crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&length_checksum),
320                        sizeof(uint32));
321     *bytes_written += sizeof(uint32);
322   }
323 
324   return Status::OK();
325 }
326 
327 // Returns whether "slice_spec" is a full slice, with respect to the full shape.
328 //
329 // This can happen say, when "slice_spec" is
330 // "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0,
331 // dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against.
IsFullSlice(const TensorSlice & slice_spec,const TensorShape & full_tensor_shape)332 bool IsFullSlice(const TensorSlice& slice_spec,
333                  const TensorShape& full_tensor_shape) {
334   if (slice_spec.IsFull()) {
335     return true;
336   } else {
337     TensorShape sliced_shape;
338     slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError();
339     return sliced_shape == full_tensor_shape;
340   }
341 }
342 
CorruptFileError(const Status & in_status,const string & filename,const string & detail)343 Status CorruptFileError(const Status& in_status, const string& filename,
344                         const string& detail) {
345   if (in_status.ok()) {
346     return errors::Internal("Unable to read file (", filename,
347                             "). Perhaps the file is corrupt or was produced by "
348                             "a newer version of TensorFlow with format changes "
349                             "(",
350                             detail, ")");
351   }
352   return Status(
353       in_status.code(),
354       strings::StrCat("Unable to read file (", filename,
355                       "). Perhaps the file is corrupt or was produced by a "
356                       "newer version of TensorFlow with format changes (",
357                       detail, "): ", in_status.error_message()));
358 }
359 
TableBuilderOptions()360 table::Options TableBuilderOptions() {
361   table::Options o;
362   // Compressed tables cannot be read by TensorFlow releases prior to 1.1.
363   // To smoothen the transition, compressed writes are disabled for now
364   // (version 1.2) with the intention that they will be enabled again at
365   // some point (perhaps the 1.3 release?).
366   o.compression = table::kNoCompression;
367   return o;
368 }
369 
370 // Writes zeros to output buffer to align the next write to the requested
371 // alignment. "size" is the current size of the buffer and is updated to the
372 // new size.
PadAlignment(FileOutputBuffer * out,int alignment,int64 * size)373 Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) {
374   int bytes_over = *size % alignment;
375   if (bytes_over == 0) {
376     return Status::OK();
377   }
378   int bytes_to_write = alignment - bytes_over;
379   Status status = out->Append(string(bytes_to_write, '\0'));
380   if (status.ok()) {
381     *size += bytes_to_write;
382   }
383   return status;
384 }
385 
386 }  // namespace
387 
BundleWriter(Env * env,StringPiece prefix,const Options & options)388 BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
389     : env_(env),
390       options_(options),
391       prefix_(prefix),
392       tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate",
393                                          random::New64())),
394       tmp_data_path_(strings::StrCat(DataFilename(prefix_, 0, 1), ".tempstate",
395                                      random::New64())),
396       out_(nullptr),
397       size_(0) {
398   status_ = env_->CreateDir(string(io::Dirname(prefix_)));
399   if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
400     return;
401   }
402   const string filename = DataFilename(prefix_, 0, 1);
403   std::unique_ptr<WritableFile> wrapper;
404   status_ = env_->NewWritableFile(tmp_data_path_, &wrapper);
405   if (!status_.ok()) return;
406   out_ = std::unique_ptr<FileOutputBuffer>(
407       new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */));
408 
409   VLOG(1) << "Writing to file " << tmp_data_path_;
410 }
411 
Add(StringPiece key,const Tensor & val)412 Status BundleWriter::Add(StringPiece key, const Tensor& val) {
413   if (!status_.ok()) return status_;
414   CHECK_NE(key, kHeaderEntryKey);
415   const string key_string(key);
416   if (entries_.find(key_string) != entries_.end()) {
417     status_ = errors::InvalidArgument("Adding duplicate key: ", key);
418     return status_;
419   }
420 
421   BundleEntryProto* entry = &entries_[key_string];
422   entry->set_dtype(val.dtype());
423   val.shape().AsProto(entry->mutable_shape());
424   entry->set_shard_id(0);
425   entry->set_offset(size_);
426 
427   // Updates the data file.
428   size_t data_bytes_written = 0;
429   uint32 crc32c = 0;
430   out_->clear_crc32c();
431   if (val.dtype() == DT_STRING) {
432     status_ = WriteStringTensor(val, out_.get(), &data_bytes_written, &crc32c);
433   } else if (val.dtype() == DT_VARIANT) {
434     status_ = WriteVariantTensor(val, out_.get(), &data_bytes_written, &crc32c);
435   } else {
436     status_ = WriteTensor(val, out_.get(), &data_bytes_written);
437     crc32c = out_->crc32c();
438   }
439 
440   if (status_.ok()) {
441     entry->set_size(data_bytes_written);
442     entry->set_crc32c(crc32c::Mask(crc32c));
443     size_ += data_bytes_written;
444     status_ = PadAlignment(out_.get(), options_.data_alignment, &size_);
445   }
446   return status_;
447 }
448 
AddSlice(StringPiece full_tensor_key,const TensorShape & full_tensor_shape,const TensorSlice & slice_spec,const Tensor & slice_tensor)449 Status BundleWriter::AddSlice(StringPiece full_tensor_key,
450                               const TensorShape& full_tensor_shape,
451                               const TensorSlice& slice_spec,
452                               const Tensor& slice_tensor) {
453   if (!status_.ok()) return status_;
454   CHECK_NE(full_tensor_key, kHeaderEntryKey);
455 
456   // If just a singleton full slice, use the regular Add() to be more efficient.
457   if (IsFullSlice(slice_spec, full_tensor_shape)) {
458     return Add(full_tensor_key, slice_tensor);
459   }
460 
461   // Inserts/updates the full tensor's metadata entry.
462   //
463   // In the case of a sharded save, MergeBundles() is responsible for merging
464   // the "slices" field of multiple metadata entries corresponding to the same
465   // full tensor.
466   const string full_tensor_key_string(full_tensor_key);
467   BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
468   if (full_entry->dtype() != DT_INVALID) {
469     CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
470   }
471   if (full_entry->has_shape()) {
472     CHECK(TensorShape(full_entry->shape()) == full_tensor_shape);
473   }
474 
475   // Populates dtype, shape, and slices.  Intentionally leaving out shard_id and
476   // offset, which do not make sense for this full tensor entry.
477   full_entry->set_dtype(slice_tensor.dtype());
478   full_tensor_shape.AsProto(full_entry->mutable_shape());
479   TensorSliceProto* slice_proto = full_entry->add_slices();
480   slice_spec.AsProto(slice_proto);
481 
482   // The slice itself is handled by a regular Add(), which includes adding its
483   // own metadata entry, and writing out the slice's values.
484   const string slice_name =
485       checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec);
486   status_ = Add(slice_name, slice_tensor);
487   return status_;
488 }
489 
490 // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing
491 // the orphaned data file.
Finish()492 Status BundleWriter::Finish() {
493   if (out_) {
494     status_.Update(out_->Close());
495     out_ = nullptr;
496     if (status_.ok()) {
497       status_ = Env::Default()->RenameFile(tmp_data_path_,
498                                            DataFilename(prefix_, 0, 1));
499     } else {
500       Env::Default()->DeleteFile(tmp_data_path_).IgnoreError();
501     }
502   }
503   if (!status_.ok()) return status_;
504   // Build key -> BundleEntryProto table.
505   std::unique_ptr<WritableFile> file;
506   status_ = env_->NewWritableFile(tmp_metadata_path_, &file);
507   if (!status_.ok()) return status_;
508   {
509     // N.B.: the default use of Snappy compression may not be supported on all
510     // platforms (e.g. Android).  The metadata file is small, so this is fine.
511     table::Options options;
512     options.compression = table::kNoCompression;
513     table::TableBuilder builder(options, file.get());
514     // Header entry.
515     BundleHeaderProto header;
516     header.set_num_shards(1);
517     header.set_endianness(BundleHeaderProto::LITTLE);
518     if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG);
519     VersionDef* version = header.mutable_version();
520     version->set_producer(kTensorBundleVersion);
521     version->set_min_consumer(kTensorBundleMinConsumer);
522 
523     builder.Add(kHeaderEntryKey, header.SerializeAsString());
524 
525     // All others.
526     for (const auto& p : entries_) {
527       builder.Add(p.first, p.second.SerializeAsString());
528     }
529     status_ = builder.Finish();
530   }
531   status_.Update(file->Close());
532   if (!status_.ok()) {
533     Env::Default()->DeleteFile(tmp_metadata_path_).IgnoreError();
534     return status_;
535   } else {
536     status_ =
537         Env::Default()->RenameFile(tmp_metadata_path_, MetaFilename(prefix_));
538     if (!status_.ok()) return status_;
539   }
540   status_ = errors::Internal("BundleWriter is closed");
541   return Status::OK();
542 }
543 
544 // Merging tensor bundles.
545 
546 // Accumulator of metadata states during a merge.
547 struct MergeState {
548   // Accumulated from the header entries.
549   int num_shards = 0;
550 
551   // Derives "endianness" and "version" from the first bundle merged (hence the
552   // "seen_first_bundle" guard).  The two fields must be the same for all
553   // bundles in a merge.
554   bool seen_first_bundle = false;
555   BundleHeaderProto_Endianness endianness;
556   VersionDef version;
557 
558   // Tensor key -> BundleEntryProto.
559   std::map<string, BundleEntryProto> entries;
560   // Data file path -> new shard id in the final merged bundle.
561   std::unordered_map<string, int32> shard_ids;
562 };
563 
564 // Merges entries of "prefix" into the accumulator state "merge".
565 // Returns OK iff the merge succeeds.
MergeOneBundle(Env * env,StringPiece prefix,MergeState * merge_state)566 static Status MergeOneBundle(Env* env, StringPiece prefix,
567                              MergeState* merge_state) {
568   VLOG(1) << "Merging bundle:" << prefix;
569   const string filename = MetaFilename(prefix);
570   uint64 file_size;
571   TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
572   std::unique_ptr<RandomAccessFile> file;
573   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
574 
575   table::Table* table = nullptr;
576   TF_RETURN_IF_ERROR(
577       table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table));
578   std::unique_ptr<table::Table> table_deleter(table);
579   std::unique_ptr<table::Iterator> iter(table->NewIterator());
580 
581   int num_shards;
582   // Process header.
583   {
584     iter->Seek(kHeaderEntryKey);
585     if (!iter->Valid()) {
586       return CorruptFileError(iter->status(), filename,
587                               "failed to seek to header entry");
588     }
589     BundleHeaderProto header;
590     Status s = ParseEntryProto(iter->key(), iter->value(), &header);
591     if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
592 
593     merge_state->num_shards += header.num_shards();
594     if (!merge_state->seen_first_bundle) {
595       merge_state->seen_first_bundle = true;
596       merge_state->endianness = header.endianness();
597       merge_state->version = header.version();
598     } else {
599       // Validates "endianness".
600       if (merge_state->endianness != header.endianness()) {
601         return errors::InvalidArgument(
602             "Merging bundles with conflicting endianness; inputs corrupted?");
603       }
604       // Validates "version".
605       string curr_version, merge_version;
606       header.version().SerializeToString(&curr_version);
607       merge_state->version.SerializeToString(&merge_version);
608       if (curr_version != merge_version) {
609         return errors::InvalidArgument(
610             "Merging bundles with different format versions: merged ",
611             merge_version, " vs. curr ", curr_version);
612       }
613     }
614     num_shards = header.num_shards();
615     iter->Next();
616   }
617 
618   // Loops through the non-header to-merge entries.
619   BundleEntryProto to_merge_entry;
620   for (; iter->Valid(); iter->Next()) {
621     const string key(iter->key());
622     const auto entry_iter = merge_state->entries.find(key);
623 
624     // Illegal: the duplicated entry is a non-slice tensor.
625     if (entry_iter != merge_state->entries.end() &&
626         entry_iter->second.slices().empty()) {
627       return errors::InvalidArgument(
628           "Duplicate tensor keyed by ", key,
629           " encountered, when merging prefix: ", prefix);
630     }
631 
632     TF_RETURN_IF_ERROR(
633         ParseEntryProto(iter->key(), iter->value(), &to_merge_entry));
634 
635     // The duplicated entry holds metadata for a sliced full tensor.
636     // Allows the duplication and merges "slices".
637     if (entry_iter != merge_state->entries.end()) {
638       BundleEntryProto& existing_entry = entry_iter->second;
639       if (to_merge_entry.slices().empty()) {
640         return errors::Internal(
641             "Duplicate tensor keyed by ", key,
642             "; attempting to merge in a non-slice bundle entry");
643       }
644       // Only needs merge the "slices" field (and validate dtype/shape).
645       for (int i = 0; i < to_merge_entry.slices_size(); ++i) {
646         TensorSliceProto* slot = existing_entry.add_slices();
647         *slot = to_merge_entry.slices(i);
648       }
649       CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype());
650       CHECK(TensorShape(existing_entry.shape()) ==
651             TensorShape(to_merge_entry.shape()));
652       continue;
653     }
654 
655     // Key doesn't duplicate: a fresh tensor/slice entry.
656     auto result = merge_state->shard_ids.insert(
657         {DataFilename(prefix, to_merge_entry.shard_id(), num_shards),
658          merge_state->shard_ids.size()});
659     to_merge_entry.set_shard_id(result.first->second);
660     merge_state->entries[key] = to_merge_entry;
661   }
662   return Status::OK();
663 }
664 
MergeBundles(Env * env,gtl::ArraySlice<string> prefixes,StringPiece merged_prefix)665 Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
666                     StringPiece merged_prefix) {
667   // Merges all metadata tables.
668   // TODO(zhifengc): KeyValue sorter if it becomes too big.
669   MergeState merge;
670   Status status = env->CreateDir(string(io::Dirname(merged_prefix)));
671   if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
672   for (int i = 0; i < prefixes.size(); ++i) {
673     TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge));
674   }
675 
676   // Renames data files to contain the merged bundle prefix.
677   for (const auto& p : merge.shard_ids) {
678     VLOG(1) << "Renaming " << p.first << " to "
679             << DataFilename(merged_prefix, p.second, merge.shard_ids.size());
680     TF_RETURN_IF_ERROR(env->RenameFile(
681         p.first,
682         DataFilename(merged_prefix, p.second, merge.shard_ids.size())));
683   }
684 
685   // Writes the final metadata table under the merged prefix.
686   std::unique_ptr<WritableFile> merged_metadata;
687   TF_RETURN_IF_ERROR(
688       env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata));
689   {
690     table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get());
691     // Header entry.
692     BundleHeaderProto header;
693     header.set_num_shards(merge.num_shards);
694     header.set_endianness(merge.endianness);
695     *header.mutable_version() = merge.version;
696     builder.Add(kHeaderEntryKey, header.SerializeAsString());
697     // All others.
698     for (const auto& p : merge.entries) {
699       builder.Add(p.first, p.second.SerializeAsString());
700     }
701     status = builder.Finish();
702   }
703   status.Update(merged_metadata->Close());
704   if (!status.ok()) return status;
705   VLOG(1) << "Merged bundles to:" << merged_prefix;
706 
707   // Cleanup: best effort based and ignores errors.
708   for (const string& prefix : prefixes) {
709     env->DeleteFile(MetaFilename(prefix)).IgnoreError();
710   }
711   return status;
712 }
713 
714 // Interface for reading a tensor bundle.
715 
BundleReader(Env * env,StringPiece prefix)716 BundleReader::BundleReader(Env* env, StringPiece prefix)
717     : env_(env),
718       prefix_(prefix),
719       metadata_(nullptr),
720       table_(nullptr),
721       iter_(nullptr) {
722   const string filename = MetaFilename(prefix_);
723   uint64 file_size;
724   status_ = env_->GetFileSize(filename, &file_size);
725   if (!status_.ok()) return;
726 
727   // Opens the metadata table.
728   std::unique_ptr<RandomAccessFile> wrapper;
729   status_ = env_->NewRandomAccessFile(filename, &wrapper);
730   if (!status_.ok()) return;
731   metadata_ = wrapper.release();
732   status_ = table::Table::Open(table::Options(), metadata_, file_size, &table_);
733   if (!status_.ok()) return;
734   iter_ = table_->NewIterator();
735 
736   // Reads "num_shards_" from the first entry.
737   iter_->Seek(kHeaderEntryKey);
738   if (!iter_->Valid()) {
739     status_ = CorruptFileError(iter_->status(), filename,
740                                "failed to seek to header entry");
741     return;
742   }
743   BundleHeaderProto header;
744   status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
745   if (!status_.ok()) {
746     status_ = CorruptFileError(status_, filename, "unable to parse header");
747     return;
748   }
749   num_shards_ = header.num_shards();
750   if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
751       (header.endianness() == BundleHeaderProto::LITTLE &&
752        !port::kLittleEndian)) {
753     status_ = errors::Unimplemented(
754         "Reading a bundle with different endianness from the reader");
755     return;
756   }
757   status_ = CheckVersions(header.version(), kTensorBundleVersion,
758                           kTensorBundleMinProducer, "Checkpoint", "checkpoint");
759 }
760 
~BundleReader()761 BundleReader::~BundleReader() {
762   delete metadata_;
763   delete iter_;
764   delete table_;
765   // InputBuffer does not own the underlying RandomAccessFile.
766   for (auto pair : data_) {
767     if (pair.second != nullptr && pair.second->file() != nullptr) {
768       delete pair.second->file();
769     }
770   }
771   gtl::STLDeleteValues(&data_);
772   gtl::STLDeleteValues(&tensor_slices_);
773 }
774 
GetBundleEntryProto(StringPiece key,BundleEntryProto * entry)775 Status BundleReader::GetBundleEntryProto(StringPiece key,
776                                          BundleEntryProto* entry) {
777   entry->Clear();
778   TF_CHECK_OK(status_);
779   Seek(key);
780   if (!iter_->Valid() || iter_->key() != key) {
781     return errors::NotFound("Key ", key, " not found in checkpoint");
782   }
783 
784   BundleEntryProto entry_copy;
785   TF_RETURN_IF_ERROR(
786       ParseEntryProto(iter_->key(), iter_->value(), &entry_copy));
787   if (!TensorShape::IsValid(entry_copy.shape())) {
788     return errors::DataLoss("Invalid tensor shape: ", key, " ",
789                             ProtoShortDebugString(entry_copy.shape()));
790   }
791 
792   *entry = entry_copy;
793   return Status::OK();
794 }
795 
GetValue(const BundleEntryProto & entry,Tensor * val)796 Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
797   Tensor* ret = val;
798   const TensorShape stored_shape(TensorShape(entry.shape()));
799   if (val->NumElements() == 0) {
800     ret = new Tensor(entry.dtype(), stored_shape);
801   }
802 
803   // Validates the "size" field.
804   if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) {
805     if (entry.size() != ret->TotalBytes()) {
806       return errors::DataLoss("Invalid size in bundle entry: key ", key(),
807                               "; stored size ", entry.size(),
808                               "; expected size ", ret->TotalBytes());
809     }
810   } else if (entry.dtype() == DT_STRING) {
811     // Relaxes the check for string tensors as follows:
812     //   entry.size() == bytes(varint lengths) + bytes(data)
813     //                >= NumElems + bytes(data), since size bytes(varint) >= 1.
814     //   TotalBytes() == sizeof(string) * NumElems + bytes(data)
815     // Since we don't know bytes(varint lengths), we just check an inequality.
816     const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
817                                sizeof(string) * ret->NumElements();
818     if (entry.size() < lower_bound) {
819       return errors::DataLoss("Invalid size in bundle entry: key ", key(),
820                               "; stored size ", entry.size(),
821                               "; expected size is at least ", lower_bound);
822     }
823   }
824 
825   // Open the data file if it has not been opened.
826   io::InputBuffer* buffered_file = data_[entry.shard_id()];
827   if (buffered_file == nullptr) {
828     std::unique_ptr<RandomAccessFile> file = nullptr;
829     TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
830         DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
831     buffered_file = new io::InputBuffer(file.release(), kBufferSize);
832     // The InputBuffer and RandomAccessFile objects are both released in dtor.
833     data_[entry.shard_id()] = buffered_file;
834   }
835   CHECK(buffered_file != nullptr);
836 
837   TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
838   uint32 actual_crc32c = 0;
839 
840   if (DataTypeCanUseMemcpy(entry.dtype())) {
841     char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
842     size_t unused_bytes_read;
843     if (entry.size() > kBufferSize) {
844       StringPiece sp;
845       TF_RETURN_IF_ERROR(buffered_file->file()->Read(
846           entry.offset(), entry.size(), &sp, backing_buffer));
847       if (sp.data() != backing_buffer) {
848         memmove(backing_buffer, sp.data(), entry.size());
849       }
850     } else {
851       TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
852                                                    &unused_bytes_read));
853     }
854     actual_crc32c = crc32c::Value(backing_buffer, entry.size());
855   } else if (entry.dtype() == DT_VARIANT) {
856     // Relies on io::InputBuffer's buffering, because we issue many neighboring
857     // reads for a single string tensor.
858     TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(),
859                                          entry.size(), &actual_crc32c));
860   } else {
861     // Relies on io::InputBuffer's buffering, because we issue many neighboring
862     // reads for a single string tensor.
863     TF_RETURN_IF_ERROR(ReadStringTensor(
864         buffered_file, ret->NumElements(), entry.offset(), entry.size(),
865         GetStringBackingBuffer(*ret), &actual_crc32c));
866   }
867   if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {
868     return errors::DataLoss(
869         "Checksum does not match: stored ",
870         strings::Printf("%08u", crc32c::Unmask(entry.crc32c())),
871         " vs. calculated on the restored bytes ", actual_crc32c);
872   }
873 
874   *val = *ret;
875   if (ret != val) delete ret;
876   return Status::OK();
877 }
878 
Lookup(StringPiece key,Tensor * val)879 Status BundleReader::Lookup(StringPiece key, Tensor* val) {
880   CHECK(val != nullptr);
881   BundleEntryProto entry;
882   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
883 
884   if (entry.slices().empty()) {
885     return GetValue(entry, val);
886   } else {
887     return GetSliceValue(
888         key, entry,
889         /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
890   }
891 }
892 
ReadCurrent(Tensor * val)893 Status BundleReader::ReadCurrent(Tensor* val) {
894   CHECK(val != nullptr);
895   BundleEntryProto entry;
896   TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry));
897   if (!TensorShape::IsValid(entry.shape())) {
898     return errors::DataLoss("Invalid tensor shape: ", iter_->key(), " ",
899                             ProtoShortDebugString(entry.shape()));
900   }
901 
902   if (entry.slices().empty()) {
903     return GetValue(entry, val);
904   } else {
905     return GetSliceValue(
906         iter_->key(), entry,
907         /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
908   }
909 }
910 
LookupTensorSlices(StringPiece key,std::vector<TensorSlice> * slices)911 Status BundleReader::LookupTensorSlices(StringPiece key,
912                                         std::vector<TensorSlice>* slices) {
913   slices->clear();
914   BundleEntryProto entry;
915   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
916   slices->reserve(entry.slices_size());
917   for (const auto& slice : entry.slices()) {
918     slices->emplace_back(slice);
919   }
920   return Status::OK();
921 }
922 
LookupSlice(StringPiece full_tensor_key,const TensorSlice & slice_spec,Tensor * val)923 Status BundleReader::LookupSlice(StringPiece full_tensor_key,
924                                  const TensorSlice& slice_spec, Tensor* val) {
925   CHECK(val != nullptr);
926   BundleEntryProto entry;
927   TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry));
928   return GetSliceValue(full_tensor_key, entry, slice_spec, val);
929 }
930 
GetSliceValue(StringPiece full_tensor_key,const BundleEntryProto & full_tensor_entry,const TensorSlice & slice_spec,Tensor * val)931 Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
932                                    const BundleEntryProto& full_tensor_entry,
933                                    const TensorSlice& slice_spec, Tensor* val) {
934   using checkpoint::RegisterTensorSlice;
935   using checkpoint::TensorSliceSet;
936   DCHECK_GE(full_tensor_entry.slices_size(), 0);
937 
938   const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
939   std::vector<std::pair<TensorSlice, string>> details;
940   const string full_tensor_key_string(full_tensor_key);
941   const TensorSliceSet* tss =
942       gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
943 
944   // Populates the "full tensor key -> TensorSliceSet" cache.
945   if (tss == nullptr) {
946     if (full_tensor_entry.slices().empty()) {
947       // Special case: a writer has saved a tensor fully, but the reader wants
948       // to read in slices.  We therefore register the full slice on-demand here
949       // without further complicating the on-disk bundle format.
950       TF_RETURN_IF_ERROR(RegisterTensorSlice(
951           full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
952           /* tag */ "",
953           /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_));
954     }
955     for (const TensorSliceProto& slice : full_tensor_entry.slices()) {
956       TF_RETURN_IF_ERROR(RegisterTensorSlice(
957           full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
958           /* tag */ "", TensorSlice(slice), &tensor_slices_));
959     }
960     tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
961     CHECK_NE(tss, nullptr);
962   }
963   if (!tss->QueryMeta(slice_spec, &details)) {
964     return errors::InvalidArgument(
965         "Does not have sufficient slices for partitioned tensor ",
966         full_tensor_key,
967         " to restore in slice_spec: ", slice_spec.DebugString());
968   }
969 
970   // The union of the slices in "details" covers "slice_spec".  Performs the
971   // copies from each.
972   BundleEntryProto stored_slice_entry = full_tensor_entry;
973   for (const auto& slice_tag_pair : details) {
974     // Seeks for the stored slice.
975     const TensorSlice& stored_slice = slice_tag_pair.first;
976 
977     // We already have the entry for the full tensor, so don't query again if
978     // the slice is full.
979     if (!stored_slice.IsFull()) {
980       const string encoded_stored_slice_name =
981           checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
982                                             stored_slice);
983       status_ =
984           GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
985       if (!status_.ok()) return status_;
986     }
987 
988     // TODO(zongheng): should we take an OpKernelContext, so that we can call
989     // allocate_temp()?  Note that without major refactorings to Saver, it's
990     // hard for the caller of the tensor bundle module to allocate these
991     // precisely-shaped scratch storage.
992 
993     // Optimization for the common case: the stored slice can be directly
994     // copied to the destination without additional slicing. This is true when
995     // either the slices are equal or when they are both full slices having the
996     // same shape.
997     TensorShape stored_slice_shape(stored_slice_entry.shape());
998     if (stored_slice == slice_spec ||
999         (stored_slice_shape == val->shape() &&
1000          IsFullSlice(stored_slice, stored_slice_shape) &&
1001          IsFullSlice(slice_spec, stored_slice_shape))) {
1002       VLOG(1) << "Optimized for common case: directly copying into "
1003                  "pre-allocated buffer; spec: "
1004               << slice_spec.DebugString();
1005       status_ = GetValue(stored_slice_entry, val);
1006       return status_;
1007     }
1008 
1009     Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
1010     status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
1011     if (!status_.ok()) return status_;
1012 
1013     // Copies the intersection over.
1014     const DataType common_dtype = full_tensor_entry.dtype();
1015     switch (common_dtype) {
1016 #define HANDLE_COPY(T)                                                 \
1017   case DataTypeToEnum<T>::value:                                       \
1018     CHECK(CopyDataFromTensorSliceToTensorSlice(                        \
1019         full_shape, stored_slice, slice_spec,                          \
1020         stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
1021     break;
1022 
1023       HANDLE_COPY(float)
1024       HANDLE_COPY(double)
1025       HANDLE_COPY(int32)
1026       HANDLE_COPY(uint8)
1027       HANDLE_COPY(int16)
1028       HANDLE_COPY(int8)
1029       HANDLE_COPY(complex64)
1030       HANDLE_COPY(complex128)
1031       HANDLE_COPY(int64)
1032       HANDLE_COPY(bool)
1033       HANDLE_COPY(qint32)
1034       HANDLE_COPY(quint8)
1035       HANDLE_COPY(qint8)
1036       default:
1037         return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
1038                                        " not supported.");
1039     }
1040 #undef HANDLE_COPY
1041   }
1042   return Status::OK();
1043 }
1044 
Contains(StringPiece key)1045 bool BundleReader::Contains(StringPiece key) {
1046   Seek(key);
1047   return Valid() && (this->key() == key);
1048 }
1049 
LookupDtypeAndShape(StringPiece key,DataType * dtype,TensorShape * shape)1050 Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype,
1051                                          TensorShape* shape) {
1052   BundleEntryProto entry;
1053   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
1054   *dtype = entry.dtype();
1055   *shape = TensorShape(entry.shape());
1056   return Status::OK();
1057 }
1058 
LookupTensorShape(StringPiece key,TensorShape * shape)1059 Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) {
1060   DataType ignored;
1061   return LookupDtypeAndShape(key, &ignored, shape);
1062 }
1063 
DebugString()1064 string BundleReader::DebugString() {
1065   // Format used below emulates that of TensorSliceReader::DebugString().
1066   string shape_str;
1067   BundleEntryProto entry;
1068   Seek(kHeaderEntryKey);
1069   for (Next(); Valid(); Next()) {
1070     CHECK(entry.ParseFromArray(value().data(), value().size()));
1071     if (entry.slices_size() > 0) continue;  // Slice of some partitioned var.
1072 
1073     strings::StrAppend(&shape_str, key(), " (",
1074                        EnumName_DataType(entry.dtype()), ") ",
1075                        TensorShape(entry.shape()).DebugString());
1076     strings::StrAppend(&shape_str, "\n");
1077   }
1078   return shape_str;
1079 }
1080 
~FileOutputBuffer()1081 FileOutputBuffer::~FileOutputBuffer() { delete file_; }
1082 
Append(StringPiece data)1083 Status FileOutputBuffer::Append(StringPiece data) {
1084   // In the below, it is critical to calculate the checksum on the actually
1085   // copied bytes, not the source bytes.  This is because "data" typically
1086   // points to tensor buffers, which may be concurrently written.
1087   if (data.size() + position_ <= buffer_size_) {
1088     // Can fit into the current buffer.
1089     memcpy(&buffer_[position_], data.data(), data.size());
1090     crc32c_ = crc32c::Extend(crc32c_, &buffer_[position_], data.size());
1091   } else if (data.size() <= buffer_size_) {
1092     // Cannot fit, but can fit after flushing.
1093     TF_RETURN_IF_ERROR(FlushBuffer());
1094     memcpy(&buffer_[0], data.data(), data.size());
1095     crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], data.size());
1096   } else {
1097     // Cannot fit even after flushing.  So we break down "data" by chunk, and
1098     // flush/checksum each chunk.
1099     TF_RETURN_IF_ERROR(FlushBuffer());
1100     for (size_t i = 0; i < data.size(); i += buffer_size_) {
1101       const size_t nbytes = std::min(data.size() - i, buffer_size_);
1102       memcpy(&buffer_[0], data.data() + i, nbytes);
1103       crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], nbytes);
1104       position_ = nbytes;
1105       TF_RETURN_IF_ERROR(FlushBuffer());
1106     }
1107     return Status::OK();
1108   }
1109   position_ += data.size();
1110   return Status::OK();
1111 }
1112 
Close()1113 Status FileOutputBuffer::Close() {
1114   TF_RETURN_IF_ERROR(FlushBuffer());
1115   return file_->Close();
1116 }
1117 
FlushBuffer()1118 Status FileOutputBuffer::FlushBuffer() {
1119   if (position_ > 0) {
1120     TF_RETURN_IF_ERROR(file_->Append(StringPiece(&buffer_[0], position_)));
1121     position_ = 0;
1122   }
1123   return Status::OK();
1124 }
1125 
1126 }  // namespace tensorflow
1127