1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/checkpoint_reader.h"
17 
18 #include <unordered_set>
19 #include <utility>
20 
21 #include "tensorflow/core/platform/env.h"
22 #include "tensorflow/core/platform/status.h"
23 #include "tensorflow/core/platform/stringpiece.h"
24 #include "tensorflow/core/platform/types.h"
25 #include "tensorflow/core/util/saved_tensor_slice_util.h"
26 
27 namespace tensorflow {
28 namespace checkpoint {
29 
30 class TensorSliceReader;
31 
CheckpointReader(const string & filename,TF_Status * status)32 CheckpointReader::CheckpointReader(const string& filename, TF_Status* status)
33     : reader_(nullptr),
34       v2_reader_(nullptr),
35       var_to_shape_map_(nullptr),
36       var_to_data_type_map_(nullptr) {
37   // Depending on whether this is a V2 ckpt, initializes "reader_" or
38   // "v2_reader_".
39   std::vector<string> v2_path;
40   if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() &&
41       !v2_path.empty()) {
42     v2_reader_.reset(
43         new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
44     if (!v2_reader_->status().ok()) {
45       Set_TF_Status_from_Status(status, v2_reader_->status());
46       return;
47     }
48     auto result = BuildV2VarMaps();
49     var_to_shape_map_.swap(result.first);
50     var_to_data_type_map_.swap(result.second);
51   } else {
52     reader_.reset(new TensorSliceReader(filename));
53     if (!reader_->status().ok()) {
54       Set_TF_Status_from_Status(status, reader_->status());
55       return;
56     }
57     var_to_shape_map_.reset(
58         new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()));
59     var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap(
60         reader_->GetVariableToDataTypeMap()));
61   }
62 }
63 
HasTensor(const string & name) const64 bool CheckpointReader::HasTensor(const string& name) const {
65   if (reader_ != nullptr) {
66     return reader_->HasTensor(name, nullptr, nullptr);
67   }
68   return v2_reader_->Contains(name);
69 }
70 
71 const TensorSliceReader::VarToShapeMap&
GetVariableToShapeMap() const72 CheckpointReader::GetVariableToShapeMap() const {
73   CHECK(var_to_shape_map_);
74   return *var_to_shape_map_;
75 }
76 
77 const TensorSliceReader::VarToDataTypeMap&
GetVariableToDataTypeMap() const78 CheckpointReader::GetVariableToDataTypeMap() const {
79   CHECK(var_to_data_type_map_);
80   return *var_to_data_type_map_;
81 }
82 
DebugString() const83 const string CheckpointReader::DebugString() const {
84   if (reader_ != nullptr) return reader_->DebugString();
85   return v2_reader_->DebugString();
86 }
87 
GetTensor(const string & name,std::unique_ptr<tensorflow::Tensor> * out_tensor,TF_Status * out_status) const88 void CheckpointReader::GetTensor(
89     const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor,
90     TF_Status* out_status) const {
91   Status status;
92   if (reader_ != nullptr) {
93     status = reader_->GetTensor(name, out_tensor);
94   } else {
95     tensorflow::DataType dtype;
96     tensorflow::TensorShape shape;
97     status = v2_reader_->LookupDtypeAndShape(name, &dtype, &shape);
98     if (status.ok()) {
99       out_tensor->reset(new Tensor(dtype, shape));
100       status = v2_reader_->Lookup(name, out_tensor->get());
101       if (!status.ok()) out_tensor->reset();
102     }
103   }
104   if (!status.ok()) {
105     Set_TF_Status_from_Status(out_status, status);
106   }
107 }
108 
109 std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>,
110           std::unique_ptr<TensorSliceReader::VarToDataTypeMap>>
BuildV2VarMaps()111 CheckpointReader::BuildV2VarMaps() {
112   CHECK(v2_reader_ != nullptr);
113   CHECK(v2_reader_->status().ok());
114 
115   // First pass: filters out the entries of the slices.
116   std::unordered_set<string> filtered_keys;
117   BundleEntryProto entry;
118   v2_reader_->Seek(kHeaderEntryKey);
119   for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
120     CHECK(entry.ParseFromArray(v2_reader_->value().data(),
121                                v2_reader_->value().size()))
122         << entry.InitializationErrorString();
123     for (int i = 0; i < entry.slices_size(); ++i) {
124       const auto& slice_proto = entry.slices(i);
125       CHECK(filtered_keys
126                 .insert(EncodeTensorNameSlice(
127                     string(v2_reader_->key()) /* full var's name */,
128                     TensorSlice(slice_proto)))
129                 .second);
130     }
131   }
132 
133   // Second pass: adds the entries, ignoring the filtered keys.
134   std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map(
135       new TensorSliceReader::VarToShapeMap);
136   std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map(
137       new TensorSliceReader::VarToDataTypeMap);
138   v2_reader_->Seek(kHeaderEntryKey);
139   for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
140     if (filtered_keys.count(string(v2_reader_->key())) > 0) continue;
141     CHECK(entry.ParseFromArray(v2_reader_->value().data(),
142                                v2_reader_->value().size()))
143         << entry.InitializationErrorString();
144     string key(v2_reader_->key());
145     (*var_to_shape_map)[key] = TensorShape(entry.shape());
146     (*var_to_data_type_map)[key] = DataType(entry.dtype());
147   }
148   // The returned pointers are owned by the caller.
149   return std::make_pair(std::move(var_to_shape_map),
150                         std::move(var_to_data_type_map));
151 }
152 
153 }  // namespace checkpoint
154 }  // namespace tensorflow
155