1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
17 
18 #include <functional>
19 #include <memory>
20 
21 #include "tensorflow/compiler/jit/flags.h"
22 #include "tensorflow/compiler/jit/xla_cluster_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_context.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/local_device.h"
28 #include "tensorflow/core/framework/device_base.h"
29 #include "tensorflow/core/framework/kernel_def.pb.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op_def_util.h"
32 #include "tensorflow/core/platform/mem.h"
33 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
34 
35 namespace tensorflow {
36 
37 const char* const DEVICE_CPU_XLA_JIT = "XLA_CPU_JIT";
38 const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT";
39 const char* const DEVICE_XLA_CPU = "XLA_CPU";
40 const char* const DEVICE_XLA_GPU = "XLA_GPU";
41 
LaunchOpHasKernelForDevice(const DeviceType & device_type)42 static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) {
43   const OpDef* op_def;
44   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("XlaLaunch", &op_def));
45   NodeDef node_def;
46   node_def.set_name("_XlaLaunch-op");
47   node_def.set_op("XlaLaunch");
48   string kernel_class_name;
49   TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr,
50                                    &kernel_class_name));
51   VLOG(1) << "LaunchOpHasKernelForDevice"
52           << " kernel_class_name: " << kernel_class_name;
53   return Status::OK();
54 }
55 
56 XlaOpRegistry::XlaOpRegistry() = default;
57 XlaOpRegistry::~XlaOpRegistry() = default;
58 
59 // TODO(b/64575122) consider adding more sophisticated definitions of
60 // compatibility if needed by future use cases.
IsCompatible(const OpRegistration & x,const OpRegistration & y)61 /* static */ bool XlaOpRegistry::IsCompatible(const OpRegistration& x,
62                                               const OpRegistration& y) {
63   if (x.name != y.name) return true;
64   // The registrations refer to the same Op: ensures they are compatible and
65   // are restricted to different device whitelists.
66   if (x.compilation_only != y.compilation_only) {
67     LOG(WARNING) << "Registrations of " << x.name
68                  << " have incompatible compilation_only settings.";
69     return false;
70   }
71   if (x.allow_resource_types != y.allow_resource_types) {
72     LOG(WARNING) << "Registrations of " << x.name
73                  << " have incompatible allow_resource_types settings.";
74     return false;
75   }
76   if (x.allow_variant_types != y.allow_variant_types) {
77     LOG(WARNING) << "Registrations of " << x.name
78                  << " have incompatible allow_variant_types settings.";
79     return false;
80   }
81   if (!x.has_device_whitelist && !y.has_device_whitelist) {
82     LOG(WARNING) << "Duplicate registrations of " << x.name
83                  << "with no device whitelists.";
84     return false;
85   }
86   if (x.has_device_whitelist && y.has_device_whitelist) {
87     for (const auto& device : x.device_whitelist) {
88       if (y.device_whitelist.count(device) != 0) {
89         LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
90                      << device;
91         return false;
92       }
93     }
94   }
95   if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) {
96     LOG(WARNING) << "Registrations of " << x.name
97                  << " have incompatible compile time constant inputs.";
98     return false;
99   }
100   if (x.is_metadata_op != y.is_metadata_op) {
101     LOG(WARNING) << "Registrations of " << x.name
102                  << " have incompatible values for is_metadata_op.";
103     return false;
104   }
105   return true;
106 }
107 
RegisterCompilationDevice(const string & device_name,const DeviceRegistration & registration)108 /* static */ void XlaOpRegistry::RegisterCompilationDevice(
109     const string& device_name, const DeviceRegistration& registration) {
110   XlaOpRegistry& registry = Instance();
111   mutex_lock lock(registry.mutex_);
112   auto result =
113       registry.compilation_devices_.emplace(device_name, registration);
114   CHECK(result.second || result.first->second.compilation_device_name ==
115                              registration.compilation_device_name);
116 }
117 
RegisterBackend(const string & compilation_device_name,absl::Span<const DataType> supported_types,BackendOpFilter op_filter)118 /* static */ void XlaOpRegistry::RegisterBackend(
119     const string& compilation_device_name,
120     absl::Span<const DataType> supported_types, BackendOpFilter op_filter) {
121   XlaOpRegistry& registry = Instance();
122   mutex_lock lock(registry.mutex_);
123   auto result = registry.backends_.emplace(compilation_device_name, Backend());
124   CHECK(result.second) << "Duplicate XLA backend registration "
125                        << compilation_device_name;
126   result.first->second.supported_types.insert(supported_types.begin(),
127                                               supported_types.end());
128   result.first->second.op_filter = op_filter;
129 }
130 
GetCompilationDevice(const string & device_name,const DeviceRegistration ** registration)131 /* static */ bool XlaOpRegistry::GetCompilationDevice(
132     const string& device_name, const DeviceRegistration** registration) {
133   XlaOpRegistry& registry = Instance();
134 
135   // Lazily register the CPU and GPU JIT devices the first time
136   // GetCompilationDevice is called.
137   static void* registration_init = [&registry]() {
138     MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
139     bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
140     VLOG(2) << "tf_xla_cpu_global_jit = " << cpu_global_jit;
141 
142     mutex_lock lock(registry.mutex_);
143     if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) {
144       DeviceRegistration& registration =
145           registry.compilation_devices_[DEVICE_CPU];
146       registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
147       registration.autoclustering_policy =
148           cpu_global_jit
149               ? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally
150               : XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested;
151       registration.compile_all_resource_ops = false;
152     }
153     if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) {
154       DeviceRegistration& registration =
155           registry.compilation_devices_[DEVICE_GPU];
156       registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
157       registration.autoclustering_policy =
158           XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
159       registration.compile_all_resource_ops = false;
160     }
161     return nullptr;
162   }();
163   (void)registration_init;
164 
165   mutex_lock lock(registry.mutex_);
166   auto it = registry.compilation_devices_.find(device_name);
167   if (it == registry.compilation_devices_.end()) return false;
168   *registration = &it->second;
169   return true;
170 }
171 
RegisterCompilationKernels()172 void XlaOpRegistry::RegisterCompilationKernels() {
173   XlaOpRegistry& registry = Instance();
174   mutex_lock lock(registry.mutex_);
175 
176   if (registry.jit_kernels_registered_) return;
177   registry.jit_kernels_registered_ = true;
178 
179   OpRegistryInterface* op_registry = OpRegistry::Global();
180   // Order of op registration:
181   // The goal is to allow the co-existence of backend-specific kernels and
182   // generic kernels. To achieve this, we enforce the following order of
183   // registrations for one op:
184   // 1. Process op registration with device whitelists:
185   //      this pass registers backend-specific kernels for this op.
186   // 2. Process op registration without device whitelists:
187   //      this pass registers the kernels for all the other supported backends.
188   for (auto& ops : registry.ops_) {
189     const string& op_name = ops.first;
190     std::vector<std::unique_ptr<OpRegistration>>& op_registrations = ops.second;
191     // Partition the op registration so that the ones with device whitelists
192     // precede the one without device whitelist.
193     std::partition(op_registrations.begin(), op_registrations.end(),
194                    [](const std::unique_ptr<OpRegistration>& op_reg) {
195                      return op_reg->has_device_whitelist;
196                    });
197 
198     // Collect a set of backend registered by ops with device whitelists.
199     // The op registration without whitelists will register a generic kernel
200     // for all other backends not in this set.
201     std::unordered_set<string> whitelisted_backend;
202     for (auto& op_registration : op_registrations) {
203       if (op_registration->has_device_whitelist) {
204         whitelisted_backend.insert(op_registration->device_whitelist.begin(),
205                                    op_registration->device_whitelist.end());
206       }
207     }
208 
209     for (auto& op_registration : op_registrations) {
210       const OpDef* op_def;
211       Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def);
212       if (!lookup_status.ok()) {
213         LOG(ERROR) << lookup_status.error_message();
214         XLA_LOG_LINES(
215             ERROR,
216             "Ops registered: \n" +
217                 dynamic_cast<OpRegistry*>(op_registry)->DebugString(true));
218       }
219       TF_CHECK_OK(lookup_status);
220 
221       std::unordered_set<string> type_attrs;
222       for (const OpDef::AttrDef& attr_def : op_def->attr()) {
223         if (attr_def.type() == "type" || attr_def.type() == "list(type)") {
224           type_attrs.insert(attr_def.name());
225         }
226       }
227 
228       // Checks there are no type constraints referring to unknown attributes.
229       for (const auto& constraint : op_registration->type_constraints) {
230         if (type_attrs.find(constraint.first) == type_attrs.end()) {
231           LOG(FATAL) << "Unknown type attribute " << constraint.first
232                      << " in XLA op registration for " << op_name;
233         }
234       }
235 
236       for (auto& backend : registry.backends_) {
237         // If the operator has a device whitelist, only register on whitelisted
238         // devices.
239         if (op_registration->has_device_whitelist &&
240             op_registration->device_whitelist.find(backend.first) ==
241                 op_registration->device_whitelist.end()) {
242           continue;
243         }
244 
245         // If the operator does NOT has a device whitelist, skip all devices
246         // that has already been registered.
247         if (!op_registration->has_device_whitelist &&
248             whitelisted_backend.find(backend.first) !=
249                 whitelisted_backend.end()) {
250           continue;
251         }
252 
253         std::unique_ptr<KernelDef> kdef(new KernelDef);
254         kdef->set_op(op_registration->name);
255         kdef->set_device_type(backend.first);
256 
257         // Constrain each type attribute to the intersection of:
258         // a) the types supported by the backend, and
259         // b) the types allowed by the OpDef, and
260         // c) the type constraints.
261         bool unsatisfiable_type_constraint = false;
262         for (const string& type_attr : type_attrs) {
263           KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
264           attr_constraint->set_name(type_attr);
265           auto* allowed_values =
266               attr_constraint->mutable_allowed_values()->mutable_list();
267 
268           const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def);
269           const auto* op_def_allowed_types =
270               op_def_attr.has_allowed_values()
271                   ? &op_def_attr.allowed_values().list().type()
272                   : nullptr;
273           auto constraint_it =
274               op_registration->type_constraints.find(type_attr);
275           const std::set<DataType>* type_constraints =
276               constraint_it != op_registration->type_constraints.end()
277                   ? &constraint_it->second
278                   : nullptr;
279           for (DataType dtype : backend.second.supported_types) {
280             // Filter out types that aren't allowed by the OpDef.
281             if (op_def_allowed_types != nullptr &&
282                 std::find(op_def_allowed_types->begin(),
283                           op_def_allowed_types->end(),
284                           dtype) == op_def_allowed_types->end()) {
285               continue;
286             }
287             // Filter out types based on the type constraints.
288             if (type_constraints != nullptr &&
289                 type_constraints->find(dtype) == type_constraints->end()) {
290               continue;
291             }
292             // Passed all the filters, this type is allowed.
293             allowed_values->add_type(dtype);
294           }
295           if (op_registration->allow_resource_types) {
296             allowed_values->add_type(DT_RESOURCE);
297           }
298           if (op_registration->allow_variant_types) {
299             allowed_values->add_type(DT_VARIANT);
300           }
301           // Don't build KernelDefs that have unsatisfiable type constraints.
302           if (allowed_values->type().empty()) {
303             unsatisfiable_type_constraint = true;
304             break;
305           }
306         }
307         if (unsatisfiable_type_constraint) continue;
308 
309         if (backend.second.op_filter != nullptr &&
310             !backend.second.op_filter(kdef.get())) {
311           continue;
312         }
313         VLOG(2) << "XLA op registration: device: " << backend.first
314                 << " op: " << op_name;
315         registry.kernel_registrars_.emplace_back(
316             new kernel_factory::OpKernelRegistrar(
317                 new KernelDef(*kdef), "XlaJitOp", op_registration->factory));
318         backend.second.kernel_defs.push_back(std::move(kdef));
319       }
320     }
321   }
322 }
323 
DeviceKernels(const string & compilation_device_name,bool include_compilation_only_kernels)324 std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
325     const string& compilation_device_name,
326     bool include_compilation_only_kernels) {
327   // Ensure compilation kernels registered.
328   RegisterCompilationKernels();
329   std::vector<const KernelDef*> kernels;
330   XlaOpRegistry& registry = Instance();
331   mutex_lock lock(registry.mutex_);
332   auto it = registry.backends_.find(compilation_device_name);
333   CHECK(it != registry.backends_.end())
334       << "Unknown backend " << compilation_device_name;
335   for (const std::unique_ptr<KernelDef>& k : it->second.kernel_defs) {
336     auto op_iter = registry.ops_.find(k->op());
337     CHECK(op_iter != registry.ops_.end() && !op_iter->second.empty());
338     // The test in IsCompatible ensures that if there are multiple matching
339     // registrations for this op name, they all have the same value of
340     // compilation_only, so only the first match needs to be tested.
341     if (include_compilation_only_kernels ||
342         !op_iter->second.front()->compilation_only) {
343       kernels.push_back(k.get());
344     }
345   }
346   return kernels;
347 }
348 
GetAllRegisteredOps()349 /*static*/ std::vector<string> XlaOpRegistry::GetAllRegisteredOps() {
350   std::vector<string> ops;
351   XlaOpRegistry& registry = Instance();
352   mutex_lock lock(registry.mutex_);
353   for (const auto& pair : registry.ops_) {
354     ops.push_back(pair.first);
355   }
356   std::sort(ops.begin(), ops.end());
357   return ops;
358 }
359 
CompileTimeConstantInputs(const NodeDef & node_def,const OpKernel * op_kernel,const OpDef * op_def,std::vector<int> * result)360 /* static */ Status XlaOpRegistry::CompileTimeConstantInputs(
361     const NodeDef& node_def, const OpKernel* op_kernel, const OpDef* op_def,
362     std::vector<int>* result) {
363   result->clear();
364 
365   DCHECK(op_def != nullptr || op_kernel != nullptr);
366 
367   std::unordered_set<string> compile_time_constant_inputs_from_attr;
368   std::vector<string> compile_time_constant_inputs_vect_from_attr;
369 
370   const std::unordered_set<string>* compile_time_constant_inputs;
371 
372   if (GetNodeAttr(node_def, kXlaCompileTimeConstantInputsAttr,
373                   &compile_time_constant_inputs_vect_from_attr)
374           .ok()) {
375     absl::c_copy(compile_time_constant_inputs_vect_from_attr,
376                  std::inserter(compile_time_constant_inputs_from_attr,
377                                compile_time_constant_inputs_from_attr.end()));
378     compile_time_constant_inputs = &compile_time_constant_inputs_from_attr;
379   } else {
380     const string& op = node_def.op();
381 
382     XlaOpRegistry& registry = Instance();
383     mutex_lock lock(registry.mutex_);
384     auto it = registry.ops_.find(op);
385     if (it == registry.ops_.end() || it->second.empty()) {
386       return Status::OK();
387     } else {
388       // The test in IsCompatible ensures that if there are multiple matching
389       // registrations for this op name, they all have the same value of
390       // compile_time_constant_inputs, so only the first match is returned.
391       //
392       // TODO(sanjoy): This can probably be a std::vector<string>.
393       compile_time_constant_inputs =
394           &it->second.front()->compile_time_constant_inputs;
395     }
396   }
397 
398   for (const string& input : *compile_time_constant_inputs) {
399     if (op_def) {
400       NameRangeMap input_name_ranges;
401       TF_RETURN_IF_ERROR(
402           NameRangesForNode(node_def, *op_def, &input_name_ranges, nullptr));
403       auto name_range = input_name_ranges.find(input);
404       if (name_range == input_name_ranges.end()) {
405         continue;
406       }
407 
408       for (int i = name_range->second.first; i < name_range->second.second;
409            i++) {
410         result->push_back(i);
411       }
412     } else {
413       int start, stop;
414       TF_CHECK_OK(op_kernel->InputRange(input, &start, &stop));
415       for (int i = start; i < stop; ++i) {
416         result->push_back(i);
417       }
418     }
419   }
420 
421   absl::c_sort(*result);
422   return Status::OK();
423 }
424 
IsMetadataOp(const string & op)425 /*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) {
426   XlaOpRegistry& registry = Instance();
427   mutex_lock lock(registry.mutex_);
428   auto it = registry.ops_.find(op);
429   if (it == registry.ops_.end() || it->second.empty()) {
430     return false;
431   }
432 
433   // The test in IsCompatible ensures that if there are multiple matching
434   // registrations for this op name, they all have the same value of
435   // is_metadata_op, so only the first match is returned.
436   return it->second.front()->is_metadata_op;
437 }
438 
BackendNames()439 std::vector<string> XlaOpRegistry::BackendNames() {
440   std::vector<string> names;
441   XlaOpRegistry& registry = Instance();
442   mutex_lock lock(registry.mutex_);
443   for (const auto& backend_pair : registry.backends_) {
444     names.push_back(backend_pair.first);
445   }
446   return names;
447 }
448 
IsBackendRegistered(const string & name)449 bool XlaOpRegistry::IsBackendRegistered(const string& name) {
450   XlaOpRegistry& registry = Instance();
451   mutex_lock lock(registry.mutex_);
452   return registry.backends_.find(name) != registry.backends_.end();
453 }
454 
Instance()455 XlaOpRegistry& XlaOpRegistry::Instance() {
456   static XlaOpRegistry* r = new XlaOpRegistry;
457   return *r;
458 }
459 
XlaOpRegistrationBuilder(absl::string_view name)460 XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(absl::string_view name) {
461   registration_.reset(new XlaOpRegistry::OpRegistration);
462   registration_->name = string(name);
463 }
464 
Name(absl::string_view name)465 XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(
466     absl::string_view name) {
467   XlaOpRegistrationBuilder registration(name);
468   return registration;
469 }
470 
Device(absl::Span<const absl::string_view> devices)471 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
472     absl::Span<const absl::string_view> devices) {
473   registration_->has_device_whitelist = true;
474   for (absl::string_view device : devices) {
475     registration_->device_whitelist.emplace(device);
476   }
477   return *this;
478 }
479 
Device(absl::string_view device)480 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
481     absl::string_view device) {
482   registration_->has_device_whitelist = true;
483   registration_->device_whitelist.emplace(device);
484   return *this;
485 }
486 
CompilationOnly()487 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompilationOnly() {
488   registration_->compilation_only = true;
489   return *this;
490 }
491 
AllowResourceTypes()492 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() {
493   registration_->allow_resource_types = true;
494   return *this;
495 }
496 
AllowVariantTypes()497 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowVariantTypes() {
498   registration_->allow_variant_types = true;
499   return *this;
500 }
501 
TypeConstraint(absl::string_view attr_name,DataType allowed)502 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
503     absl::string_view attr_name, DataType allowed) {
504   std::set<DataType>& types =
505       registration_->type_constraints[string(attr_name)];
506   types.insert(allowed);
507   return *this;
508 }
509 
TypeConstraint(absl::string_view attr_name,absl::Span<const DataType> allowed)510 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
511     absl::string_view attr_name, absl::Span<const DataType> allowed) {
512   std::set<DataType>& types =
513       registration_->type_constraints[string(attr_name)];
514   for (DataType t : allowed) {
515     types.insert(t);
516   }
517   return *this;
518 }
519 
CompileTimeConstantInput(absl::string_view input_name)520 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstantInput(
521     absl::string_view input_name) {
522   registration_->compile_time_constant_inputs.emplace(input_name);
523   return *this;
524 }
525 
IsMetadataOp()526 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() {
527   registration_->is_metadata_op = true;
528   return *this;
529 }
530 
Build(XlaOpRegistry::Factory factory)531 std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
532     XlaOpRegistry::Factory factory) {
533   registration_->factory = factory;
534   return std::move(registration_);
535 }
536 
XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration)537 XlaOpRegistrar::XlaOpRegistrar(
538     std::unique_ptr<XlaOpRegistry::OpRegistration> registration) {
539   XlaOpRegistry& registry = XlaOpRegistry::Instance();
540   mutex_lock lock(registry.mutex_);
541   auto& existing_ops = registry.ops_[registration->name];
542   for (auto& existing : existing_ops) {
543     if (!XlaOpRegistry::IsCompatible(*existing, *registration)) {
544       LOG(FATAL)
545           << "XLA op registration " << registration->name
546           << " is incompatible with existing registration of the same name.";
547     }
548   }
549   existing_ops.emplace_back(std::move(registration));
550 }
551 
XlaBackendRegistrar(absl::string_view name,absl::Span<const DataType> types,XlaOpRegistry::BackendOpFilter op_filter)552 XlaBackendRegistrar::XlaBackendRegistrar(
553     absl::string_view name, absl::Span<const DataType> types,
554     XlaOpRegistry::BackendOpFilter op_filter) {
555   XlaOpRegistry& registry = XlaOpRegistry::Instance();
556   registry.RegisterBackend(string(name), types, op_filter);
557 }
558 
559 }  // namespace tensorflow
560