1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/xla_device.h"
22 #include "tensorflow/compiler/xla/service/computation_placer.h"
23 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/compiler/xrt/xrt.pb.h"
29 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
30 #include "tensorflow/compiler/xrt/xrt_metrics.h"
31 #include "tensorflow/compiler/xrt/xrt_state.h"
32 #include "tensorflow/compiler/xrt/xrt_util.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/resource_mgr.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/lib/core/refcount.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/lib/monitoring/timed.h"
39 #include "tensorflow/core/platform/casts.h"
40 #include "tensorflow/core/platform/types.h"
41 #include "tensorflow/core/profiler/lib/traceme.h"
42 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
43 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
44 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
45 #include "tensorflow/core/tpu/tpu_configuration.h"
46 #include "tensorflow/core/tpu/tpu_defs.h"
47 #include "tensorflow/core/tpu/tpu_execute.h"
48 #include "tensorflow/stream_executor/stream_executor.h"
49 #include "tensorflow/stream_executor/stream_executor_internal.h"
50 
51 namespace tensorflow {
52 namespace {
53 
54 using tensorflow::tpu::CompilationCacheEntryRef;
55 using tensorflow::tpu::TpuCompilationCacheEntry;
56 using tensorflow::tpu::TpuCompilationCacheLookup;
57 using GetBufferFunction =
58     std::function<xla::StatusOr<std::vector<xla::ExecutionInput>>()>;
59 
60 // Looks up the input `key` in the compilation cache.
GetComputationCacheEntry(ResourceMgr * rm,int64 key,int core_index_in_replica,std::unique_ptr<CompilationCacheEntryRef> * entry)61 Status GetComputationCacheEntry(
62     ResourceMgr* rm, int64 key, int core_index_in_replica,
63     std::unique_ptr<CompilationCacheEntryRef>* entry) {
64   profiler::TraceMe trace_me("XRTExecuteOp::LookupProto", /*level=*/2);
65   TpuCompilationCacheLookup* proto_lookup;
66   TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
67                                 tpu::kCompiledProtoCacheResourceName,
68                                 &proto_lookup));
69   core::ScopedUnref lookup_unref(proto_lookup);
70   TF_RETURN_IF_ERROR(proto_lookup->Lookup(key, core_index_in_replica, entry));
71   return Status::OK();
72 }
73 
GetDynamicInputInfo(const TPUExecutableInfoProto & executable_proto)74 std::vector<bool> GetDynamicInputInfo(
75     const TPUExecutableInfoProto& executable_proto) {
76   std::vector<bool> input_is_dynamic;
77   input_is_dynamic.reserve(executable_proto.input_shapes().size());
78   for (int64 i = 0; i < executable_proto.input_shapes().size(); ++i) {
79     input_is_dynamic.push_back(
80         !xla::Shape(executable_proto.input_shapes(i)).is_static());
81   }
82   return input_is_dynamic;
83 }
84 
GetChainedOpInputs(const xrt::XRTChainedExecuteOp & op,absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs,const TPUExecutableInfoProto & executable_proto)85 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetChainedOpInputs(
86     const xrt::XRTChainedExecuteOp& op,
87     absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs,
88     const TPUExecutableInfoProto& executable_proto) {
89   if (op.inputs_size() != executable_proto.input_shapes_size()) {
90     return errors::InvalidArgument(
91         "Number of inputs does not match executable proto input shapes: ",
92         op.inputs_size(), " vs. ", executable_proto.input_shapes_size());
93   }
94 
95   std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
96   input_tuples.reserve(op.inputs_size());
97   for (int i = 0; i < op.inputs_size(); ++i) {
98     auto& input = op.inputs(i);
99     const RefPtr<XRTTupleAllocation>& tuple = op_inputs[i];
100     // Thanks to the greatness of proto3, there is no way to query for
101     // explicitly set fields, so the default for output_index (zero) means no
102     // sub-index. As consequence, the real index is output_index - 1.
103     if (input.output_index() == 0) {
104       input_tuples.push_back(tuple);
105     } else {
106       XRTTupleAllocation* sub_tuple;
107       TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
108           tuple.get(), {input.output_index() - 1}, &sub_tuple,
109           /*alias_parent_allocation=*/true));
110       input_tuples.emplace_back(sub_tuple);
111     }
112     if (!InputShapeMatches(xla::Shape(executable_proto.input_shapes(i)),
113                            input_tuples.back()->on_host_shape())) {
114       return errors::InvalidArgument(
115           "Run-time shape mismatch for XRTExecute argument[", i, "] (",
116           op.computation_handle(), "). Expected ",
117           executable_proto.input_shapes(i).DebugString(), "; got ",
118           tuple->on_host_shape().DebugString());
119     }
120   }
121   return std::move(input_tuples);
122 }
123 
GetExecutableAliasConfig(const tpu::TpuProgramGroup * tpu_program_group,xla::Backend * const backend,int core_index)124 xla::StatusOr<xla::HloInputOutputAliasConfig> GetExecutableAliasConfig(
125     const tpu::TpuProgramGroup* tpu_program_group, xla::Backend* const backend,
126     int core_index) {
127   const TPUExecutableInfoProto& executable =
128       tpu_program_group->executable_info(core_index);
129   return xla::HloInputOutputAliasConfig::CreateFromProto(
130       backend->transfer_manager()->HostShapeToDeviceShape(
131           xla::Shape(executable.output_shape())),
132       tpu_program_group->hlo_metadata(core_index)
133           ->hlo_module()
134           .input_output_alias());
135 }
136 
AllocateOutputTuple(tpu::TpuNodeContext * node_context,se::Stream * stream,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias,xla::ScopedShapedBuffer output_scoped_buffer,int device_ordinal)137 xla::StatusOr<RefPtr<XRTTupleAllocation>> AllocateOutputTuple(
138     tpu::TpuNodeContext* node_context, se::Stream* stream,
139     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
140     const xla::HloInputOutputAliasConfig& input_output_alias,
141     xla::ScopedShapedBuffer output_scoped_buffer, int device_ordinal) {
142   auto output_shaped_buffer = output_scoped_buffer.release();
143 
144   xla::Shape output_device_shape = output_shaped_buffer.on_device_shape();
145   if (!output_device_shape.is_static()) {
146     TF_RETURN_IF_ERROR(
147         node_context->backend()->transfer_manager()->ReadDynamicShapes(
148             stream, &output_shaped_buffer, &output_device_shape));
149   }
150 
151   XRTTupleAllocation* output_tuple;
152   xla::Shape output_host_shape =
153       xla::ShapeUtil::DeviceShapeToHostShape(output_device_shape);
154 
155   TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
156       output_shaped_buffer, output_host_shape, output_device_shape,
157       node_context->backend(), device_ordinal, &output_tuple));
158   RefPtr<XRTTupleAllocation> output_tuple_ptr(output_tuple);
159 
160   // If the input tuples had to release some buffers in order to provide the
161   // proper temporary ownership transfer, we patch the holes here by alising the
162   // buffers from the result tuple. The device address we patch back here, will
163   // essentially be the same one we carved out in the DoWork() function.
164   TF_RETURN_IF_ERROR(
165       RebuildOutputAliases(output_tuple_ptr, input_tuples, input_output_alias));
166 
167   return std::move(output_tuple_ptr);
168 }
169 
AllocateOutputTensors(OpKernelContext * context,XRTMemoryManager * memory_manager,tpu::TpuNodeContext * node_context,se::Stream * stream,const xrt::XRTExecutionConfig & config_proto,const TPUExecutableInfoProto & executable_proto,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias,xla::ScopedShapedBuffer output_scoped_buffer,int device_ordinal)170 Status AllocateOutputTensors(
171     OpKernelContext* context, XRTMemoryManager* memory_manager,
172     tpu::TpuNodeContext* node_context, se::Stream* stream,
173     const xrt::XRTExecutionConfig& config_proto,
174     const TPUExecutableInfoProto& executable_proto,
175     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
176     const xla::HloInputOutputAliasConfig& input_output_alias,
177     xla::ScopedShapedBuffer output_scoped_buffer, int device_ordinal) {
178   TF_ASSIGN_OR_RETURN(
179       RefPtr<XRTTupleAllocation> output_tuple,
180       AllocateOutputTuple(node_context, stream, input_tuples,
181                           input_output_alias, std::move(output_scoped_buffer),
182                           device_ordinal));
183   return CreateExecuteOutput(context, memory_manager, std::move(output_tuple),
184                              config_proto.return_exploded_tuple());
185 }
186 
RunExecutable(OpKernelContext * context,tpu::TpuNodeContext * node_context,const TPUExecutableInfoProto & executable,std::vector<xla::ExecutionInput> arguments,const string & execution_id,const uint32 rng_seed,const tpu::TpuProgramGroup * tpu_program_group,xla::Backend * const backend,se::Stream * stream,int core_index,int device_ordinal,string rendezvous_key_base)187 xla::StatusOr<xla::ExecutionOutput> RunExecutable(
188     OpKernelContext* context, tpu::TpuNodeContext* node_context,
189     const TPUExecutableInfoProto& executable,
190     std::vector<xla::ExecutionInput> arguments, const string& execution_id,
191     const uint32 rng_seed, const tpu::TpuProgramGroup* tpu_program_group,
192     xla::Backend* const backend, se::Stream* stream, int core_index,
193     int device_ordinal, string rendezvous_key_base) {
194   profiler::TraceMe trace_me("RunExecutable", /*level=*/2);
195 
196   // se::StreamExecutor* executor = node->stream_executor();
197 
198   std::unique_ptr<xla::DeviceAssignment> device_assignment;
199   if (executable.has_device_assignment()) {
200     TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
201                                                executable.device_assignment()));
202   }
203   // Ideally this should be the host-to-device stream from XlaDeviceContext.
204   // The particular anti-dependency this is avoiding (why we need a separate
205   // transfer stream) is between the executable writing tuple tables and
206   // TPUExecute()'s deregister_stream; if they come from the same stream pool
207   // antidependencies will occur. XlaBackend has a different pool of streams
208   // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these
209   // will never refer to the same stream.
210   TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
211                       backend->BorrowStream(device_ordinal));
212   const TPUHostTransferInfoProto& host_transfer_info =
213       tpu_program_group->host_transfer_info(core_index);
214   TF_ASSIGN_OR_RETURN(
215       xla::ExecutionOutput output,
216       TPUExecute(executable, host_transfer_info,
217                  *tpu_program_group->hlo_metadata(core_index),
218                  std::move(arguments), rendezvous_key_base, rng_seed,
219                  node_context, device_assignment.get(),
220                  context->cancellation_manager(), context, stream,
221                  transfer_stream_ptr.get(),
222                  tpu_program_group->tpu_program(core_index)));
223 
224   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
225 
226   return output;
227 }
228 
ExecuteTPUProgram(OpKernelContext * context,tpu::TpuNodeContext * node_context,XRTMemoryManager * memory_manager,const TPUExecutableInfoProto & executable,const GetBufferFunction & get_buffers_fn,const string & execution_id,const uint32 rng_seed,const tpu::TpuProgramGroup * tpu_program_group,xla::Backend * const backend,se::Stream * stream,int core_index,int device_ordinal,string rendezvous_key_base)229 xla::StatusOr<xla::ExecutionOutput> ExecuteTPUProgram(
230     OpKernelContext* context, tpu::TpuNodeContext* node_context,
231     XRTMemoryManager* memory_manager, const TPUExecutableInfoProto& executable,
232     const GetBufferFunction& get_buffers_fn, const string& execution_id,
233     const uint32 rng_seed, const tpu::TpuProgramGroup* tpu_program_group,
234     xla::Backend* const backend, se::Stream* stream, int core_index,
235     int device_ordinal, string rendezvous_key_base) {
236   auto runfn = [&]() -> xla::StatusOr<xla::ExecutionOutput> {
237     TF_ASSIGN_OR_RETURN(auto arguments, get_buffers_fn());
238     return RunExecutable(context, node_context, executable,
239                          std::move(arguments), execution_id, rng_seed,
240                          tpu_program_group, backend, stream, core_index,
241                          device_ordinal, rendezvous_key_base);
242   };
243   return memory_manager->Run<xla::ExecutionOutput>(
244       runfn, backend, device_ordinal, /*requested_free_size=*/0);
245 }
246 
247 // XRTExecuteOp
248 
249 class XRTExecuteOp : public AsyncOpKernel {
250  public:
251   explicit XRTExecuteOp(OpKernelConstruction* context);
252 
253   void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
254 
255  private:
256   Status DoWork(OpKernelContext* context);
257 };
258 
XRTExecuteOp(OpKernelConstruction * context)259 XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context)
260     : AsyncOpKernel(context, /* is_deferred = */ true) {}
261 
ComputeAsync(OpKernelContext * context,DoneCallback done)262 void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
263   // Schedule onto the default queue, for unbounded concurrency. See b/73520706
264   OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
265   done();
266 }
267 
DoWork(OpKernelContext * context)268 Status XRTExecuteOp::DoWork(OpKernelContext* context) {
269   VLOG(1) << "XRTExecuteOp::Compute";
270 
271   const XlaDevice::Metadata* metadata;
272   TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
273   const int device_ordinal = metadata->device_ordinal();
274   // We are guaranteed that the object underlying TpuNodeContext won't be
275   // deleted out from under us, while node_context is alive.
276   TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_context,
277                       tpu::TpuNodeContext::Create(device_ordinal));
278   xla::Backend* const backend = node_context->backend();
279   se::Stream* stream = context->op_device_context()->stream();
280 
281   auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteCell());
282   profiler::TraceMe trace_me(
283       [context] {
284         return profiler::TraceMeEncode("TpuExecuteOp",
285                                        {{"step_id", context->step_id()}});
286       },
287       /*level=*/2);
288   profiler::TraceMe trace_me_init("XRTExecuteOp::Init", /*level=*/2);
289 
290   auto* rm = GetTPUConfigResourceMgr();
291   TF_RET_CHECK(rm != nullptr);
292 
293   const Tensor& execution_input = context->input(0);
294   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape()));
295   int64 compilation_handle = execution_input.scalar<int64>()();
296 
297   const Tensor& execution_config = context->input(1);
298   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
299   xrt::XRTExecutionConfig config_proto;
300   TF_RET_CHECK(
301       config_proto.ParseFromString(execution_config.scalar<tstring>()()));
302 
303   int core_index_in_replica = config_proto.core_index_in_replica();
304   bool release_inputs = config_proto.release_input_handles();
305   bool release_compilation = config_proto.release_compilation_handle();
306 
307   string rendezvous_key_base = std::to_string(compilation_handle);
308   std::unique_ptr<CompilationCacheEntryRef> entry;
309   TF_RETURN_IF_ERROR(GetComputationCacheEntry(rm, compilation_handle,
310                                               core_index_in_replica, &entry));
311 
312   TpuCompilationCacheEntry centry = entry->get();
313   const tpu::TpuProgramGroup* tpu_program_group =
314       tensorflow::down_cast<const tpu::TpuProgramGroup*>(
315           centry.tpu_program_group());
316   CHECK_NE(tpu_program_group, nullptr);
317 
318   if (release_compilation) {
319     // Process-wide cache of Tpu executables.
320     tpu::TpuCompilationCacheInterface* cache;
321     TF_RETURN_IF_ERROR(rm->Lookup<tpu::TpuCompilationCacheInterface>(
322         rm->default_container(), tpu::kCompilationCacheResourceName, &cache));
323     core::ScopedUnref cache_unref(cache);
324     TF_RETURN_IF_ERROR(cache->Release(compilation_handle));
325     VLOG(2) << "Released compilation handle " << compilation_handle;
326   }
327 
328   const int core_index = centry.core_index();
329   const TPUExecutableInfoProto& executable =
330       tpu_program_group->executable_info(core_index);
331 
332   std::vector<bool> input_is_dynamic = GetDynamicInputInfo(executable);
333 
334   TF_ASSIGN_OR_RETURN(
335       xla::HloInputOutputAliasConfig input_output_alias,
336       GetExecutableAliasConfig(tpu_program_group, backend, core_index));
337   TF_ASSIGN_OR_RETURN(std::vector<InputCoords> input_coords,
338                       GetComputationInputs(context, "input_handles"));
339 
340   RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
341   XRTMemoryManager::WorkingSet working_set(memory_manager);
342   TF_ASSIGN_OR_RETURN(
343       std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
344       GetInputTupleAllocations(
345           input_coords, &working_set, backend, executable.input_shapes_size(),
346           [&](int64 i) { return xla::Shape(executable.input_shapes(i)); },
347           release_inputs));
348   auto get_buffers_fn = [&]() {
349     return GetArgumentsBuffers(input_output_alias, input_tuples,
350                                input_is_dynamic, release_inputs);
351   };
352   trace_me_init.Stop();
353 
354   TF_ASSIGN_OR_RETURN(
355       xla::ExecutionOutput output,
356       ExecuteTPUProgram(
357           context, node_context.get(), memory_manager.get(), executable,
358           get_buffers_fn, config_proto.execution_instance_key(),
359           config_proto.rng_seed(), tpu_program_group, backend, stream,
360           core_index, device_ordinal, rendezvous_key_base));
361 
362   // AllocateComputationOutput writes the output tuple handle to the output
363   // tensor return value from the Op.
364   TF_RETURN_IF_ERROR(AllocateOutputTensors(
365       context, memory_manager.get(), node_context.get(), stream, config_proto,
366       executable, input_tuples, input_output_alias, output.ConsumeResult(),
367       device_ordinal));
368   return Status::OK();
369 }
370 
371 class XRTExecuteChainedOp : public AsyncOpKernel {
372  public:
373   explicit XRTExecuteChainedOp(OpKernelConstruction* context);
374 
375   void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
376 
377  private:
378   Status DoWork(OpKernelContext* context);
379 };
380 
XRTExecuteChainedOp(OpKernelConstruction * context)381 XRTExecuteChainedOp::XRTExecuteChainedOp(OpKernelConstruction* context)
382     : AsyncOpKernel(context, /* is_deferred = */ true) {}
383 
ComputeAsync(OpKernelContext * context,DoneCallback done)384 void XRTExecuteChainedOp::ComputeAsync(OpKernelContext* context,
385                                        DoneCallback done) {
386   // Schedule onto the default queue, for unbounded concurrency. See b/73520706
387   OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
388   done();
389 }
390 
DoWork(OpKernelContext * context)391 Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
392   VLOG(1) << "XRTExecuteChainedOp::Compute";
393   const XlaDevice::Metadata* metadata;
394   TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
395   const int device_ordinal = metadata->device_ordinal();
396   // We are guaranteed that the object underlying TpuNodeContext won't be
397   // deleted out from under us, while node_context is alive.
398   TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_context,
399                       tpu::TpuNodeContext::Create(device_ordinal));
400   xla::Backend* const backend = node_context->backend();
401   se::Stream* stream = context->op_device_context()->stream();
402   auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteChainedCell());
403   profiler::TraceMe trace_me(
404       [context] {
405         return profiler::TraceMeEncode("TpuExecuteChainedOp",
406                                        {{"step_id", context->step_id()}});
407       },
408       /*level=*/2);
409   ResourceMgr* rm = GetTPUConfigResourceMgr();
410   TF_RET_CHECK(rm != nullptr);
411 
412   const Tensor& execution_plan = context->input(0);
413   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape()));
414   xrt::XRTChainedExecutePlan plan;
415   TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar<tstring>()()));
416 
417   const Tensor& execution_config = context->input(1);
418   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
419   xrt::XRTChainedExecuteConfig config;
420   TF_RET_CHECK(config.ParseFromString(execution_config.scalar<tstring>()()));
421 
422   TpuCompilationCacheLookup* proto_lookup;
423   TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
424                                 tpu::kCompiledProtoCacheResourceName,
425                                 &proto_lookup));
426   core::ScopedUnref lookup_unref(proto_lookup);
427   RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
428   auto execute_op = [&](const xrt::XRTChainedExecuteOp& op,
429                         absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs)
430       -> xla::StatusOr<RefPtr<XRTTupleAllocation>> {
431     std::unique_ptr<CompilationCacheEntryRef> entry;
432     TF_RETURN_IF_ERROR(proto_lookup->Lookup(
433         op.computation_handle(), config.core_index_in_replica(), &entry));
434     string rendezvous_key_base = std::to_string(op.computation_handle());
435     TpuCompilationCacheEntry centry = entry->get();
436     const tpu::TpuProgramGroup* tpu_program_group =
437         tensorflow::down_cast<const tpu::TpuProgramGroup*>(
438             centry.tpu_program_group());
439     CHECK_NE(tpu_program_group, nullptr);
440     const int core_index = centry.core_index();
441     const TPUExecutableInfoProto& executable =
442         tpu_program_group->executable_info(core_index);
443     std::vector<bool> input_is_dynamic = GetDynamicInputInfo(executable);
444 
445     TF_ASSIGN_OR_RETURN(
446         xla::HloInputOutputAliasConfig input_output_alias,
447         GetExecutableAliasConfig(tpu_program_group, backend, core_index));
448     TF_ASSIGN_OR_RETURN(std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
449                         GetChainedOpInputs(op, op_inputs, executable));
450     auto get_buffers_fn = [&]() {
451       return GetArgumentsBuffers(input_output_alias, input_tuples,
452                                  input_is_dynamic,
453                                  /*release_inputs=*/false);
454     };
455     TF_ASSIGN_OR_RETURN(
456         xla::ExecutionOutput output,
457         ExecuteTPUProgram(context, node_context.get(), memory_manager.get(),
458                           executable, get_buffers_fn,
459                           config.execution_instance_key(), config.rng_seed(),
460                           tpu_program_group, backend, stream, core_index,
461                           device_ordinal, rendezvous_key_base));
462     return AllocateOutputTuple(node_context.get(), stream, input_tuples,
463                                input_output_alias, output.ConsumeResult(),
464                                device_ordinal);
465   };
466 
467   return ExecuteChained(context, memory_manager, backend, device_ordinal, plan,
468                         config, execute_op);
469 }
470 
471 }  // namespace
472 
473 REGISTER_KERNEL_BUILDER(Name("XRTExecute")
474                             .Device(DEVICE_TPU_NODE)
475                             .HostMemory("computation_handle")
476                             .HostMemory("execution_config")
477                             .HostMemory("input_handles")
478                             .HostMemory("output_handle"),
479                         XRTExecuteOp);
480 
481 REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained")
482                             .Device(DEVICE_TPU_NODE)
483                             .HostMemory("execution_plan")
484                             .HostMemory("execution_config")
485                             .HostMemory("output_handle"),
486                         XRTExecuteChainedOp);
487 
488 }  // namespace tensorflow
489