1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_H_
18 
19 #include <functional>
20 #include <unordered_map>
21 
22 #include <vector>
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/framework/op_def_util.h"
25 #include "tensorflow/core/framework/selective_registration.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace tensorflow {
37 
38 // Users that want to look up an OpDef by type name should take an
39 // OpRegistryInterface.  Functions accepting a
40 // (const) OpRegistryInterface* may call LookUp() from multiple threads.
41 class OpRegistryInterface {
42  public:
43   virtual ~OpRegistryInterface();
44 
45   // Returns an error status and sets *op_reg_data to nullptr if no OpDef is
46   // registered under that name, otherwise returns the registered OpDef.
47   // Caller must not delete the returned pointer.
48   virtual Status LookUp(const std::string& op_type_name,
49                         const OpRegistrationData** op_reg_data) const = 0;
50 
51   // Shorthand for calling LookUp to get the OpDef.
52   Status LookUpOpDef(const std::string& op_type_name,
53                      const OpDef** op_def) const;
54 };
55 
56 // The standard implementation of OpRegistryInterface, along with a
57 // global singleton used for registering ops via the REGISTER
58 // macros below.  Thread-safe.
59 //
60 // Example registration:
61 //   OpRegistry::Global()->Register(
62 //     [](OpRegistrationData* op_reg_data)->Status {
63 //       // Populate *op_reg_data here.
64 //       return Status::OK();
65 //   });
66 class OpRegistry : public OpRegistryInterface {
67  public:
68   typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
69 
70   OpRegistry();
71   ~OpRegistry() override;
72 
73   void Register(const OpRegistrationDataFactory& op_data_factory);
74 
75   Status LookUp(const std::string& op_type_name,
76                 const OpRegistrationData** op_reg_data) const override;
77 
78   // Returns OpRegistrationData* of registered op type, else returns nullptr.
79   const OpRegistrationData* LookUp(const std::string& op_type_name) const;
80 
81   // Fills *ops with all registered OpDefs (except those with names
82   // starting with '_' if include_internal == false) sorted in
83   // ascending alphabetical order.
84   void Export(bool include_internal, OpList* ops) const;
85 
86   // Returns ASCII-format OpList for all registered OpDefs (except
87   // those with names starting with '_' if include_internal == false).
88   std::string DebugString(bool include_internal) const;
89 
90   // A singleton available at startup.
91   static OpRegistry* Global();
92 
93   // Get all registered ops.
94   void GetRegisteredOps(std::vector<OpDef>* op_defs);
95 
96   // Get all `OpRegistrationData`s.
97   void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
98 
99   // Registers a function that validates op registry.
RegisterValidator(std::function<Status (const OpRegistryInterface &)> validator)100   void RegisterValidator(
101       std::function<Status(const OpRegistryInterface&)> validator) {
102     op_registry_validator_ = std::move(validator);
103   }
104 
105   // Watcher, a function object.
106   // The watcher, if set by SetWatcher(), is called every time an op is
107   // registered via the Register function. The watcher is passed the Status
108   // obtained from building and adding the OpDef to the registry, and the OpDef
109   // itself if it was successfully built. A watcher returns a Status which is in
110   // turn returned as the final registration status.
111   typedef std::function<Status(const Status&, const OpDef&)> Watcher;
112 
113   // An OpRegistry object has only one watcher. This interface is not thread
114   // safe, as different clients are free to set the watcher any time.
115   // Clients are expected to atomically perform the following sequence of
116   // operations :
117   // SetWatcher(a_watcher);
118   // Register some ops;
119   // op_registry->ProcessRegistrations();
120   // SetWatcher(nullptr);
121   // Returns a non-OK status if a non-null watcher is over-written by another
122   // non-null watcher.
123   Status SetWatcher(const Watcher& watcher);
124 
125   // Process the current list of deferred registrations. Note that calls to
126   // Export, LookUp and DebugString would also implicitly process the deferred
127   // registrations. Returns the status of the first failed op registration or
128   // Status::OK() otherwise.
129   Status ProcessRegistrations() const;
130 
131   // Defer the registrations until a later call to a function that processes
132   // deferred registrations are made. Normally, registrations that happen after
133   // calls to Export, LookUp, ProcessRegistrations and DebugString are processed
134   // immediately. Call this to defer future registrations.
135   void DeferRegistrations();
136 
137   // Clear the registrations that have been deferred.
138   void ClearDeferredRegistrations();
139 
140  private:
141   // Ensures that all the functions in deferred_ get called, their OpDef's
142   // registered, and returns with deferred_ empty.  Returns true the first
143   // time it is called. Prints a fatal log if any op registration fails.
144   bool MustCallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
145 
146   // Calls the functions in deferred_ and registers their OpDef's
147   // It returns the Status of the first failed op registration or Status::OK()
148   // otherwise.
149   Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
150 
151   // Add 'def' to the registry with additional data 'data'. On failure, or if
152   // there is already an OpDef with that name registered, returns a non-okay
153   // status.
154   Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
155       const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
156 
157   const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const;
158 
159   mutable mutex mu_;
160   // Functions in deferred_ may only be called with mu_ held.
161   mutable std::vector<OpRegistrationDataFactory> deferred_ TF_GUARDED_BY(mu_);
162   // Values are owned.
163   mutable std::unordered_map<string, const OpRegistrationData*> registry_
164       TF_GUARDED_BY(mu_);
165   mutable bool initialized_ TF_GUARDED_BY(mu_);
166 
167   // Registry watcher.
168   mutable Watcher watcher_ TF_GUARDED_BY(mu_);
169 
170   std::function<Status(const OpRegistryInterface&)> op_registry_validator_;
171 };
172 
173 // An adapter to allow an OpList to be used as an OpRegistryInterface.
174 //
175 // Note that shape inference functions are not passed in to OpListOpRegistry, so
176 // it will return an unusable shape inference function for every op it supports;
177 // therefore, it should only be used in contexts where this is okay.
178 class OpListOpRegistry : public OpRegistryInterface {
179  public:
180   // Does not take ownership of op_list, *op_list must outlive *this.
181   explicit OpListOpRegistry(const OpList* op_list);
182   ~OpListOpRegistry() override;
183   Status LookUp(const std::string& op_type_name,
184                 const OpRegistrationData** op_reg_data) const override;
185 
186   // Returns OpRegistrationData* of op type in list, else returns nullptr.
187   const OpRegistrationData* LookUp(const std::string& op_type_name) const;
188 
189  private:
190   // Values are owned.
191   std::unordered_map<string, const OpRegistrationData*> index_;
192 };
193 
194 // Support for defining the OpDef (specifying the semantics of the Op and how
195 // it should be created) and registering it in the OpRegistry::Global()
196 // registry.  Usage:
197 //
198 // REGISTER_OP("my_op_name")
199 //     .Attr("<name>:<type>")
200 //     .Attr("<name>:<type>=<default>")
201 //     .Input("<name>:<type-expr>")
202 //     .Input("<name>:Ref(<type-expr>)")
203 //     .Output("<name>:<type-expr>")
204 //     .Doc(R"(
205 // <1-line summary>
206 // <rest of the description (potentially many lines)>
207 // <name-of-attr-input-or-output>: <description of name>
208 // <name-of-attr-input-or-output>: <description of name;
209 //   if long, indent the description on subsequent lines>
210 // )");
211 //
212 // Note: .Doc() should be last.
213 // For details, see the OpDefBuilder class in op_def_builder.h.
214 
215 namespace register_op {
216 
217 class OpDefBuilderWrapper {
218  public:
OpDefBuilderWrapper(const char name[])219   explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
Attr(std::string spec)220   OpDefBuilderWrapper& Attr(std::string spec) {
221     builder_.Attr(std::move(spec));
222     return *this;
223   }
Input(std::string spec)224   OpDefBuilderWrapper& Input(std::string spec) {
225     builder_.Input(std::move(spec));
226     return *this;
227   }
Output(std::string spec)228   OpDefBuilderWrapper& Output(std::string spec) {
229     builder_.Output(std::move(spec));
230     return *this;
231   }
SetIsCommutative()232   OpDefBuilderWrapper& SetIsCommutative() {
233     builder_.SetIsCommutative();
234     return *this;
235   }
SetIsAggregate()236   OpDefBuilderWrapper& SetIsAggregate() {
237     builder_.SetIsAggregate();
238     return *this;
239   }
SetIsStateful()240   OpDefBuilderWrapper& SetIsStateful() {
241     builder_.SetIsStateful();
242     return *this;
243   }
SetDoNotOptimize()244   OpDefBuilderWrapper& SetDoNotOptimize() {
245     // We don't have a separate flag to disable optimizations such as constant
246     // folding and CSE so we reuse the stateful flag.
247     builder_.SetIsStateful();
248     return *this;
249   }
SetAllowsUninitializedInput()250   OpDefBuilderWrapper& SetAllowsUninitializedInput() {
251     builder_.SetAllowsUninitializedInput();
252     return *this;
253   }
Deprecated(int version,std::string explanation)254   OpDefBuilderWrapper& Deprecated(int version, std::string explanation) {
255     builder_.Deprecated(version, std::move(explanation));
256     return *this;
257   }
Doc(std::string text)258   OpDefBuilderWrapper& Doc(std::string text) {
259     builder_.Doc(std::move(text));
260     return *this;
261   }
SetShapeFn(OpShapeInferenceFn fn)262   OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) {
263     builder_.SetShapeFn(std::move(fn));
264     return *this;
265   }
266 
builder()267   const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
268 
269   InitOnStartupMarker operator()();
270 
271  private:
272   mutable ::tensorflow::OpDefBuilder builder_;
273 };
274 
275 }  // namespace register_op
276 
277 #define REGISTER_OP_IMPL(ctr, name, is_system_op)                         \
278   static ::tensorflow::InitOnStartupMarker const register_op##ctr         \
279       TF_ATTRIBUTE_UNUSED =                                               \
280           TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \
281           << ::tensorflow::register_op::OpDefBuilderWrapper(name)
282 
283 #define REGISTER_OP(name)        \
284   TF_ATTRIBUTE_ANNOTATE("tf:op") \
285   TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false)
286 
287 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except
288 // that the op is registered unconditionally even when selective
289 // registration is used.
290 #define REGISTER_SYSTEM_OP(name)        \
291   TF_ATTRIBUTE_ANNOTATE("tf:op")        \
292   TF_ATTRIBUTE_ANNOTATE("tf:op:system") \
293   TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true)
294 
295 }  // namespace tensorflow
296 
297 #endif  // TENSORFLOW_CORE_FRAMEWORK_OP_H_
298