1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/io_ops.cc.
17
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/kernels/save_restore_tensor.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/saved_tensor_slice_util.h"
33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
34 #include "tensorflow/core/util/tensor_slice_reader.h"
35
36 namespace tensorflow {
37
38 namespace {
39
40 // Shared validations of the inputs to the SaveV2 and RestoreV2 ops.
ValidateInputs(bool is_save_op,OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices)41 void ValidateInputs(bool is_save_op, OpKernelContext* context,
42 const Tensor& prefix, const Tensor& tensor_names,
43 const Tensor& shape_and_slices) {
44 const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices.
45 const int num_tensors = static_cast<int>(tensor_names.NumElements());
46 OP_REQUIRES(
47 context, prefix.NumElements() == 1,
48 errors::InvalidArgument("Input prefix should have a single element, got ",
49 prefix.NumElements(), " instead."));
50 OP_REQUIRES(context,
51 TensorShapeUtils::IsVector(tensor_names.shape()) &&
52 TensorShapeUtils::IsVector(shape_and_slices.shape()),
53 errors::InvalidArgument(
54 "Input tensor_names and shape_and_slices "
55 "should be an 1-D tensors, got ",
56 tensor_names.shape().DebugString(), " and ",
57 shape_and_slices.shape().DebugString(), " instead."));
58 OP_REQUIRES(context,
59 tensor_names.NumElements() == shape_and_slices.NumElements(),
60 errors::InvalidArgument("tensor_names and shape_and_slices "
61 "have different number of elements: ",
62 tensor_names.NumElements(), " vs. ",
63 shape_and_slices.NumElements()));
64 OP_REQUIRES(context,
65 FastBoundsCheck(tensor_names.NumElements() + kFixedInputs,
66 std::numeric_limits<int>::max()),
67 errors::InvalidArgument("Too many inputs to the op"));
68 OP_REQUIRES(
69 context, shape_and_slices.NumElements() == num_tensors,
70 errors::InvalidArgument("Expected ", num_tensors,
71 " elements in shapes_and_slices, but got ",
72 context->input(2).NumElements()));
73 if (is_save_op) {
74 OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
75 errors::InvalidArgument(
76 "Got ", num_tensors, " tensor names but ",
77 context->num_inputs() - kFixedInputs, " tensors."));
78 OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
79 errors::InvalidArgument(
80 "Expected a total of ", num_tensors + kFixedInputs,
81 " inputs as input #1 (which is a string "
82 "tensor of saved names) contains ",
83 num_tensors, " names, but received ", context->num_inputs(),
84 " inputs"));
85 }
86 }
87
88 } // namespace
89
90 // Saves a list of named tensors using the tensor bundle library.
91 class SaveV2 : public OpKernel {
92 public:
SaveV2(OpKernelConstruction * context)93 explicit SaveV2(OpKernelConstruction* context) : OpKernel(context) {}
94
Compute(OpKernelContext * context)95 void Compute(OpKernelContext* context) override {
96 const Tensor& prefix = context->input(0);
97 const Tensor& tensor_names = context->input(1);
98 const Tensor& shape_and_slices = context->input(2);
99 ValidateInputs(true /* is save op */, context, prefix, tensor_names,
100 shape_and_slices);
101
102 const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices.
103 const int num_tensors = static_cast<int>(tensor_names.NumElements());
104 const string& prefix_string = prefix.scalar<string>()();
105 const auto& tensor_names_flat = tensor_names.flat<string>();
106 const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
107
108 BundleWriter writer(Env::Default(), prefix_string);
109 OP_REQUIRES_OK(context, writer.status());
110 VLOG(1) << "BundleWriter, prefix_string: " << prefix_string;
111
112 for (int i = 0; i < num_tensors; ++i) {
113 const string& tensor_name = tensor_names_flat(i);
114 const Tensor& tensor = context->input(i + kFixedInputs);
115
116 if (!shape_and_slices_flat(i).empty()) {
117 const string& shape_spec = shape_and_slices_flat(i);
118 TensorShape shape;
119 TensorSlice slice(tensor.dims());
120 TensorShape slice_shape;
121
122 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
123 shape_spec, &shape, &slice, &slice_shape));
124 OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()),
125 errors::InvalidArgument("Slice in shape_and_slice "
126 "specification does not match the "
127 "shape of the tensor to save: ",
128 shape_spec, ", tensor: ",
129 tensor.shape().DebugString()));
130
131 OP_REQUIRES_OK(context,
132 writer.AddSlice(tensor_name, shape, slice, tensor));
133 } else {
134 OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor));
135 }
136 }
137 OP_REQUIRES_OK(context, writer.Finish());
138 }
139 };
140 REGISTER_KERNEL_BUILDER(Name("SaveV2").Device(DEVICE_CPU), SaveV2);
141
142 // Restores a list of named tensors from a tensor bundle (V2 checkpoint format).
143 class RestoreV2 : public OpKernel {
144 public:
RestoreV2(OpKernelConstruction * context)145 explicit RestoreV2(OpKernelConstruction* context) : OpKernel(context) {
146 OP_REQUIRES_OK(context, context->GetAttr("dtypes", &dtypes_));
147 }
148
Compute(OpKernelContext * context)149 void Compute(OpKernelContext* context) override {
150 const Tensor& prefix = context->input(0);
151 const Tensor& tensor_names = context->input(1);
152 const Tensor& shape_and_slices = context->input(2);
153 OP_REQUIRES(context, tensor_names.NumElements() == dtypes_.size(),
154 errors::InvalidArgument("Got ", tensor_names.NumElements(),
155 " tensor names, but ", dtypes_.size(),
156 " expected dtypes."));
157 ValidateInputs(false /* not save op */, context, prefix, tensor_names,
158 shape_and_slices);
159
160 const string& prefix_string = prefix.scalar<string>()();
161
162 // Intention: we plan to use the RestoreV2 op as a backward-compatible
163 // reader as we upgrade to the V2 format. This allows transparent upgrade.
164 // We here attempt to read a V1 checkpoint, if "prefix_string" does not
165 // refer to a V2 checkpoint.
166 Env* env = Env::Default();
167 std::vector<string> paths;
168 if (!env->GetMatchingPaths(MetaFilename(prefix_string), &paths).ok() ||
169 paths.empty()) {
170 // Cannot find V2's metadata file, so "prefix_string" does not point to a
171 // V2 checkpoint. Invokes the V1 read path instead.
172 for (size_t i = 0; i < tensor_names.NumElements(); ++i) {
173 RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
174 /* preferred_shard */ -1, /* restore_slice */ true,
175 /* restore_index */ i);
176 if (!context->status().ok()) {
177 return;
178 }
179 }
180 return;
181 }
182 // If found, invokes the V2 reader.
183 OP_REQUIRES_OK(context, RestoreTensorsV2(context, prefix, tensor_names,
184 shape_and_slices, dtypes_));
185 }
186
187 private:
188 // Expected dtypes of the to-restore tensors.
189 std::vector<DataType> dtypes_;
190 };
191 REGISTER_KERNEL_BUILDER(Name("RestoreV2").Device(DEVICE_CPU), RestoreV2);
192
193 // The final step in saving sharded V2 checkpoints: merges metadata files.
194 class MergeV2Checkpoints : public OpKernel {
195 public:
MergeV2Checkpoints(OpKernelConstruction * context)196 explicit MergeV2Checkpoints(OpKernelConstruction* context)
197 : OpKernel(context) {
198 OP_REQUIRES_OK(context,
199 context->GetAttr("delete_old_dirs", &delete_old_dirs_));
200 }
201
Compute(OpKernelContext * context)202 void Compute(OpKernelContext* context) override {
203 const Tensor& checkpoint_prefixes = context->input(0);
204 const Tensor& destination_prefix = context->input(1);
205 OP_REQUIRES(context,
206 TensorShapeUtils::IsVector(checkpoint_prefixes.shape()),
207 errors::InvalidArgument(
208 "Input checkpoint_prefixes should be an 1-D tensor, got ",
209 checkpoint_prefixes.shape().DebugString(), " instead."));
210 OP_REQUIRES(context, TensorShapeUtils::IsScalar(destination_prefix.shape()),
211 errors::InvalidArgument(
212 "Input destination_prefix should be a scalar tensor, got ",
213 destination_prefix.shape().DebugString(), " instead."));
214
215 const gtl::ArraySlice<string> input_prefixes =
216 gtl::ArraySlice<string>(checkpoint_prefixes.flat<string>());
217 Env* env = Env::Default();
218 const string& merged_prefix = destination_prefix.scalar<string>()();
219 OP_REQUIRES_OK(
220 context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
221
222 if (delete_old_dirs_) {
223 const string merged_dir(io::Dirname(merged_prefix));
224 for (const string& input_prefix : input_prefixes) {
225 const string dirname(io::Dirname(input_prefix));
226 if (dirname == merged_dir) continue;
227 Status status = env->DeleteDir(dirname);
228 // For sharded save, only the first delete will go through and all
229 // others will hit NotFound. Use vlog to be less verbose.
230 if (!status.ok()) VLOG(1) << status;
231 }
232 }
233 }
234
235 private:
236 // On merge, whether or not to delete the input (temporary) directories.
237 bool delete_old_dirs_;
238 };
239 REGISTER_KERNEL_BUILDER(Name("MergeV2Checkpoints").Device(DEVICE_CPU),
240 MergeV2Checkpoints);
241
242 } // namespace tensorflow
243