1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/save_restore_tensor.h"
17 #include <numeric>
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
34 #include "tensorflow/core/util/tensor_slice_reader.h"
35 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
36 #include "tensorflow/core/util/tensor_slice_writer.h"
37
38 namespace tensorflow {
39
SaveTensors(OpKernelContext * context,checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,bool save_slices)40 void SaveTensors(
41 OpKernelContext* context,
42 checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
43 bool save_slices) {
44 const Tensor& filename_t = context->input(0);
45 {
46 const int64 size = filename_t.NumElements();
47 OP_REQUIRES(
48 context, size == 1,
49 errors::InvalidArgument(
50 "Input 0 (filename) must be a string scalar; got a tensor of ",
51 size, "elements"));
52 }
53
54 // Path, names, and slices if save_slices is true.
55 const int kFixedInputs = save_slices ? 3 : 2;
56 const Tensor& tensor_names_t = context->input(1);
57 OP_REQUIRES(context,
58 FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs,
59 std::numeric_limits<int>::max()),
60 errors::InvalidArgument("Too many inputs to SaveTensors"));
61 const int N = static_cast<int>(tensor_names_t.NumElements());
62 const string* tensor_shapes_and_slices_ptr = nullptr;
63 if (save_slices) {
64 const Tensor& tensor_shapes_and_slices_t = context->input(2);
65 OP_REQUIRES(
66 context,
67 tensor_shapes_and_slices_t.NumElements() == static_cast<int64>(N),
68 errors::InvalidArgument("Expected ", N,
69 " elements for the tensor "
70 "shapes and slices but got ",
71 tensor_shapes_and_slices_t.NumElements()));
72 tensor_shapes_and_slices_ptr =
73 tensor_shapes_and_slices_t.flat<string>().data();
74 }
75 OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs,
76 errors::InvalidArgument("Expected totally ", N + kFixedInputs,
77 " inputs as input #1 (which is a string "
78 "tensor of saved names) contains ",
79 N, " names, but received ",
80 context->num_inputs(), " inputs"));
81
82 VLOG(1) << "About to save tensors to file " << filename_t.flat<string>()(0)
83 << "...";
84 checkpoint::TensorSliceWriter writer(filename_t.flat<string>()(0),
85 std::move(builder_func));
86
87 Status s;
88 auto tensor_names_flat = tensor_names_t.flat<string>();
89
90 // Process tensors in sorted name order. This allows us to avoid seeking
91 // during restoration in the common case where we are restoring a full
92 // checkpoint.
93 std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
94 std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
95 std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
96 [&tensor_names_flat](size_t a, size_t b) {
97 return tensor_names_flat(a) < tensor_names_flat(b);
98 });
99
100 for (const size_t i : sorted_name_idx) {
101 const string& name = tensor_names_flat(i);
102 const Tensor& input = context->input(i + kFixedInputs);
103 TensorShape shape(input.shape());
104 TensorSlice slice(input.dims());
105 if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
106 const string& shape_spec = tensor_shapes_and_slices_ptr[i];
107 TensorShape slice_shape;
108 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
109 shape_spec, &shape, &slice, &slice_shape));
110 OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
111 errors::InvalidArgument(
112 "Slice in shape_and_slice "
113 "specification does not match the "
114 "shape of the tensor to save: ",
115 shape_spec, ", tensor: ", input.shape().DebugString()));
116 }
117
118 #define WRITER_ADD(T) \
119 case DataTypeToEnum<T>::value: \
120 s = writer.Add(name, shape, slice, input.flat<T>().data()); \
121 break;
122
123 switch (input.dtype()) {
124 TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD)
125 default:
126 context->SetStatus(errors::Unimplemented("Saving data type ",
127 DataTypeString(input.dtype()),
128 " not yet supported"));
129 return;
130 }
131 #undef WRITER_ADD
132 if (!s.ok()) {
133 context->SetStatus(s);
134 return;
135 }
136 }
137
138 s = writer.Finish();
139 if (!s.ok()) {
140 context->SetStatus(s);
141 }
142 }
143
RestoreTensor(OpKernelContext * context,checkpoint::TensorSliceReader::OpenTableFunction open_func,int preferred_shard,bool restore_slice,int restore_index)144 void RestoreTensor(OpKernelContext* context,
145 checkpoint::TensorSliceReader::OpenTableFunction open_func,
146 int preferred_shard, bool restore_slice, int restore_index) {
147 const Tensor& file_pattern_t = context->input(0);
148 {
149 const int64 size = file_pattern_t.NumElements();
150 OP_REQUIRES(
151 context, size == 1,
152 errors::InvalidArgument(
153 "Input 0 (file_pattern) must be a string scalar; got a tensor of ",
154 size, "elements"));
155 }
156 const string& file_pattern = file_pattern_t.flat<string>()(0);
157
158 const Tensor& tensor_name_t = context->input(1);
159 const string& tensor_name = tensor_name_t.flat<string>()(restore_index);
160
161 // If we cannot find a cached reader we will allocate our own.
162 std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
163
164 const checkpoint::TensorSliceReader* reader = nullptr;
165
166 if (context->slice_reader_cache()) {
167 reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
168 preferred_shard);
169 }
170 if (!reader) {
171 allocated_reader.reset(new checkpoint::TensorSliceReader(
172 file_pattern, open_func, preferred_shard));
173 reader = allocated_reader.get();
174 }
175 OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status());
176
177 // Get the shape and type from the save file.
178 DataType type;
179 TensorShape saved_shape;
180 OP_REQUIRES(
181 context, reader->HasTensor(tensor_name, &saved_shape, &type),
182 errors::NotFound("Tensor name \"", tensor_name,
183 "\" not found in checkpoint files ", file_pattern));
184 OP_REQUIRES(
185 context, type == context->expected_output_dtype(restore_index),
186 errors::InvalidArgument("Expected to restore a tensor of type ",
187 DataTypeString(context->expected_output_dtype(0)),
188 ", got a tensor of type ", DataTypeString(type),
189 " instead: tensor_name = ", tensor_name));
190
191 // Shape of the output and slice to load.
192 TensorShape output_shape(saved_shape);
193 TensorSlice slice_to_load(saved_shape.dims());
194 if (restore_slice) {
195 const string& shape_spec = context->input(2).flat<string>()(restore_index);
196 if (!shape_spec.empty()) {
197 TensorShape parsed_shape;
198 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
199 shape_spec, &parsed_shape, &slice_to_load,
200 &output_shape));
201 OP_REQUIRES(
202 context, parsed_shape.IsSameSize(saved_shape),
203 errors::InvalidArgument(
204 "Shape in shape_and_slice spec does not match the shape in the "
205 "save file: ",
206 parsed_shape.DebugString(),
207 ", save file shape: ", saved_shape.DebugString()));
208 }
209 }
210
211 Tensor* t = nullptr;
212 OP_REQUIRES_OK(context,
213 context->allocate_output(restore_index, output_shape, &t));
214
215 if (output_shape.num_elements() == 0) return;
216
217 #define READER_COPY(T) \
218 case DataTypeToEnum<T>::value: \
219 OP_REQUIRES(context, \
220 reader->CopySliceData(tensor_name, slice_to_load, \
221 t->flat<T>().data()), \
222 errors::InvalidArgument("Error copying slice data")); \
223 break;
224
225 switch (type) {
226 TF_CALL_SAVE_RESTORE_TYPES(READER_COPY)
227 default:
228 context->SetStatus(errors::Unimplemented(
229 "Restoring data type ", DataTypeString(type), " not yet supported"));
230 }
231 #undef READER_COPY
232 }
233
234 namespace {
235
236 // Tensors larger than this threshold will be restored from a thread-pool.
237 const int64 kLargeShapeThreshold = 16 << 20; // 16M
238
239 // A restore operation for a single tensor. Small tensors may be restored
240 // directly from the op thread to improve read locality. Large tensors can be
241 // restored from a thread pool: this requires creating a separate BundleReader
242 // for each restore.
243 struct RestoreOp {
244 RestoreOp& operator=(const RestoreOp&) = delete;
245
should_run_in_pooltensorflow::__anon4ff9bc700211::RestoreOp246 bool should_run_in_pool(BundleReader* reader) const {
247 TensorShape restored_full_shape;
248
249 // Ignore status here; we'll catch the error later.
250 if (!reader->LookupTensorShape(tensor_name, &restored_full_shape).ok()) {
251 return false;
252 }
253
254 return restored_full_shape.num_elements() > kLargeShapeThreshold;
255 }
256
257 // Run this restore operation using a new BundleReader.
run_with_new_readertensorflow::__anon4ff9bc700211::RestoreOp258 void run_with_new_reader() {
259 BundleReader reader(Env::Default(), reader_prefix);
260 if (!reader.status().ok()) {
261 status = reader.status();
262 return;
263 }
264
265 status = run(&reader);
266 }
267
runtensorflow::__anon4ff9bc700211::RestoreOp268 Status run(BundleReader* reader) {
269 TensorShape restored_full_shape;
270 TF_RETURN_IF_ERROR(
271 reader->LookupTensorShape(tensor_name, &restored_full_shape));
272
273 VLOG(1) << "Restoring tensor " << idx << " : " << tensor_name << " : "
274 << restored_full_shape.num_elements();
275 Tensor* restored_tensor;
276 if (shape_and_slice.empty()) {
277 // Lookup the full tensor.
278 TF_RETURN_IF_ERROR(
279 context->allocate_output(idx, restored_full_shape, &restored_tensor));
280 TF_RETURN_IF_ERROR(reader->Lookup(tensor_name, restored_tensor));
281 } else {
282 // Lookup the slice.
283 TensorShape parsed_full_shape;
284 TensorSlice parsed_slice;
285 TensorShape parsed_slice_shape;
286
287 TF_RETURN_IF_ERROR(
288 checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
289 &parsed_slice, &parsed_slice_shape));
290
291 if (!restored_full_shape.IsSameSize(parsed_full_shape)) {
292 return errors::InvalidArgument(
293 "tensor_name = ", tensor_name, "; shape in shape_and_slice spec ",
294 parsed_full_shape.DebugString(),
295 " does not match the shape stored in checkpoint: ",
296 restored_full_shape.DebugString());
297 }
298 TF_RETURN_IF_ERROR(
299 context->allocate_output(idx, parsed_slice_shape, &restored_tensor));
300 TF_RETURN_IF_ERROR(
301 reader->LookupSlice(tensor_name, parsed_slice, restored_tensor));
302 }
303 return Status::OK();
304 }
305
306 OpKernelContext* context;
307 size_t idx;
308 string tensor_name;
309 string shape_and_slice;
310 string reader_prefix;
311
312 ::tensorflow::Status status;
313 };
314
315 } // namespace
316
RestoreTensorsV2(OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices,gtl::ArraySlice<DataType> dtypes)317 Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
318 const Tensor& tensor_names,
319 const Tensor& shape_and_slices,
320 gtl::ArraySlice<DataType> dtypes) {
321 const string& prefix_string = prefix.scalar<string>()();
322
323 const auto& tensor_names_flat = tensor_names.flat<string>();
324 const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
325
326 // Sort lookup keys to improve locality when reading multiple tensors.
327 std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
328 std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
329 std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
330 [&tensor_names_flat](size_t a, size_t b) {
331 return tensor_names_flat(a) < tensor_names_flat(b);
332 });
333
334 std::vector<std::unique_ptr<RestoreOp> > pool_restore_ops;
335 std::vector<std::unique_ptr<RestoreOp> > direct_restore_ops;
336
337 BundleReader default_reader(Env::Default(), prefix_string);
338 TF_RETURN_IF_ERROR(default_reader.status());
339
340 std::vector<string> mismatched_errors;
341 for (const size_t i : sorted_name_idx) {
342 TensorShape restored_full_shape;
343 DataType original_dtype;
344 const string& tensor_name = tensor_names_flat(i);
345 TF_RETURN_IF_ERROR(default_reader.LookupDtypeAndShape(
346 tensor_name, &original_dtype, &restored_full_shape));
347 if (dtypes[i] != original_dtype) {
348 string error_msg = strings::StrCat(
349 "tensor_name = ", tensor_name, "; expected dtype ",
350 DataTypeString(dtypes[i]), " does not equal original dtype ",
351 DataTypeString(original_dtype));
352 mismatched_errors.emplace_back(error_msg);
353 }
354 }
355 if (!mismatched_errors.empty()) {
356 const string error_msg = str_util::Join(mismatched_errors, "\n");
357 return errors::InvalidArgument(error_msg);
358 }
359
360 for (auto i : sorted_name_idx) {
361 const string& tensor_name = tensor_names_flat(i);
362 const string& shape_and_slice = shape_and_slices_flat(i);
363 auto op =
364 new RestoreOp{context, i, tensor_name, shape_and_slice, prefix_string};
365 if (op->should_run_in_pool(&default_reader)) {
366 pool_restore_ops.emplace_back(op);
367 } else {
368 direct_restore_ops.emplace_back(op);
369 }
370 }
371
372 {
373 // Schedule any threaded operations first, skipping thread pool creation if
374 // we don't have any expensive operations.
375 std::unique_ptr<thread::ThreadPool> reader_pool;
376 if (!pool_restore_ops.empty()) {
377 reader_pool.reset(
378 new thread::ThreadPool(Env::Default(), "restore_tensors", 8));
379 for (auto& op : pool_restore_ops) {
380 reader_pool->Schedule([&op]() { op->run_with_new_reader(); });
381 }
382 }
383
384 // Read small tensors from the op thread
385 for (auto& op : direct_restore_ops) {
386 TF_RETURN_IF_ERROR(op->run(&default_reader));
387 }
388 }
389
390 // Check status of pool ops; this must come after the pool shuts down.
391 for (auto& op : pool_restore_ops) {
392 TF_RETURN_IF_ERROR(op->status);
393 }
394
395 for (auto i : sorted_name_idx) {
396 const string& tensor_name = tensor_names_flat(i);
397 if (dtypes[i] != context->mutable_output(i)->dtype()) {
398 return errors::InvalidArgument(
399 "tensor_name = ", tensor_name, "; expected dtype ",
400 DataTypeString(dtypes[i]), " does not equal restored dtype ",
401 DataTypeString(context->mutable_output(i)->dtype()));
402 }
403 }
404
405 return Status::OK();
406 }
407
408 } // namespace tensorflow
409