1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_H_
18 
19 #include <functional>
20 #include <unordered_map>
21 
22 #include <vector>
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/framework/op_def_util.h"
25 #include "tensorflow/core/framework/selective_registration.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace tensorflow {
37 
38 // Users that want to look up an OpDef by type name should take an
39 // OpRegistryInterface.  Functions accepting a
40 // (const) OpRegistryInterface* may call LookUp() from multiple threads.
41 class OpRegistryInterface {
42  public:
43   virtual ~OpRegistryInterface();
44 
45   // Returns an error status and sets *op_reg_data to nullptr if no OpDef is
46   // registered under that name, otherwise returns the registered OpDef.
47   // Caller must not delete the returned pointer.
48   virtual Status LookUp(const string& op_type_name,
49                         const OpRegistrationData** op_reg_data) const = 0;
50 
51   // Shorthand for calling LookUp to get the OpDef.
52   Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const;
53 };
54 
55 // The standard implementation of OpRegistryInterface, along with a
56 // global singleton used for registering ops via the REGISTER
57 // macros below.  Thread-safe.
58 //
59 // Example registration:
60 //   OpRegistry::Global()->Register(
61 //     [](OpRegistrationData* op_reg_data)->Status {
62 //       // Populate *op_reg_data here.
63 //       return Status::OK();
64 //   });
65 class OpRegistry : public OpRegistryInterface {
66  public:
67   typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
68 
69   OpRegistry();
70   ~OpRegistry() override;
71 
72   void Register(const OpRegistrationDataFactory& op_data_factory);
73 
74   Status LookUp(const string& op_type_name,
75                 const OpRegistrationData** op_reg_data) const override;
76 
77   // Fills *ops with all registered OpDefs (except those with names
78   // starting with '_' if include_internal == false) sorted in
79   // ascending alphabetical order.
80   void Export(bool include_internal, OpList* ops) const;
81 
82   // Returns ASCII-format OpList for all registered OpDefs (except
83   // those with names starting with '_' if include_internal == false).
84   string DebugString(bool include_internal) const;
85 
86   // A singleton available at startup.
87   static OpRegistry* Global();
88 
89   // Get all registered ops.
90   void GetRegisteredOps(std::vector<OpDef>* op_defs);
91 
92   // Get all `OpRegistrationData`s.
93   void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
94 
95   // Watcher, a function object.
96   // The watcher, if set by SetWatcher(), is called every time an op is
97   // registered via the Register function. The watcher is passed the Status
98   // obtained from building and adding the OpDef to the registry, and the OpDef
99   // itself if it was successfully built. A watcher returns a Status which is in
100   // turn returned as the final registration status.
101   typedef std::function<Status(const Status&, const OpDef&)> Watcher;
102 
103   // An OpRegistry object has only one watcher. This interface is not thread
104   // safe, as different clients are free to set the watcher any time.
105   // Clients are expected to atomically perform the following sequence of
106   // operations :
107   // SetWatcher(a_watcher);
108   // Register some ops;
109   // op_registry->ProcessRegistrations();
110   // SetWatcher(nullptr);
111   // Returns a non-OK status if a non-null watcher is over-written by another
112   // non-null watcher.
113   Status SetWatcher(const Watcher& watcher);
114 
115   // Process the current list of deferred registrations. Note that calls to
116   // Export, LookUp and DebugString would also implicitly process the deferred
117   // registrations. Returns the status of the first failed op registration or
118   // Status::OK() otherwise.
119   Status ProcessRegistrations() const;
120 
121   // Defer the registrations until a later call to a function that processes
122   // deferred registrations are made. Normally, registrations that happen after
123   // calls to Export, LookUp, ProcessRegistrations and DebugString are processed
124   // immediately. Call this to defer future registrations.
125   void DeferRegistrations();
126 
127   // Clear the registrations that have been deferred.
128   void ClearDeferredRegistrations();
129 
130  private:
131   // Ensures that all the functions in deferred_ get called, their OpDef's
132   // registered, and returns with deferred_ empty.  Returns true the first
133   // time it is called. Prints a fatal log if any op registration fails.
134   bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
135 
136   // Calls the functions in deferred_ and registers their OpDef's
137   // It returns the Status of the first failed op registration or Status::OK()
138   // otherwise.
139   Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
140 
141   // Add 'def' to the registry with additional data 'data'. On failure, or if
142   // there is already an OpDef with that name registered, returns a non-okay
143   // status.
144   Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
145       const EXCLUSIVE_LOCKS_REQUIRED(mu_);
146 
147   Status LookUpSlow(const string& op_type_name,
148                     const OpRegistrationData** op_reg_data) const;
149 
150   mutable mutex mu_;
151   // Functions in deferred_ may only be called with mu_ held.
152   mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_);
153   // Values are owned.
154   mutable std::unordered_map<string, const OpRegistrationData*> registry_
155       GUARDED_BY(mu_);
156   mutable bool initialized_ GUARDED_BY(mu_);
157 
158   // Registry watcher.
159   mutable Watcher watcher_ GUARDED_BY(mu_);
160 };
161 
162 // An adapter to allow an OpList to be used as an OpRegistryInterface.
163 //
164 // Note that shape inference functions are not passed in to OpListOpRegistry, so
165 // it will return an unusable shape inference function for every op it supports;
166 // therefore, it should only be used in contexts where this is okay.
167 class OpListOpRegistry : public OpRegistryInterface {
168  public:
169   // Does not take ownership of op_list, *op_list must outlive *this.
170   OpListOpRegistry(const OpList* op_list);
171   ~OpListOpRegistry() override;
172   Status LookUp(const string& op_type_name,
173                 const OpRegistrationData** op_reg_data) const override;
174 
175  private:
176   // Values are owned.
177   std::unordered_map<string, const OpRegistrationData*> index_;
178 };
179 
180 // Support for defining the OpDef (specifying the semantics of the Op and how
181 // it should be created) and registering it in the OpRegistry::Global()
182 // registry.  Usage:
183 //
184 // REGISTER_OP("my_op_name")
185 //     .Attr("<name>:<type>")
186 //     .Attr("<name>:<type>=<default>")
187 //     .Input("<name>:<type-expr>")
188 //     .Input("<name>:Ref(<type-expr>)")
189 //     .Output("<name>:<type-expr>")
190 //     .Doc(R"(
191 // <1-line summary>
192 // <rest of the description (potentially many lines)>
193 // <name-of-attr-input-or-output>: <description of name>
194 // <name-of-attr-input-or-output>: <description of name;
195 //   if long, indent the description on subsequent lines>
196 // )");
197 //
198 // Note: .Doc() should be last.
199 // For details, see the OpDefBuilder class in op_def_builder.h.
200 
201 namespace register_op {
202 
203 // OpDefBuilderWrapper is a templated class that is used in the REGISTER_OP
204 // calls. This allows the result of REGISTER_OP to be used in chaining, as in
205 // REGISTER_OP(a).Attr("...").Input("...");, while still allowing selective
206 // registration to turn the entire call-chain into a no-op.
207 template <bool should_register>
208 class OpDefBuilderWrapper;
209 
210 // Template specialization that forwards all calls to the contained builder.
211 template <>
212 class OpDefBuilderWrapper<true> {
213  public:
OpDefBuilderWrapper(const char name[])214   OpDefBuilderWrapper(const char name[]) : builder_(name) {}
Attr(string spec)215   OpDefBuilderWrapper<true>& Attr(string spec) {
216     builder_.Attr(std::move(spec));
217     return *this;
218   }
Input(string spec)219   OpDefBuilderWrapper<true>& Input(string spec) {
220     builder_.Input(std::move(spec));
221     return *this;
222   }
Output(string spec)223   OpDefBuilderWrapper<true>& Output(string spec) {
224     builder_.Output(std::move(spec));
225     return *this;
226   }
SetIsCommutative()227   OpDefBuilderWrapper<true>& SetIsCommutative() {
228     builder_.SetIsCommutative();
229     return *this;
230   }
SetIsAggregate()231   OpDefBuilderWrapper<true>& SetIsAggregate() {
232     builder_.SetIsAggregate();
233     return *this;
234   }
SetIsStateful()235   OpDefBuilderWrapper<true>& SetIsStateful() {
236     builder_.SetIsStateful();
237     return *this;
238   }
SetAllowsUninitializedInput()239   OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() {
240     builder_.SetAllowsUninitializedInput();
241     return *this;
242   }
Deprecated(int version,string explanation)243   OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) {
244     builder_.Deprecated(version, std::move(explanation));
245     return *this;
246   }
Doc(string text)247   OpDefBuilderWrapper<true>& Doc(string text) {
248     builder_.Doc(std::move(text));
249     return *this;
250   }
SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))251   OpDefBuilderWrapper<true>& SetShapeFn(
252       Status (*fn)(shape_inference::InferenceContext*)) {
253     builder_.SetShapeFn(fn);
254     return *this;
255   }
builder()256   const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
257 
258  private:
259   mutable ::tensorflow::OpDefBuilder builder_;
260 };
261 
262 // Template specialization that turns all calls into no-ops.
263 template <>
264 class OpDefBuilderWrapper<false> {
265  public:
OpDefBuilderWrapper(const char name[])266   constexpr OpDefBuilderWrapper(const char name[]) {}
Attr(StringPiece spec)267   OpDefBuilderWrapper<false>& Attr(StringPiece spec) { return *this; }
Input(StringPiece spec)268   OpDefBuilderWrapper<false>& Input(StringPiece spec) { return *this; }
Output(StringPiece spec)269   OpDefBuilderWrapper<false>& Output(StringPiece spec) { return *this; }
SetIsCommutative()270   OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; }
SetIsAggregate()271   OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
SetIsStateful()272   OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
SetAllowsUninitializedInput()273   OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
Deprecated(int,StringPiece)274   OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; }
Doc(StringPiece text)275   OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))276   OpDefBuilderWrapper<false>& SetShapeFn(
277       Status (*fn)(shape_inference::InferenceContext*)) {
278     return *this;
279   }
280 };
281 
282 struct OpDefBuilderReceiver {
283   // To call OpRegistry::Global()->Register(...), used by the
284   // REGISTER_OP macro below.
285   // Note: These are implicitly converting constructors.
286   OpDefBuilderReceiver(
287       const OpDefBuilderWrapper<true>& wrapper);  // NOLINT(runtime/explicit)
OpDefBuilderReceiverOpDefBuilderReceiver288   constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) {
289   }  // NOLINT(runtime/explicit)
290 };
291 }  // namespace register_op
292 
293 #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
294 #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
295 #define REGISTER_OP_UNIQ(ctr, name)                                          \
296   static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr    \
297       TF_ATTRIBUTE_UNUSED =                                                  \
298           ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \
299               name)>(name)
300 
301 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except
302 // that the op is registered unconditionally even when selective
303 // registration is used.
304 #define REGISTER_SYSTEM_OP(name) \
305   REGISTER_SYSTEM_OP_UNIQ_HELPER(__COUNTER__, name)
306 #define REGISTER_SYSTEM_OP_UNIQ_HELPER(ctr, name) \
307   REGISTER_SYSTEM_OP_UNIQ(ctr, name)
308 #define REGISTER_SYSTEM_OP_UNIQ(ctr, name)                                \
309   static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
310       TF_ATTRIBUTE_UNUSED =                                               \
311           ::tensorflow::register_op::OpDefBuilderWrapper<true>(name)
312 
313 }  // namespace tensorflow
314 
315 #endif  // TENSORFLOW_CORE_FRAMEWORK_OP_H_
316