1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_device.h"
17 
18 #include <stdlib.h>
19 #include <unordered_set>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
24 #include "tensorflow/compiler/jit/xla_device_context.h"
25 #include "tensorflow/compiler/jit/xla_device_ops.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/client/client_library.h"
29 #include "tensorflow/compiler/xla/service/stream_pool.h"
30 #include "tensorflow/core/common_runtime/device.h"
31 #include "tensorflow/core/common_runtime/device_factory.h"
32 #include "tensorflow/core/common_runtime/dma_helper.h"
33 #include "tensorflow/core/common_runtime/function.h"
34 #include "tensorflow/core/common_runtime/renamed_device.h"
35 #include "tensorflow/core/framework/allocator.h"
36 #include "tensorflow/core/framework/device_base.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/kernel_def.pb.h"
39 #include "tensorflow/core/framework/node_def_builder.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor.pb.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/graph/graph_constructor.h"
45 #include "tensorflow/core/lib/core/notification.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
49 #include "tensorflow/core/platform/tracing.h"
50 #include "tensorflow/core/public/session_options.h"
51 #include "tensorflow/core/public/version.h"
52 #include "tensorflow/core/util/device_name_utils.h"
53 #include "tensorflow/core/util/dump_graph.h"
54 #include "tensorflow/core/util/ptr_util.h"
55 #include "tensorflow/core/util/stream_executor_util.h"
56 
57 namespace tensorflow {
58 
59 // Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
60 // XlaDeviceAllocator is created on demand and is associated with a
61 // XlaDevice. It outlives the device itself (for instance, the buffer
62 // backing a tensor holds a pointer to the allocator for book-keeping,
63 // and this buffer can outlast the device).
64 class XlaDeviceAllocatorState {
65  public:
66   // Creates or returns a cached XlaDeviceAllocator for a given
67   // backend and device_ordinal.
68   static XlaDeviceAllocator* GetOrCreateXlaDeviceAllocator(
69       const xla::Backend* backend, int device_ordinal);
70 
71  private:
72   // Returns the singleton instance of XlaDeviceAllocatorState.
73   static XlaDeviceAllocatorState& Singleton();
74   XlaDeviceAllocatorState();
75   ~XlaDeviceAllocatorState();
76 
77   mutex allocator_mutex_;  // Guards the singleton allocator state.
78   std::unordered_map<std::pair<const xla::Backend*, int>,
79                      std::unique_ptr<XlaDeviceAllocator>,
80                      hash<std::pair<const xla::Backend*, int>>>
81       allocators_ GUARDED_BY(allocator_mutex_);
82 
83   TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
84 };
85 
Singleton()86 /* static */ XlaDeviceAllocatorState& XlaDeviceAllocatorState::Singleton() {
87   static auto a = new XlaDeviceAllocatorState;
88   return *a;
89 }
90 
91 XlaDeviceAllocatorState::XlaDeviceAllocatorState() = default;
92 XlaDeviceAllocatorState::~XlaDeviceAllocatorState() = default;
93 
GetOrCreateXlaDeviceAllocator(const xla::Backend * backend,int device_ordinal)94 XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
95     const xla::Backend* backend, int device_ordinal) {
96   XlaDeviceAllocatorState& state = Singleton();
97   mutex_lock lock(state.allocator_mutex_);
98 
99   auto it = state.allocators_.find({backend, device_ordinal});
100   if (it != state.allocators_.end()) {
101     return it->second.get();
102   }
103 
104   std::unique_ptr<XlaDeviceAllocator> alloc =
105       absl::make_unique<XlaDeviceAllocator>(
106           backend->stream_executors()[device_ordinal]);
107   XlaDeviceAllocator* alloc_ptr = alloc.get();
108   state.allocators_[{backend, device_ordinal}] = std::move(alloc);
109   return alloc_ptr;
110 }
111 
112 namespace {
113 
114 // Default PaddedShapeFn implementation that simply returns the unpadded
115 // on-device shape. This is accurate for CPU and GPU devices that neither
116 // transpose nor pad tensors.
DefaultPaddedShapeFn(const Tensor & tensor,xla::Shape * shape)117 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
118   const tensorflow::XlaTensor* xla_tensor =
119       tensorflow::XlaTensor::FromTensor(&tensor);
120   if (xla_tensor == nullptr) {
121     return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
122   }
123 
124   const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
125   *shape = shaped_buffer.on_device_shape();
126   return Status::OK();
127 }
128 
BuildXlaDeviceAttributes(const string & name_prefix,const string & device_name,int device_ordinal)129 static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix,
130                                                  const string& device_name,
131                                                  int device_ordinal) {
132   return Device::BuildDeviceAttributes(
133       absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal),
134       DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
135       absl::StrCat("device: ", device_name, " device"));
136 }
137 
138 }  // namespace
139 
Metadata(int device_ordinal,se::Platform * platform,const DeviceType & device_type,XlaCompiler::ShapeRepresentationFn shape_representation_fn,PaddedShapeFn padded_shape_fn,bool use_multiple_streams)140 XlaDevice::Metadata::Metadata(
141     int device_ordinal, se::Platform* platform, const DeviceType& device_type,
142     XlaCompiler::ShapeRepresentationFn shape_representation_fn,
143     PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
144     : device_ordinal_(device_ordinal),
145       device_type_(device_type),
146       platform_(platform),
147       shape_representation_fn_(std::move(shape_representation_fn)),
148       padded_shape_fn_(std::move(padded_shape_fn)),
149       use_multiple_streams_(use_multiple_streams) {}
150 
device_ordinal() const151 int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
152 
platform() const153 se::Platform* XlaDevice::Metadata::platform() const { return platform_; }
154 
client() const155 xla::LocalClient* XlaDevice::Metadata::client() const {
156   auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_);
157   return client.ValueOrDie();
158 }
159 
jit_device_type() const160 const DeviceType& XlaDevice::Metadata::jit_device_type() const {
161   return device_type_;
162 }
163 
GetMetadataFromDevice(DeviceBase * device,const XlaDevice::Metadata ** metadata)164 /*static*/ Status XlaDevice::GetMetadataFromDevice(
165     DeviceBase* device, const XlaDevice::Metadata** metadata) {
166   *metadata = nullptr;
167   XlaDevice* xla_device = dynamic_cast<XlaDevice*>(device->UnderlyingDevice());
168   if (xla_device == nullptr) {
169     return errors::Internal(
170         "Cannot get XLA metadata from non-XLA device \"", device->name(),
171         "\". GetMetadata must only be called on an XLA device. Either an "
172         "internal bug has been triggered, or an XLA-specific op has been "
173         "placed on the wrong device.");
174   }
175   *metadata = &(xla_device->xla_metadata_);
176   return Status::OK();
177 }
178 
GetMetadata(OpKernelContext * ctx,const Metadata ** metadata)179 /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
180                                            const Metadata** metadata) {
181   return GetMetadataFromDevice(ctx->device(), metadata);
182 }
183 
GetMetadata(OpKernelConstruction * ctx,const Metadata ** metadata)184 /* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx,
185                                            const Metadata** metadata) {
186   return GetMetadataFromDevice(ctx->device(), metadata);
187 }
188 
XlaDevice(const SessionOptions & session_options,const Options & options)189 XlaDevice::XlaDevice(const SessionOptions& session_options,
190                      const Options& options)
191     : LocalDevice(session_options,
192                   BuildXlaDeviceAttributes(options.device_name_prefix,
193                                            options.device_name,
194                                            options.device_ordinal)),
195       xla_metadata_(options.device_ordinal, options.platform,
196                     DeviceType(options.compilation_device_name),
197                     options.shape_representation_fn,
198                     options.padded_shape_fn ? options.padded_shape_fn
199                                             : DefaultPaddedShapeFn,
200                     options.use_multiple_streams),
201       device_ordinal_(options.device_ordinal),
202       jit_device_name_(options.compilation_device_name),
203       platform_(options.platform),
204       use_multiple_streams_(options.use_multiple_streams),
205       shape_representation_fn_(options.shape_representation_fn),
206       allowed_devices_(options.allowed_devices) {
207   VLOG(1) << "Created XLA device " << options.compilation_device_name << " "
208           << this;
209   thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device",
210                                             /*num_threads=*/1));
211 
212   // We have multiple device to device streams to allow for some concurrency
213   // between transfers. The particular value of '4' is chosen fairly
214   // arbitrarily. It may be necessary to make this tunable via
215   // XlaDevice::Options.
216   static constexpr int kNumDeviceToDeviceStreams = 4;
217   device_to_device_streams_.resize(kNumDeviceToDeviceStreams);
218 }
219 
~XlaDevice()220 XlaDevice::~XlaDevice() {
221   VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
222   mutex_lock lock(mu_);
223   if (device_context_) {
224     device_context_->Unref();
225   }
226 }
227 
client() const228 xla::LocalClient* XlaDevice::client() const {
229   // We lazily create the client because the platform commits to the
230   // details of the host hardware when the client is created, so we
231   // don't want to do it until we get a chance to hook the platform up
232   // to a simulator.
233 
234   // TODO(b/78468222): This can fail, at least when the backend is GPU and
235   // there is no GPU on the host.
236   return xla::ClientLibrary::GetOrCreateLocalClient(platform_, allowed_devices_)
237       .ValueOrDie();
238 }
239 
GetAllocator(AllocatorAttributes attr)240 Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
241   mutex_lock lock(mu_);
242   return GetAllocatorLocked(attr);
243 }
244 
GetAllocatorLocked(AllocatorAttributes attr)245 Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) {
246   if (attr.on_host()) {
247     return cpu_allocator();
248   }
249 
250   if (xla_allocator_ == nullptr) {
251     xla::Backend* backend = client()->mutable_backend();
252     xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
253         backend, device_ordinal_);
254   }
255   return xla_allocator_;
256 }
257 
EnsureDeviceContextOk()258 Status XlaDevice::EnsureDeviceContextOk() {
259   mutex_lock lock(mu_);
260   return GetDeviceContextLocked().status();
261 }
262 
EnsureStreamOkLocked(xla::Backend * backend,const string & name,std::shared_ptr<se::Stream> * stream,bool * stream_was_changed)263 Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
264                                        const string& name,
265                                        std::shared_ptr<se::Stream>* stream,
266                                        bool* stream_was_changed) {
267   if (!(*stream) || !(*stream)->ok()) {
268     xla::StreamPool::Ptr ptr;
269     TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
270     *stream = std::shared_ptr<se::Stream>(std::move(ptr));
271     VLOG(1) << "XlaDevice " << this << " new " << name << " "
272             << (*stream)->DebugStreamPointers();
273     *stream_was_changed = true;
274   }
275   return Status::OK();
276 }
277 
GetDeviceContextLocked()278 xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
279   xla::Backend* backend = client()->mutable_backend();
280 
281   // Ensure all our streams are valid, borrowing new streams if necessary.
282   bool need_new_device_context = !device_context_;
283   TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
284                                           &need_new_device_context));
285 
286   std::shared_ptr<se::Stream> host_to_device_stream;
287   std::shared_ptr<se::Stream> device_to_host_stream;
288   std::vector<std::shared_ptr<se::Stream>> device_to_device_streams;
289   if (use_multiple_streams_) {
290     TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
291                                             &host_to_device_stream_,
292                                             &need_new_device_context));
293     for (std::shared_ptr<se::Stream>& stream : device_to_device_streams_) {
294       TF_RETURN_IF_ERROR(
295           EnsureStreamOkLocked(backend, "device_to_device_stream", &stream,
296                                &need_new_device_context));
297     }
298     host_to_device_stream = host_to_device_stream_;
299     device_to_device_streams = device_to_device_streams_;
300     // The data transfer requests from device to host could arrive out of order,
301     // so a single stream would cause deadlock. For this case,
302     // xla_device_context would borrow a stream for each transfer request.
303     device_to_host_stream = nullptr;
304   } else {
305     host_to_device_stream = stream_;
306     device_to_host_stream = stream_;
307     device_to_device_streams = {stream_};
308   }
309 
310   if (!need_new_device_context) {
311     return device_context_;
312   }
313 
314   // At this point we know we need a new device context.
315   // Call GetAllocator for the side-effect of ensuring the allocator is created.
316   GetAllocatorLocked({});
317   if (device_context_) {
318     device_context_->Unref();
319   }
320   // The XlaDeviceContext keeps a reference count to the streams, and the
321   // XlaDeviceContext remains live for the duration of a Executor run. This
322   // ensures that the streams remain live for the duration of a run, even if
323   // an error is encountered and the streams are replaced with new ones.
324   device_context_ = new XlaDeviceContext(
325       stream_, std::move(host_to_device_stream),
326       std::move(device_to_host_stream), std::move(device_to_device_streams),
327       client(), shape_representation_fn_, thread_pool_.get());
328   VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
329           << device_context_;
330 
331   // Create and set a new GpuDeviceInfo, if necessary.
332   //
333   // TODO(b/78232898): This isn't thread-safe; there is a race between the call
334   // to set_tensorflow_gpu_device_info() with ops that call the getter
335   // tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking
336   // to those methods; see the bug for details. Our only saving grace at the
337   // moment is that this race doesn't seem to occur in practice.
338   if (use_gpu_device_info_) {
339     auto gpu_device_info = absl::make_unique<GpuDeviceInfo>();
340     gpu_device_info->stream = stream_.get();
341     gpu_device_info->default_context = device_context_;
342     set_tensorflow_gpu_device_info(gpu_device_info.get());
343     gpu_device_info_ = std::move(gpu_device_info);
344     VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo "
345             << gpu_device_info_.get();
346   }
347 
348   return device_context_;
349 }
350 
UseGpuDeviceInfo()351 Status XlaDevice::UseGpuDeviceInfo() {
352   mutex_lock lock(mu_);
353   use_gpu_device_info_ = true;
354   return GetDeviceContextLocked().status();
355 }
356 
FillContextMap(const Graph * graph,DeviceContextMap * device_context_map)357 Status XlaDevice::FillContextMap(const Graph* graph,
358                                  DeviceContextMap* device_context_map) {
359   VLOG(1) << "XlaDevice::FillContextMap";
360   mutex_lock lock(mu_);
361   TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context,
362                       GetDeviceContextLocked());
363 
364   device_context_map->resize(graph->num_node_ids());
365   for (Node* n : graph->nodes()) {
366     VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
367     device_context->Ref();
368     (*device_context_map)[n->id()] = device_context;
369   }
370   return Status::OK();
371 }
372 
Compute(OpKernel * op_kernel,OpKernelContext * context)373 void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
374   VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
375           << op_kernel->type_string();
376   op_kernel->Compute(context);
377 }
378 
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)379 void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
380                              AsyncOpKernel::DoneCallback done) {
381   VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
382           << op_kernel->type_string();
383   tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
384                                    op_kernel->IsExpensive());
385   op_kernel->ComputeAsync(context, done);
386 }
387 
Sync()388 Status XlaDevice::Sync() {
389   VLOG(1) << "XlaDevice::Sync";
390   tracing::ScopedActivity activity("XlaDevice::Sync", /*is_expensive=*/true);
391   std::shared_ptr<se::Stream> stream;
392   {
393     mutex_lock lock(mu_);
394     stream = stream_;
395   }
396   if (!stream) return Status::OK();
397 
398   Status status = stream->BlockHostUntilDone();
399   TF_RETURN_IF_ERROR(status);
400   if (!stream->ok()) {
401     return errors::Internal("XlaDevice::Sync() failed.");
402   }
403   VLOG(1) << "XlaDevice::Sync completed";
404   return Status::OK();
405 }
406 
407 // TODO(b/112409994): This is no longer necessary. Consolidate it with the
408 // synchronous version.
Sync(const DoneCallback & done)409 void XlaDevice::Sync(const DoneCallback& done) {
410   VLOG(1) << "XlaDevice::Sync (asynchronous)";
411   std::shared_ptr<se::Stream> stream;
412   {
413     mutex_lock lock(mu_);
414     stream = stream_;
415   }
416   if (!stream) {
417     done(Status::OK());
418     return;
419   }
420 
421   // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at
422   // the end of the stream, after everything that has already been enqueued
423   // there at this moment. When the host callback is called, everything before
424   // it must have already finished, and the host callback will then place the
425   // task below onto a background thread. (See the implementation of
426   // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done
427   // callback is finally called from that background thread, we know for sure
428   // that everything enqueued onto the stream (i.e., the device) at this very
429   // moment--when ThenEnqueueOnBackgroundThread is called--will have finished.
430   // This achieves a device-wide sync.
431   stream->ThenEnqueueOnBackgroundThread(
432       [stream, done](se::StreamExecutor*) {
433         tracing::ScopedActivity activity("XlaDevice::Sync::Callback",
434                                          /*is_expensive=*/true);
435         done(stream->ok() ? Status::OK()
436                           : errors::Internal("XlaDevice::Sync() failed."));
437       });
438 }
439 
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)440 Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
441                                       const AllocatorAttributes alloc_attrs,
442                                       Tensor* tensor) {
443   VLOG(1) << "XlaDevice::MakeTensorFromProto";
444 
445   Tensor parsed(tensor_proto.dtype());
446   if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
447     return errors::InvalidArgument("Cannot parse tensor from proto: ",
448                                    tensor_proto.DebugString());
449   }
450 
451   Status status;
452   if (alloc_attrs.on_host()) {
453     *tensor = parsed;
454   } else {
455     mutex_lock lock(mu_);
456     TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context,
457                         GetDeviceContextLocked());
458     Allocator* allocator = GetAllocatorLocked(alloc_attrs);
459     Tensor copy(allocator, parsed.dtype(), parsed.shape());
460     Notification n;
461     device_context->CopyCPUTensorToDevice(&parsed, this, &copy,
462                                           [&n, &status](const Status& s) {
463                                             status = s;
464                                             n.Notify();
465                                           });
466     n.WaitForNotification();
467     *tensor = copy;
468   }
469   VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
470   return status;
471 }
472 
SetAllowsSyncOnCompletion(bool sync_on_completion)473 void XlaDevice::SetAllowsSyncOnCompletion(bool sync_on_completion) {
474   mutex_lock lock(mu_);
475   sync_on_completion_ = sync_on_completion;
476 }
477 
AllowsSyncOnCompletion() const478 bool XlaDevice::AllowsSyncOnCompletion() const {
479   mutex_lock lock(mu_);
480   return sync_on_completion_;
481 }
482 
SetHandleDeviceErrorCallback(std::function<Status ()> callback)483 void XlaDevice::SetHandleDeviceErrorCallback(std::function<Status()> callback) {
484   mutex_lock lock(mu_);
485   device_error_callback_ = callback;
486 }
487 
HandleDeviceError()488 Status XlaDevice::HandleDeviceError() {
489   std::function<Status()> local_device_error_callback;
490   {
491     mutex_lock lock(mu_);
492     local_device_error_callback = device_error_callback_;
493   }
494   if (local_device_error_callback != nullptr) {
495     return local_device_error_callback();
496   }
497   return Status::OK();
498 }
499 
RefreshStatus()500 Status XlaDevice::RefreshStatus() {
501   std::shared_ptr<se::Stream> stream;
502   {
503     mutex_lock lock(mu_);
504     stream = stream_;
505   }
506   if (!stream) {
507     return Status::OK();
508   }
509   Status status = stream->RefreshStatus();
510   if (!status.ok()) {
511     // Ignore errors from HandleDeviceError, since by definition the status is
512     // already non-ok, so there's nothing extra to report if HandleDeviceError
513     // itself returns an error.
514     HandleDeviceError().IgnoreError();
515   }
516   return status;
517 }
518 
RegisterXlaDeviceKernels(const char * device,const char * jit_device)519 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
520                                                    const char* jit_device) {
521   // Any op assigned to the device that isn't rewritten by the graph rewriter
522   // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes
523   // it just-in-time.
524   OpKernel* (*factory)(OpKernelConstruction*) =
525       [](OpKernelConstruction* context) -> OpKernel* {
526     return new XlaCompileOnDemandOp(context);
527   };
528   XlaOpRegistry::RegisterCompilationKernels();
529   XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
530   for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(
531            jit_device,
532            /*include_compilation_only_kernels=*/false)) {
533     KernelDef* def = new KernelDef(*jit_def);
534     def->set_device_type(device);
535     registrations->op_kernel_registrars.emplace_back(
536         new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",
537                                               factory));
538   }
539   return registrations;
540 }
541 
542 }  // namespace tensorflow
543