1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/op_def_builder.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/host_info.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/protobuf.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace tensorflow {
33 
DefaultValidator(const OpRegistryInterface & op_registry)34 Status DefaultValidator(const OpRegistryInterface& op_registry) {
35   LOG(WARNING) << "No kernel validator registered with OpRegistry.";
36   return Status::OK();
37 }
38 
39 // OpRegistry -----------------------------------------------------------------
40 
~OpRegistryInterface()41 OpRegistryInterface::~OpRegistryInterface() {}
42 
LookUpOpDef(const string & op_type_name,const OpDef ** op_def) const43 Status OpRegistryInterface::LookUpOpDef(const string& op_type_name,
44                                         const OpDef** op_def) const {
45   *op_def = nullptr;
46   const OpRegistrationData* op_reg_data = nullptr;
47   TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data));
48   *op_def = &op_reg_data->op_def;
49   return Status::OK();
50 }
51 
OpRegistry()52 OpRegistry::OpRegistry()
53     : initialized_(false), op_registry_validator_(DefaultValidator) {}
54 
~OpRegistry()55 OpRegistry::~OpRegistry() {
56   for (const auto& e : registry_) delete e.second;
57 }
58 
Register(const OpRegistrationDataFactory & op_data_factory)59 void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
60   mutex_lock lock(mu_);
61   if (initialized_) {
62     TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
63   } else {
64     deferred_.push_back(op_data_factory);
65   }
66 }
67 
68 namespace {
69 // Helper function that returns Status message for failed LookUp.
OpNotFound(const string & op_type_name)70 Status OpNotFound(const string& op_type_name) {
71   Status status = errors::NotFound(
72       "Op type not registered '", op_type_name, "' in binary running on ",
73       port::Hostname(), ". ",
74       "Make sure the Op and Kernel are registered in the binary running in "
75       "this process. Note that if you are loading a saved graph which used ops "
76       "from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done "
77       "before importing the graph, as contrib ops are lazily registered when "
78       "the module is first accessed.");
79   VLOG(1) << status.ToString();
80   return status;
81 }
82 }  // namespace
83 
LookUp(const string & op_type_name,const OpRegistrationData ** op_reg_data) const84 Status OpRegistry::LookUp(const string& op_type_name,
85                           const OpRegistrationData** op_reg_data) const {
86   if ((*op_reg_data = LookUp(op_type_name))) return Status::OK();
87   return OpNotFound(op_type_name);
88 }
89 
LookUp(const string & op_type_name) const90 const OpRegistrationData* OpRegistry::LookUp(const string& op_type_name) const {
91   {
92     tf_shared_lock l(mu_);
93     if (initialized_) {
94       if (const OpRegistrationData* res =
95               gtl::FindWithDefault(registry_, op_type_name, nullptr)) {
96         return res;
97       }
98     }
99   }
100   return LookUpSlow(op_type_name);
101 }
102 
LookUpSlow(const string & op_type_name) const103 const OpRegistrationData* OpRegistry::LookUpSlow(
104     const string& op_type_name) const {
105   const OpRegistrationData* res = nullptr;
106 
107   bool first_call = false;
108   bool first_unregistered = false;
109   {  // Scope for lock.
110     mutex_lock lock(mu_);
111     first_call = MustCallDeferred();
112     res = gtl::FindWithDefault(registry_, op_type_name, nullptr);
113 
114     static bool unregistered_before = false;
115     first_unregistered = !unregistered_before && (res == nullptr);
116     if (first_unregistered) {
117       unregistered_before = true;
118     }
119     // Note: Can't hold mu_ while calling Export() below.
120   }
121   if (first_call) {
122     TF_QCHECK_OK(op_registry_validator_(*this));
123   }
124   if (res == nullptr) {
125     if (first_unregistered) {
126       OpList op_list;
127       Export(true, &op_list);
128       if (VLOG_IS_ON(3)) {
129         LOG(INFO) << "All registered Ops:";
130         for (const auto& op : op_list.op()) {
131           LOG(INFO) << SummarizeOpDef(op);
132         }
133       }
134     }
135   }
136   return res;
137 }
138 
GetRegisteredOps(std::vector<OpDef> * op_defs)139 void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) {
140   mutex_lock lock(mu_);
141   MustCallDeferred();
142   for (const auto& p : registry_) {
143     op_defs->push_back(p.second->op_def);
144   }
145 }
146 
GetOpRegistrationData(std::vector<OpRegistrationData> * op_data)147 void OpRegistry::GetOpRegistrationData(
148     std::vector<OpRegistrationData>* op_data) {
149   mutex_lock lock(mu_);
150   MustCallDeferred();
151   for (const auto& p : registry_) {
152     op_data->push_back(*p.second);
153   }
154 }
155 
SetWatcher(const Watcher & watcher)156 Status OpRegistry::SetWatcher(const Watcher& watcher) {
157   mutex_lock lock(mu_);
158   if (watcher_ && watcher) {
159     return errors::AlreadyExists(
160         "Cannot over-write a valid watcher with another.");
161   }
162   watcher_ = watcher;
163   return Status::OK();
164 }
165 
Export(bool include_internal,OpList * ops) const166 void OpRegistry::Export(bool include_internal, OpList* ops) const {
167   mutex_lock lock(mu_);
168   MustCallDeferred();
169 
170   std::vector<std::pair<string, const OpRegistrationData*>> sorted(
171       registry_.begin(), registry_.end());
172   std::sort(sorted.begin(), sorted.end());
173 
174   auto out = ops->mutable_op();
175   out->Clear();
176   out->Reserve(sorted.size());
177 
178   for (const auto& item : sorted) {
179     if (include_internal || !absl::StartsWith(item.first, "_")) {
180       *out->Add() = item.second->op_def;
181     }
182   }
183 }
184 
DeferRegistrations()185 void OpRegistry::DeferRegistrations() {
186   mutex_lock lock(mu_);
187   initialized_ = false;
188 }
189 
ClearDeferredRegistrations()190 void OpRegistry::ClearDeferredRegistrations() {
191   mutex_lock lock(mu_);
192   deferred_.clear();
193 }
194 
ProcessRegistrations() const195 Status OpRegistry::ProcessRegistrations() const {
196   mutex_lock lock(mu_);
197   return CallDeferred();
198 }
199 
DebugString(bool include_internal) const200 string OpRegistry::DebugString(bool include_internal) const {
201   OpList op_list;
202   Export(include_internal, &op_list);
203   string ret;
204   for (const auto& op : op_list.op()) {
205     strings::StrAppend(&ret, SummarizeOpDef(op), "\n");
206   }
207   return ret;
208 }
209 
MustCallDeferred() const210 bool OpRegistry::MustCallDeferred() const {
211   if (initialized_) return false;
212   initialized_ = true;
213   for (size_t i = 0; i < deferred_.size(); ++i) {
214     TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i]));
215   }
216   deferred_.clear();
217   return true;
218 }
219 
CallDeferred() const220 Status OpRegistry::CallDeferred() const {
221   if (initialized_) return Status::OK();
222   initialized_ = true;
223   for (size_t i = 0; i < deferred_.size(); ++i) {
224     Status s = RegisterAlreadyLocked(deferred_[i]);
225     if (!s.ok()) {
226       return s;
227     }
228   }
229   deferred_.clear();
230   return Status::OK();
231 }
232 
RegisterAlreadyLocked(const OpRegistrationDataFactory & op_data_factory) const233 Status OpRegistry::RegisterAlreadyLocked(
234     const OpRegistrationDataFactory& op_data_factory) const {
235   std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
236   Status s = op_data_factory(op_reg_data.get());
237   if (s.ok()) {
238     s = ValidateOpDef(op_reg_data->op_def);
239     if (s.ok() &&
240         !gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
241                                  op_reg_data.get())) {
242       s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
243     }
244   }
245   Status watcher_status = s;
246   if (watcher_) {
247     watcher_status = watcher_(s, op_reg_data->op_def);
248   }
249   if (s.ok()) {
250     op_reg_data.release();
251   } else {
252     op_reg_data.reset();
253   }
254   return watcher_status;
255 }
256 
257 // static
Global()258 OpRegistry* OpRegistry::Global() {
259   static OpRegistry* global_op_registry = new OpRegistry;
260   return global_op_registry;
261 }
262 
263 // OpListOpRegistry -----------------------------------------------------------
264 
OpListOpRegistry(const OpList * op_list)265 OpListOpRegistry::OpListOpRegistry(const OpList* op_list) {
266   for (const OpDef& op_def : op_list->op()) {
267     auto* op_reg_data = new OpRegistrationData();
268     op_reg_data->op_def = op_def;
269     index_[op_def.name()] = op_reg_data;
270   }
271 }
272 
~OpListOpRegistry()273 OpListOpRegistry::~OpListOpRegistry() {
274   for (const auto& e : index_) delete e.second;
275 }
276 
LookUp(const string & op_type_name) const277 const OpRegistrationData* OpListOpRegistry::LookUp(
278     const string& op_type_name) const {
279   auto iter = index_.find(op_type_name);
280   if (iter == index_.end()) {
281     return nullptr;
282   }
283   return iter->second;
284 }
285 
LookUp(const string & op_type_name,const OpRegistrationData ** op_reg_data) const286 Status OpListOpRegistry::LookUp(const string& op_type_name,
287                                 const OpRegistrationData** op_reg_data) const {
288   if ((*op_reg_data = LookUp(op_type_name))) return Status::OK();
289   return OpNotFound(op_type_name);
290 }
291 
292 namespace register_op {
293 
operator ()()294 InitOnStartupMarker OpDefBuilderWrapper::operator()() {
295   OpRegistry::Global()->Register(
296       [builder =
297            std::move(builder_)](OpRegistrationData* op_reg_data) -> Status {
298         return builder.Finalize(op_reg_data);
299       });
300   return {};
301 }
302 
303 }  //  namespace register_op
304 
305 }  // namespace tensorflow
306