1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <set>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "tensorflow/core/common_runtime/device_factory.h"
26 #include "tensorflow/core/common_runtime/local_device.h"
27 #include "tensorflow/core/framework/device_base.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/mem.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/public/session_options.h"
35 
36 namespace tensorflow {
37 
38 // Names of the XLA compilation devices. These are not user-visible, and are
39 // used internally by the Tensorflow/XLA bridge to perform symbolic execution of
40 // a Tensorflow graph.
41 
42 extern const char* const DEVICE_CPU_XLA_JIT;  // "CPU_XLA_JIT"
43 extern const char* const DEVICE_GPU_XLA_JIT;  // "GPU_XLA_JIT"
44 
45 extern const char* const DEVICE_XLA_CPU;
46 extern const char* const DEVICE_XLA_GPU;
47 
48 constexpr std::array<DataType, 4> kFloatTypes = {
49     {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
50 constexpr std::array<DataType, 9> kNumericTypes = {
51     {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
52      DT_COMPLEX64, DT_BFLOAT16}};
53 
54 constexpr std::array<DataType, 8> kCpuAllTypes = {
55     {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
56      DT_COMPLEX64, DT_BOOL}};
57 
58 constexpr std::array<DataType, 8> kGpuAllTypes = {
59     {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
60      DT_COMPLEX64, DT_BOOL}};
61 
62 // Class that manages registrations of operators and devices for the XLA JIT.
63 // Not thread-safe.
64 class XlaOpRegistry {
65  public:
66   typedef OpKernel* (*Factory)(OpKernelConstruction*);
67 
68   // Describes how to compile operators assigned to a device.
69   struct DeviceRegistration {
70     // The name of the an XLA compilation device to use to compile code.
71     string compilation_device_name;
72 
73     // Do operators assigned to this device require compilation?
74     bool requires_compilation;
75 
76     // If !requires_compilation, should we try to JIT operators on this device
77     // when XLA JIT compilation is enabled globally via the SessionOptions?
78     // (It is still possible to explicitly mark operators to JIT compile, even
79     // if enable_jit_by_default is false.)
80     bool enable_jit_by_default;
81 
82     // Enable compilation of operators that use DT_RESOURCE types?
83     bool compile_resource_ops = false;
84   };
85 
86   // Registers an XLA backend. `compilation_device_name` is the name of the
87   // device used for symbolic execution during compilation. `supported_types`
88   // is the list of non-resource types supported by the device. Each operators
89   // will be registered for the intersection of the operator's supported types
90   // and the device's supported types. `backend_op_filter` is a function used
91   // to exclude or modify operator registrations on the device; it may be
92   // nullptr, in which case all ops are included.
93   // `backend_op_filter` should return true if the op should be registered on
94   // the device; it may optionally modify the KernelDef.
95   typedef bool (*BackendOpFilter)(KernelDef* kdef);
96   static void RegisterBackend(const string& compilation_device_name,
97                               gtl::ArraySlice<DataType> supported_types,
98                               BackendOpFilter op_filter);
99 
100   // Returns the names of the registered backends.
101   static std::vector<string> BackendNames();
102 
103   // Returns true iff a backend with the given name is registered.
104   static bool IsBackendRegistered(const string& name);
105 
106   // Registers `device_name` for XLA compilation, using information from
107   // `registration`.
108   static void RegisterCompilationDevice(const string& device_name,
109                                         const DeviceRegistration& registration);
110 
111   // Returns the JIT device name associated with 'device_name', setting
112   // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they
113   // are not null. Returns false and leaves the outputs unchanged if no matching
114   // JIT device is registered.
115   // '*enable_jit_by_default' is set to true if we should try to JIT using this
116   // device when the JIT is enabled via the Session OptimizerOptions.
117   static bool GetCompilationDevice(const string& device_name,
118                                    const DeviceRegistration** registration);
119 
120   // Registers all JIT kernels on JIT devices, if not already registered.
121   // Does nothing otherwise.
122   static void RegisterCompilationKernels();
123 
124   // Returns KernelDefs for compilation ops registered on
125   // 'compilation_device_name'.  Does not include kernels registered as
126   // CompilationOnly, iff include_compilation_only_kernels=false.
127   static std::vector<const KernelDef*> DeviceKernels(
128       const string& compilation_device_name,
129       bool include_compilation_only_kernels);
130 
131   // Returns the set of compile-time constant inputs to 'op'. Returns nullptr
132   // if the op is not registered.
133   static const std::unordered_set<string>* CompileTimeConstantInputs(
134       const string& op);
135 
136  private:
137   friend class XlaBackendRegistrar;
138   friend class XlaOpRegistrar;
139   friend class XlaOpRegistrationBuilder;
140 
141   static XlaOpRegistry& Instance();
142 
143   XlaOpRegistry();
144   ~XlaOpRegistry();
145 
146   mutex mutex_;
147 
148   // Describes an XLA backend.
149   struct Backend {
150     // Which types are supported by this device?
151     std::set<DataType> supported_types;
152 
153     // The per-backend operator filter function. See the comment on
154     // RegisterBackend() for details.
155     BackendOpFilter op_filter;
156 
157     // KernelDefs built by RegisterCompilationKernels() for each op supported
158     // by the device.
159     std::vector<std::unique_ptr<KernelDef>> kernel_defs;
160   };
161 
162   // Map from compilation device names to a description of the backend.
163   std::unordered_map<string, Backend> backends_ GUARDED_BY(mutex_);
164 
165   // Map from Tensorflow device names to the corresponding JIT device metadata.
166   std::unordered_map<string, DeviceRegistration> compilation_devices_
167       GUARDED_BY(mutex_);
168 
169   // A description of a Tensorflow operator that can be compiled to XLA.
170   struct OpRegistration {
171     string name;
172 
173     // Should this operator be registered only on compilation devices, without a
174     // dummy kernel registered on the corresponding XLA device?
175     bool compilation_only = false;
176 
177     // Should we allow resource types for type attributes? Used by _Arg to
178     // allow DT_RESOURCE.
179     bool allow_resource_types = false;
180 
181     // Mapping from attribute name to a list of supported types.
182     std::unordered_map<string, std::set<DataType>> type_constraints;
183 
184     // An optional whitelist of devices. If there is no whitelist, all devices
185     // are permitted.
186     bool has_device_whitelist = false;
187     std::unordered_set<string> device_whitelist;
188 
189     // Names of arguments that must be compile-time constants.
190     std::unordered_set<string> compile_time_constant_inputs;
191 
192     // Factory used to build OpKernels that perform symbolic execution.
193     Factory factory;
194   };
195 
196   // Returns true if registrations x and y can both be added to the registry.
197   // This is always the case if they refer to different ops. If they refer to
198   // the same op name, they must: have the same values for compilation_only and
199   // allow_resource_types; use a device_whitelist; and their
200   // whitelists must not intersect.
201   static bool IsCompatible(const OpRegistration& x, const OpRegistration& y);
202 
203   // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
204   // Registrations present under the same key must satisfy IsCompatible above,
205   // and this is checked during registration.
206   std::unordered_multimap<string, std::unique_ptr<OpRegistration>> ops_
207       GUARDED_BY(mutex_);
208 
209   // Have we already registered the JIT kernels on the JIT devices?
210   bool jit_kernels_registered_ = false;
211 
212   // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel
213   // registrations created by RegisterCompilationKernels() and
214   // RegisterDeviceKernels().
215   std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
216       kernel_registrars_ GUARDED_BY(mutex_);
217 };
218 
219 // REGISTER_XLA_OP() registers an XLA OpKernel by name, for example:
220 // REGISTER_XLA_OP(Name("Add"), AddOp);
221 // where 'AddOp' is the name of a JIT OpKernel class that implements "Add".
222 //
223 // We don't use a variadic macro here because we don't expect JIT operators to
224 // be templated.
225 
226 #define REGISTER_XLA_OP(NAME, OP) \
227   REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP)
228 
229 class XlaOpRegistrationBuilder {
230  public:
231   // Starts an operator registration chain.
232   static XlaOpRegistrationBuilder Name(StringPiece name);
233 
234   // Specifies a whitelist of devices on which the operator may run.
235   XlaOpRegistrationBuilder& Device(StringPiece devices);
236   XlaOpRegistrationBuilder& Device(gtl::ArraySlice<StringPiece> devices);
237 
238   // Specifies a type constraint for a type variable attribute. Each constraint
239   // specifies the set of types that the type variable may assume.
240   XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
241                                            DataType allowed);
242 
243   XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
244                                            gtl::ArraySlice<DataType> allowed);
245 
246   // Specifies that a dummy copy of this operator should not be registered on
247   // XLA_* devices, but may be used during compilation.
248   XlaOpRegistrationBuilder& CompilationOnly();
249 
250   // Allow DT_RESOURCE types for type parameters.
251   XlaOpRegistrationBuilder& AllowResourceTypes();
252 
253   // Mark 'input_name' as an argument whose value must be known at compile-time.
254   XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name);
255 
256   std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
257       XlaOpRegistry::Factory factory);
258 
259  private:
260   XlaOpRegistrationBuilder(StringPiece name);
261 
262   std::unique_ptr<XlaOpRegistry::OpRegistration> registration_;
263 };
264 
265 // REGISTER_XLA_BACKEND() registers an XLA backend. Example usage:
266 // REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter);
267 #define REGISTER_XLA_BACKEND(NAME, ...) \
268   REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__)
269 
270 // Implementation details.
271 
272 class XlaOpRegistrar {
273  public:
274   XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration);
275 };
276 
277 #define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \
278   REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP)
279 
280 #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP)                                 \
281   static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \
282       XlaOpRegistrationBuilder::BUILDER.Build(                                 \
283           [](::tensorflow::OpKernelConstruction* context)                      \
284               -> ::tensorflow::OpKernel* { return new OP(context); }));
285 
286 class XlaBackendRegistrar {
287  public:
288   XlaBackendRegistrar(StringPiece name, gtl::ArraySlice<DataType> types,
289                       XlaOpRegistry::BackendOpFilter op_filter = nullptr);
290 };
291 
292 #define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \
293   REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__)
294 
295 #define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \
296   static ::tensorflow::XlaBackendRegistrar        \
297       xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__);
298 
299 }  // namespace tensorflow
300 
301 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
302