1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_
17 #define TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_
18 
19 #include "tensorflow/core/util/tensor_slice_reader.h"
20 #include "tensorflow/core/util/tensor_slice_writer.h"
21 
22 namespace tensorflow {
23 
24 class OpKernelContext;
25 
26 // Legacy / V1 checkpoint format.
27 
28 // Save input tensors in *context to a writer built from builder_func().
29 // context must have the following inputs:
30 //  0: a single element string tensor that contains the file name.
31 //  1: names for the remaining tensors
32 // If save_slices is true:
33 //  2: shape and slice specifications.
34 //  rest: tensors to save
35 void SaveTensors(
36     OpKernelContext* context,
37     checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
38     bool save_slices);
39 
40 // Reads a single tensor from the reader built from open_func() and produces
41 // it as context->output(restore_index).  "preferred_shard" is the same the
42 // TensorSliceReader preferred_shard parameter.
43 //
44 // context must have the following inputs:
45 //  0: a single element string tensor that contains the file name.
46 //  1: string tensor that names the outputs to be restored.
47 // If restore_slice is true:
48 //  2: shape and slice specification of the tensors to restore.
49 //
50 // restore_index indicates the variable name and slice to lookup
51 // in context(1) and (2).
52 void RestoreTensor(OpKernelContext* context,
53                    checkpoint::TensorSliceReader::OpenTableFunction open_func,
54                    int preferred_shard, bool restore_slice, int restore_index);
55 
56 // V2 checkpoint format.
57 
58 // Invokes the V2 checkpoint read path to read tensors.
59 //
60 // "context" is only used for allocating outputs.  In particular, the inputs are
61 // explicitly provided and not accessed via the "input(i)" methods.
62 // REQUIRES:
63 //   * "prefix" has 1 element, DT_STRING.
64 //   * "tensor_names" and "shape_and_slices" shaped {N}, both DT_STRING.
65 //   * "dtypes" has N elements, the datatypes of the to-restore tensors.
66 Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
67                         const Tensor& tensor_names,
68                         const Tensor& shape_and_slices,
69                         gtl::ArraySlice<DataType> dtypes);
70 
71 }  // namespace tensorflow
72 
73 #endif  // TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_
74