1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <set> 22 #include <unordered_map> 23 #include <vector> 24 25 #include "tensorflow/core/common_runtime/device_factory.h" 26 #include "tensorflow/core/common_runtime/local_device.h" 27 #include "tensorflow/core/framework/device_base.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/types.pb.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/mem.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/public/session_options.h" 35 36 namespace tensorflow { 37 38 // Names of the XLA compilation devices. These are not user-visible, and are 39 // used internally by the Tensorflow/XLA bridge to perform symbolic execution of 40 // a Tensorflow graph. 41 42 extern const char* const DEVICE_CPU_XLA_JIT; // "CPU_XLA_JIT" 43 extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" 44 45 extern const char* const DEVICE_XLA_CPU; 46 extern const char* const DEVICE_XLA_GPU; 47 48 constexpr std::array<DataType, 4> kFloatTypes = { 49 {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; 50 constexpr std::array<DataType, 9> kNumericTypes = { 51 {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, 52 DT_COMPLEX64, DT_BFLOAT16}}; 53 54 constexpr std::array<DataType, 8> kCpuAllTypes = { 55 {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, 56 DT_COMPLEX64, DT_BOOL}}; 57 58 constexpr std::array<DataType, 8> kGpuAllTypes = { 59 {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, 60 DT_COMPLEX64, DT_BOOL}}; 61 62 // Class that manages registrations of operators and devices for the XLA JIT. 63 // Not thread-safe. 64 class XlaOpRegistry { 65 public: 66 typedef OpKernel* (*Factory)(OpKernelConstruction*); 67 68 // Describes how to compile operators assigned to a device. 69 struct DeviceRegistration { 70 // The name of the an XLA compilation device to use to compile code. 71 string compilation_device_name; 72 73 // Do operators assigned to this device require compilation? 74 bool requires_compilation; 75 76 // If !requires_compilation, should we try to JIT operators on this device 77 // when XLA JIT compilation is enabled globally via the SessionOptions? 78 // (It is still possible to explicitly mark operators to JIT compile, even 79 // if enable_jit_by_default is false.) 80 bool enable_jit_by_default; 81 82 // Enable compilation of operators that use DT_RESOURCE types? 83 bool compile_resource_ops = false; 84 }; 85 86 // Registers an XLA backend. `compilation_device_name` is the name of the 87 // device used for symbolic execution during compilation. `supported_types` 88 // is the list of non-resource types supported by the device. Each operators 89 // will be registered for the intersection of the operator's supported types 90 // and the device's supported types. `backend_op_filter` is a function used 91 // to exclude or modify operator registrations on the device; it may be 92 // nullptr, in which case all ops are included. 93 // `backend_op_filter` should return true if the op should be registered on 94 // the device; it may optionally modify the KernelDef. 95 typedef bool (*BackendOpFilter)(KernelDef* kdef); 96 static void RegisterBackend(const string& compilation_device_name, 97 gtl::ArraySlice<DataType> supported_types, 98 BackendOpFilter op_filter); 99 100 // Returns the names of the registered backends. 101 static std::vector<string> BackendNames(); 102 103 // Returns true iff a backend with the given name is registered. 104 static bool IsBackendRegistered(const string& name); 105 106 // Registers `device_name` for XLA compilation, using information from 107 // `registration`. 108 static void RegisterCompilationDevice(const string& device_name, 109 const DeviceRegistration& registration); 110 111 // Returns the JIT device name associated with 'device_name', setting 112 // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they 113 // are not null. Returns false and leaves the outputs unchanged if no matching 114 // JIT device is registered. 115 // '*enable_jit_by_default' is set to true if we should try to JIT using this 116 // device when the JIT is enabled via the Session OptimizerOptions. 117 static bool GetCompilationDevice(const string& device_name, 118 const DeviceRegistration** registration); 119 120 // Registers all JIT kernels on JIT devices, if not already registered. 121 // Does nothing otherwise. 122 static void RegisterCompilationKernels(); 123 124 // Returns KernelDefs for compilation ops registered on 125 // 'compilation_device_name'. Does not include kernels registered as 126 // CompilationOnly, iff include_compilation_only_kernels=false. 127 static std::vector<const KernelDef*> DeviceKernels( 128 const string& compilation_device_name, 129 bool include_compilation_only_kernels); 130 131 // Returns the set of compile-time constant inputs to 'op'. Returns nullptr 132 // if the op is not registered. 133 static const std::unordered_set<string>* CompileTimeConstantInputs( 134 const string& op); 135 136 private: 137 friend class XlaBackendRegistrar; 138 friend class XlaOpRegistrar; 139 friend class XlaOpRegistrationBuilder; 140 141 static XlaOpRegistry& Instance(); 142 143 XlaOpRegistry(); 144 ~XlaOpRegistry(); 145 146 mutex mutex_; 147 148 // Describes an XLA backend. 149 struct Backend { 150 // Which types are supported by this device? 151 std::set<DataType> supported_types; 152 153 // The per-backend operator filter function. See the comment on 154 // RegisterBackend() for details. 155 BackendOpFilter op_filter; 156 157 // KernelDefs built by RegisterCompilationKernels() for each op supported 158 // by the device. 159 std::vector<std::unique_ptr<KernelDef>> kernel_defs; 160 }; 161 162 // Map from compilation device names to a description of the backend. 163 std::unordered_map<string, Backend> backends_ GUARDED_BY(mutex_); 164 165 // Map from Tensorflow device names to the corresponding JIT device metadata. 166 std::unordered_map<string, DeviceRegistration> compilation_devices_ 167 GUARDED_BY(mutex_); 168 169 // A description of a Tensorflow operator that can be compiled to XLA. 170 struct OpRegistration { 171 string name; 172 173 // Should this operator be registered only on compilation devices, without a 174 // dummy kernel registered on the corresponding XLA device? 175 bool compilation_only = false; 176 177 // Should we allow resource types for type attributes? Used by _Arg to 178 // allow DT_RESOURCE. 179 bool allow_resource_types = false; 180 181 // Mapping from attribute name to a list of supported types. 182 std::unordered_map<string, std::set<DataType>> type_constraints; 183 184 // An optional whitelist of devices. If there is no whitelist, all devices 185 // are permitted. 186 bool has_device_whitelist = false; 187 std::unordered_set<string> device_whitelist; 188 189 // Names of arguments that must be compile-time constants. 190 std::unordered_set<string> compile_time_constant_inputs; 191 192 // Factory used to build OpKernels that perform symbolic execution. 193 Factory factory; 194 }; 195 196 // Returns true if registrations x and y can both be added to the registry. 197 // This is always the case if they refer to different ops. If they refer to 198 // the same op name, they must: have the same values for compilation_only and 199 // allow_resource_types; use a device_whitelist; and their 200 // whitelists must not intersect. 201 static bool IsCompatible(const OpRegistration& x, const OpRegistration& y); 202 203 // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. 204 // Registrations present under the same key must satisfy IsCompatible above, 205 // and this is checked during registration. 206 std::unordered_multimap<string, std::unique_ptr<OpRegistration>> ops_ 207 GUARDED_BY(mutex_); 208 209 // Have we already registered the JIT kernels on the JIT devices? 210 bool jit_kernels_registered_ = false; 211 212 // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel 213 // registrations created by RegisterCompilationKernels() and 214 // RegisterDeviceKernels(). 215 std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>> 216 kernel_registrars_ GUARDED_BY(mutex_); 217 }; 218 219 // REGISTER_XLA_OP() registers an XLA OpKernel by name, for example: 220 // REGISTER_XLA_OP(Name("Add"), AddOp); 221 // where 'AddOp' is the name of a JIT OpKernel class that implements "Add". 222 // 223 // We don't use a variadic macro here because we don't expect JIT operators to 224 // be templated. 225 226 #define REGISTER_XLA_OP(NAME, OP) \ 227 REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP) 228 229 class XlaOpRegistrationBuilder { 230 public: 231 // Starts an operator registration chain. 232 static XlaOpRegistrationBuilder Name(StringPiece name); 233 234 // Specifies a whitelist of devices on which the operator may run. 235 XlaOpRegistrationBuilder& Device(StringPiece devices); 236 XlaOpRegistrationBuilder& Device(gtl::ArraySlice<StringPiece> devices); 237 238 // Specifies a type constraint for a type variable attribute. Each constraint 239 // specifies the set of types that the type variable may assume. 240 XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, 241 DataType allowed); 242 243 XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, 244 gtl::ArraySlice<DataType> allowed); 245 246 // Specifies that a dummy copy of this operator should not be registered on 247 // XLA_* devices, but may be used during compilation. 248 XlaOpRegistrationBuilder& CompilationOnly(); 249 250 // Allow DT_RESOURCE types for type parameters. 251 XlaOpRegistrationBuilder& AllowResourceTypes(); 252 253 // Mark 'input_name' as an argument whose value must be known at compile-time. 254 XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name); 255 256 std::unique_ptr<XlaOpRegistry::OpRegistration> Build( 257 XlaOpRegistry::Factory factory); 258 259 private: 260 XlaOpRegistrationBuilder(StringPiece name); 261 262 std::unique_ptr<XlaOpRegistry::OpRegistration> registration_; 263 }; 264 265 // REGISTER_XLA_BACKEND() registers an XLA backend. Example usage: 266 // REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter); 267 #define REGISTER_XLA_BACKEND(NAME, ...) \ 268 REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__) 269 270 // Implementation details. 271 272 class XlaOpRegistrar { 273 public: 274 XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration); 275 }; 276 277 #define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \ 278 REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP) 279 280 #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP) \ 281 static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ 282 XlaOpRegistrationBuilder::BUILDER.Build( \ 283 [](::tensorflow::OpKernelConstruction* context) \ 284 -> ::tensorflow::OpKernel* { return new OP(context); })); 285 286 class XlaBackendRegistrar { 287 public: 288 XlaBackendRegistrar(StringPiece name, gtl::ArraySlice<DataType> types, 289 XlaOpRegistry::BackendOpFilter op_filter = nullptr); 290 }; 291 292 #define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \ 293 REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__) 294 295 #define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \ 296 static ::tensorflow::XlaBackendRegistrar \ 297 xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__); 298 299 } // namespace tensorflow 300 301 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 302