1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_ 17 #define TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_ 18 19 #include "tensorflow/core/util/tensor_slice_reader.h" 20 #include "tensorflow/core/util/tensor_slice_writer.h" 21 22 namespace tensorflow { 23 24 class OpKernelContext; 25 26 // Legacy / V1 checkpoint format. 27 28 // Save input tensors in *context to a writer built from builder_func(). 29 // context must have the following inputs: 30 // 0: a single element string tensor that contains the file name. 31 // 1: names for the remaining tensors 32 // If save_slices is true: 33 // 2: shape and slice specifications. 34 // rest: tensors to save 35 void SaveTensors( 36 OpKernelContext* context, 37 checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func, 38 bool save_slices); 39 40 // Reads a single tensor from the reader built from open_func() and produces 41 // it as context->output(restore_index). "preferred_shard" is the same the 42 // TensorSliceReader preferred_shard parameter. 43 // 44 // context must have the following inputs: 45 // 0: a single element string tensor that contains the file name. 46 // 1: string tensor that names the outputs to be restored. 47 // If restore_slice is true: 48 // 2: shape and slice specification of the tensors to restore. 49 // 50 // restore_index indicates the variable name and slice to lookup 51 // in context(1) and (2). 52 void RestoreTensor(OpKernelContext* context, 53 checkpoint::TensorSliceReader::OpenTableFunction open_func, 54 int preferred_shard, bool restore_slice, int restore_index); 55 56 // V2 checkpoint format. 57 58 // Invokes the V2 checkpoint read path to read tensors. 59 // 60 // "context" is only used for allocating outputs. In particular, the inputs are 61 // explicitly provided and not accessed via the "input(i)" methods. 62 // REQUIRES: 63 // * "prefix" has 1 element, DT_STRING. 64 // * "tensor_names" and "shape_and_slices" shaped {N}, both DT_STRING. 65 // * "dtypes" has N elements, the datatypes of the to-restore tensors. 66 Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, 67 const Tensor& tensor_names, 68 const Tensor& shape_and_slices, 69 gtl::ArraySlice<DataType> dtypes); 70 71 } // namespace tensorflow 72 73 #endif // TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_ 74