1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/xla_device.h"
22 #include "tensorflow/compiler/xla/service/computation_placer.h"
23 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/compiler/xrt/xrt.pb.h"
29 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
30 #include "tensorflow/compiler/xrt/xrt_metrics.h"
31 #include "tensorflow/compiler/xrt/xrt_state.h"
32 #include "tensorflow/compiler/xrt/xrt_util.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/resource_mgr.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/lib/core/refcount.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/lib/monitoring/timed.h"
39 #include "tensorflow/core/platform/casts.h"
40 #include "tensorflow/core/platform/types.h"
41 #include "tensorflow/core/profiler/lib/traceme.h"
42 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
43 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
44 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
45 #include "tensorflow/core/tpu/tpu_configuration.h"
46 #include "tensorflow/core/tpu/tpu_defs.h"
47 #include "tensorflow/core/tpu/tpu_execute.h"
48 #include "tensorflow/stream_executor/stream_executor.h"
49 #include "tensorflow/stream_executor/stream_executor_internal.h"
50
51 namespace tensorflow {
52 namespace {
53
54 using tensorflow::tpu::CompilationCacheEntryRef;
55 using tensorflow::tpu::TpuCompilationCacheEntry;
56 using tensorflow::tpu::TpuCompilationCacheLookup;
57 using GetBufferFunction =
58 std::function<xla::StatusOr<std::vector<xla::ExecutionInput>>()>;
59
60 // Looks up the input `key` in the compilation cache.
GetComputationCacheEntry(ResourceMgr * rm,int64 key,int core_index_in_replica,std::unique_ptr<CompilationCacheEntryRef> * entry)61 Status GetComputationCacheEntry(
62 ResourceMgr* rm, int64 key, int core_index_in_replica,
63 std::unique_ptr<CompilationCacheEntryRef>* entry) {
64 profiler::TraceMe trace_me("XRTExecuteOp::LookupProto", /*level=*/2);
65 TpuCompilationCacheLookup* proto_lookup;
66 TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
67 tpu::kCompiledProtoCacheResourceName,
68 &proto_lookup));
69 core::ScopedUnref lookup_unref(proto_lookup);
70 TF_RETURN_IF_ERROR(proto_lookup->Lookup(key, core_index_in_replica, entry));
71 return Status::OK();
72 }
73
GetDynamicInputInfo(const TPUExecutableInfoProto & executable_proto)74 std::vector<bool> GetDynamicInputInfo(
75 const TPUExecutableInfoProto& executable_proto) {
76 std::vector<bool> input_is_dynamic;
77 input_is_dynamic.reserve(executable_proto.input_shapes().size());
78 for (int64 i = 0; i < executable_proto.input_shapes().size(); ++i) {
79 input_is_dynamic.push_back(
80 !xla::Shape(executable_proto.input_shapes(i)).is_static());
81 }
82 return input_is_dynamic;
83 }
84
GetChainedOpInputs(const xrt::XRTChainedExecuteOp & op,absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs,const TPUExecutableInfoProto & executable_proto)85 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetChainedOpInputs(
86 const xrt::XRTChainedExecuteOp& op,
87 absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs,
88 const TPUExecutableInfoProto& executable_proto) {
89 if (op.inputs_size() != executable_proto.input_shapes_size()) {
90 return errors::InvalidArgument(
91 "Number of inputs does not match executable proto input shapes: ",
92 op.inputs_size(), " vs. ", executable_proto.input_shapes_size());
93 }
94
95 std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
96 input_tuples.reserve(op.inputs_size());
97 for (int i = 0; i < op.inputs_size(); ++i) {
98 auto& input = op.inputs(i);
99 const RefPtr<XRTTupleAllocation>& tuple = op_inputs[i];
100 // Thanks to the greatness of proto3, there is no way to query for
101 // explicitly set fields, so the default for output_index (zero) means no
102 // sub-index. As consequence, the real index is output_index - 1.
103 if (input.output_index() == 0) {
104 input_tuples.push_back(tuple);
105 } else {
106 XRTTupleAllocation* sub_tuple;
107 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
108 tuple.get(), {input.output_index() - 1}, &sub_tuple,
109 /*alias_parent_allocation=*/true));
110 input_tuples.emplace_back(sub_tuple);
111 }
112 if (!InputShapeMatches(xla::Shape(executable_proto.input_shapes(i)),
113 input_tuples.back()->on_host_shape())) {
114 return errors::InvalidArgument(
115 "Run-time shape mismatch for XRTExecute argument[", i, "] (",
116 op.computation_handle(), "). Expected ",
117 executable_proto.input_shapes(i).DebugString(), "; got ",
118 tuple->on_host_shape().DebugString());
119 }
120 }
121 return std::move(input_tuples);
122 }
123
GetExecutableAliasConfig(const tpu::TpuProgramGroup * tpu_program_group,xla::Backend * const backend,int core_index)124 xla::StatusOr<xla::HloInputOutputAliasConfig> GetExecutableAliasConfig(
125 const tpu::TpuProgramGroup* tpu_program_group, xla::Backend* const backend,
126 int core_index) {
127 const TPUExecutableInfoProto& executable =
128 tpu_program_group->executable_info(core_index);
129 return xla::HloInputOutputAliasConfig::CreateFromProto(
130 backend->transfer_manager()->HostShapeToDeviceShape(
131 xla::Shape(executable.output_shape())),
132 tpu_program_group->hlo_metadata(core_index)
133 ->hlo_module()
134 .input_output_alias());
135 }
136
AllocateOutputTuple(tpu::TpuNodeContext * node_context,se::Stream * stream,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias,xla::ScopedShapedBuffer output_scoped_buffer,int device_ordinal)137 xla::StatusOr<RefPtr<XRTTupleAllocation>> AllocateOutputTuple(
138 tpu::TpuNodeContext* node_context, se::Stream* stream,
139 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
140 const xla::HloInputOutputAliasConfig& input_output_alias,
141 xla::ScopedShapedBuffer output_scoped_buffer, int device_ordinal) {
142 auto output_shaped_buffer = output_scoped_buffer.release();
143
144 xla::Shape output_device_shape = output_shaped_buffer.on_device_shape();
145 if (!output_device_shape.is_static()) {
146 TF_RETURN_IF_ERROR(
147 node_context->backend()->transfer_manager()->ReadDynamicShapes(
148 stream, &output_shaped_buffer, &output_device_shape));
149 }
150
151 XRTTupleAllocation* output_tuple;
152 xla::Shape output_host_shape =
153 xla::ShapeUtil::DeviceShapeToHostShape(output_device_shape);
154
155 TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
156 output_shaped_buffer, output_host_shape, output_device_shape,
157 node_context->backend(), device_ordinal, &output_tuple));
158 RefPtr<XRTTupleAllocation> output_tuple_ptr(output_tuple);
159
160 // If the input tuples had to release some buffers in order to provide the
161 // proper temporary ownership transfer, we patch the holes here by alising the
162 // buffers from the result tuple. The device address we patch back here, will
163 // essentially be the same one we carved out in the DoWork() function.
164 TF_RETURN_IF_ERROR(
165 RebuildOutputAliases(output_tuple_ptr, input_tuples, input_output_alias));
166
167 return std::move(output_tuple_ptr);
168 }
169
AllocateOutputTensors(OpKernelContext * context,XRTMemoryManager * memory_manager,tpu::TpuNodeContext * node_context,se::Stream * stream,const xrt::XRTExecutionConfig & config_proto,const TPUExecutableInfoProto & executable_proto,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias,xla::ScopedShapedBuffer output_scoped_buffer,int device_ordinal)170 Status AllocateOutputTensors(
171 OpKernelContext* context, XRTMemoryManager* memory_manager,
172 tpu::TpuNodeContext* node_context, se::Stream* stream,
173 const xrt::XRTExecutionConfig& config_proto,
174 const TPUExecutableInfoProto& executable_proto,
175 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
176 const xla::HloInputOutputAliasConfig& input_output_alias,
177 xla::ScopedShapedBuffer output_scoped_buffer, int device_ordinal) {
178 TF_ASSIGN_OR_RETURN(
179 RefPtr<XRTTupleAllocation> output_tuple,
180 AllocateOutputTuple(node_context, stream, input_tuples,
181 input_output_alias, std::move(output_scoped_buffer),
182 device_ordinal));
183 return CreateExecuteOutput(context, memory_manager, std::move(output_tuple),
184 config_proto.return_exploded_tuple());
185 }
186
RunExecutable(OpKernelContext * context,tpu::TpuNodeContext * node_context,const TPUExecutableInfoProto & executable,std::vector<xla::ExecutionInput> arguments,const string & execution_id,const uint32 rng_seed,const tpu::TpuProgramGroup * tpu_program_group,xla::Backend * const backend,se::Stream * stream,int core_index,int device_ordinal,string rendezvous_key_base)187 xla::StatusOr<xla::ExecutionOutput> RunExecutable(
188 OpKernelContext* context, tpu::TpuNodeContext* node_context,
189 const TPUExecutableInfoProto& executable,
190 std::vector<xla::ExecutionInput> arguments, const string& execution_id,
191 const uint32 rng_seed, const tpu::TpuProgramGroup* tpu_program_group,
192 xla::Backend* const backend, se::Stream* stream, int core_index,
193 int device_ordinal, string rendezvous_key_base) {
194 profiler::TraceMe trace_me("RunExecutable", /*level=*/2);
195
196 // se::StreamExecutor* executor = node->stream_executor();
197
198 std::unique_ptr<xla::DeviceAssignment> device_assignment;
199 if (executable.has_device_assignment()) {
200 TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
201 executable.device_assignment()));
202 }
203 // Ideally this should be the host-to-device stream from XlaDeviceContext.
204 // The particular anti-dependency this is avoiding (why we need a separate
205 // transfer stream) is between the executable writing tuple tables and
206 // TPUExecute()'s deregister_stream; if they come from the same stream pool
207 // antidependencies will occur. XlaBackend has a different pool of streams
208 // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these
209 // will never refer to the same stream.
210 TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
211 backend->BorrowStream(device_ordinal));
212 const TPUHostTransferInfoProto& host_transfer_info =
213 tpu_program_group->host_transfer_info(core_index);
214 TF_ASSIGN_OR_RETURN(
215 xla::ExecutionOutput output,
216 TPUExecute(executable, host_transfer_info,
217 *tpu_program_group->hlo_metadata(core_index),
218 std::move(arguments), rendezvous_key_base, rng_seed,
219 node_context, device_assignment.get(),
220 context->cancellation_manager(), context, stream,
221 transfer_stream_ptr.get(),
222 tpu_program_group->tpu_program(core_index)));
223
224 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
225
226 return output;
227 }
228
ExecuteTPUProgram(OpKernelContext * context,tpu::TpuNodeContext * node_context,XRTMemoryManager * memory_manager,const TPUExecutableInfoProto & executable,const GetBufferFunction & get_buffers_fn,const string & execution_id,const uint32 rng_seed,const tpu::TpuProgramGroup * tpu_program_group,xla::Backend * const backend,se::Stream * stream,int core_index,int device_ordinal,string rendezvous_key_base)229 xla::StatusOr<xla::ExecutionOutput> ExecuteTPUProgram(
230 OpKernelContext* context, tpu::TpuNodeContext* node_context,
231 XRTMemoryManager* memory_manager, const TPUExecutableInfoProto& executable,
232 const GetBufferFunction& get_buffers_fn, const string& execution_id,
233 const uint32 rng_seed, const tpu::TpuProgramGroup* tpu_program_group,
234 xla::Backend* const backend, se::Stream* stream, int core_index,
235 int device_ordinal, string rendezvous_key_base) {
236 auto runfn = [&]() -> xla::StatusOr<xla::ExecutionOutput> {
237 TF_ASSIGN_OR_RETURN(auto arguments, get_buffers_fn());
238 return RunExecutable(context, node_context, executable,
239 std::move(arguments), execution_id, rng_seed,
240 tpu_program_group, backend, stream, core_index,
241 device_ordinal, rendezvous_key_base);
242 };
243 return memory_manager->Run<xla::ExecutionOutput>(
244 runfn, backend, device_ordinal, /*requested_free_size=*/0);
245 }
246
247 // XRTExecuteOp
248
249 class XRTExecuteOp : public AsyncOpKernel {
250 public:
251 explicit XRTExecuteOp(OpKernelConstruction* context);
252
253 void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
254
255 private:
256 Status DoWork(OpKernelContext* context);
257 };
258
XRTExecuteOp(OpKernelConstruction * context)259 XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context)
260 : AsyncOpKernel(context, /* is_deferred = */ true) {}
261
ComputeAsync(OpKernelContext * context,DoneCallback done)262 void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
263 // Schedule onto the default queue, for unbounded concurrency. See b/73520706
264 OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
265 done();
266 }
267
DoWork(OpKernelContext * context)268 Status XRTExecuteOp::DoWork(OpKernelContext* context) {
269 VLOG(1) << "XRTExecuteOp::Compute";
270
271 const XlaDevice::Metadata* metadata;
272 TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
273 const int device_ordinal = metadata->device_ordinal();
274 // We are guaranteed that the object underlying TpuNodeContext won't be
275 // deleted out from under us, while node_context is alive.
276 TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_context,
277 tpu::TpuNodeContext::Create(device_ordinal));
278 xla::Backend* const backend = node_context->backend();
279 se::Stream* stream = context->op_device_context()->stream();
280
281 auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteCell());
282 profiler::TraceMe trace_me(
283 [context] {
284 return profiler::TraceMeEncode("TpuExecuteOp",
285 {{"step_id", context->step_id()}});
286 },
287 /*level=*/2);
288 profiler::TraceMe trace_me_init("XRTExecuteOp::Init", /*level=*/2);
289
290 auto* rm = GetTPUConfigResourceMgr();
291 TF_RET_CHECK(rm != nullptr);
292
293 const Tensor& execution_input = context->input(0);
294 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape()));
295 int64 compilation_handle = execution_input.scalar<int64>()();
296
297 const Tensor& execution_config = context->input(1);
298 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
299 xrt::XRTExecutionConfig config_proto;
300 TF_RET_CHECK(
301 config_proto.ParseFromString(execution_config.scalar<tstring>()()));
302
303 int core_index_in_replica = config_proto.core_index_in_replica();
304 bool release_inputs = config_proto.release_input_handles();
305 bool release_compilation = config_proto.release_compilation_handle();
306
307 string rendezvous_key_base = std::to_string(compilation_handle);
308 std::unique_ptr<CompilationCacheEntryRef> entry;
309 TF_RETURN_IF_ERROR(GetComputationCacheEntry(rm, compilation_handle,
310 core_index_in_replica, &entry));
311
312 TpuCompilationCacheEntry centry = entry->get();
313 const tpu::TpuProgramGroup* tpu_program_group =
314 tensorflow::down_cast<const tpu::TpuProgramGroup*>(
315 centry.tpu_program_group());
316 CHECK_NE(tpu_program_group, nullptr);
317
318 if (release_compilation) {
319 // Process-wide cache of Tpu executables.
320 tpu::TpuCompilationCacheInterface* cache;
321 TF_RETURN_IF_ERROR(rm->Lookup<tpu::TpuCompilationCacheInterface>(
322 rm->default_container(), tpu::kCompilationCacheResourceName, &cache));
323 core::ScopedUnref cache_unref(cache);
324 TF_RETURN_IF_ERROR(cache->Release(compilation_handle));
325 VLOG(2) << "Released compilation handle " << compilation_handle;
326 }
327
328 const int core_index = centry.core_index();
329 const TPUExecutableInfoProto& executable =
330 tpu_program_group->executable_info(core_index);
331
332 std::vector<bool> input_is_dynamic = GetDynamicInputInfo(executable);
333
334 TF_ASSIGN_OR_RETURN(
335 xla::HloInputOutputAliasConfig input_output_alias,
336 GetExecutableAliasConfig(tpu_program_group, backend, core_index));
337 TF_ASSIGN_OR_RETURN(std::vector<InputCoords> input_coords,
338 GetComputationInputs(context, "input_handles"));
339
340 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
341 XRTMemoryManager::WorkingSet working_set(memory_manager);
342 TF_ASSIGN_OR_RETURN(
343 std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
344 GetInputTupleAllocations(
345 input_coords, &working_set, backend, executable.input_shapes_size(),
346 [&](int64 i) { return xla::Shape(executable.input_shapes(i)); },
347 release_inputs));
348 auto get_buffers_fn = [&]() {
349 return GetArgumentsBuffers(input_output_alias, input_tuples,
350 input_is_dynamic, release_inputs);
351 };
352 trace_me_init.Stop();
353
354 TF_ASSIGN_OR_RETURN(
355 xla::ExecutionOutput output,
356 ExecuteTPUProgram(
357 context, node_context.get(), memory_manager.get(), executable,
358 get_buffers_fn, config_proto.execution_instance_key(),
359 config_proto.rng_seed(), tpu_program_group, backend, stream,
360 core_index, device_ordinal, rendezvous_key_base));
361
362 // AllocateComputationOutput writes the output tuple handle to the output
363 // tensor return value from the Op.
364 TF_RETURN_IF_ERROR(AllocateOutputTensors(
365 context, memory_manager.get(), node_context.get(), stream, config_proto,
366 executable, input_tuples, input_output_alias, output.ConsumeResult(),
367 device_ordinal));
368 return Status::OK();
369 }
370
371 class XRTExecuteChainedOp : public AsyncOpKernel {
372 public:
373 explicit XRTExecuteChainedOp(OpKernelConstruction* context);
374
375 void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
376
377 private:
378 Status DoWork(OpKernelContext* context);
379 };
380
XRTExecuteChainedOp(OpKernelConstruction * context)381 XRTExecuteChainedOp::XRTExecuteChainedOp(OpKernelConstruction* context)
382 : AsyncOpKernel(context, /* is_deferred = */ true) {}
383
ComputeAsync(OpKernelContext * context,DoneCallback done)384 void XRTExecuteChainedOp::ComputeAsync(OpKernelContext* context,
385 DoneCallback done) {
386 // Schedule onto the default queue, for unbounded concurrency. See b/73520706
387 OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
388 done();
389 }
390
DoWork(OpKernelContext * context)391 Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
392 VLOG(1) << "XRTExecuteChainedOp::Compute";
393 const XlaDevice::Metadata* metadata;
394 TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
395 const int device_ordinal = metadata->device_ordinal();
396 // We are guaranteed that the object underlying TpuNodeContext won't be
397 // deleted out from under us, while node_context is alive.
398 TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_context,
399 tpu::TpuNodeContext::Create(device_ordinal));
400 xla::Backend* const backend = node_context->backend();
401 se::Stream* stream = context->op_device_context()->stream();
402 auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteChainedCell());
403 profiler::TraceMe trace_me(
404 [context] {
405 return profiler::TraceMeEncode("TpuExecuteChainedOp",
406 {{"step_id", context->step_id()}});
407 },
408 /*level=*/2);
409 ResourceMgr* rm = GetTPUConfigResourceMgr();
410 TF_RET_CHECK(rm != nullptr);
411
412 const Tensor& execution_plan = context->input(0);
413 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape()));
414 xrt::XRTChainedExecutePlan plan;
415 TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar<tstring>()()));
416
417 const Tensor& execution_config = context->input(1);
418 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
419 xrt::XRTChainedExecuteConfig config;
420 TF_RET_CHECK(config.ParseFromString(execution_config.scalar<tstring>()()));
421
422 TpuCompilationCacheLookup* proto_lookup;
423 TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
424 tpu::kCompiledProtoCacheResourceName,
425 &proto_lookup));
426 core::ScopedUnref lookup_unref(proto_lookup);
427 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
428 auto execute_op = [&](const xrt::XRTChainedExecuteOp& op,
429 absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs)
430 -> xla::StatusOr<RefPtr<XRTTupleAllocation>> {
431 std::unique_ptr<CompilationCacheEntryRef> entry;
432 TF_RETURN_IF_ERROR(proto_lookup->Lookup(
433 op.computation_handle(), config.core_index_in_replica(), &entry));
434 string rendezvous_key_base = std::to_string(op.computation_handle());
435 TpuCompilationCacheEntry centry = entry->get();
436 const tpu::TpuProgramGroup* tpu_program_group =
437 tensorflow::down_cast<const tpu::TpuProgramGroup*>(
438 centry.tpu_program_group());
439 CHECK_NE(tpu_program_group, nullptr);
440 const int core_index = centry.core_index();
441 const TPUExecutableInfoProto& executable =
442 tpu_program_group->executable_info(core_index);
443 std::vector<bool> input_is_dynamic = GetDynamicInputInfo(executable);
444
445 TF_ASSIGN_OR_RETURN(
446 xla::HloInputOutputAliasConfig input_output_alias,
447 GetExecutableAliasConfig(tpu_program_group, backend, core_index));
448 TF_ASSIGN_OR_RETURN(std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
449 GetChainedOpInputs(op, op_inputs, executable));
450 auto get_buffers_fn = [&]() {
451 return GetArgumentsBuffers(input_output_alias, input_tuples,
452 input_is_dynamic,
453 /*release_inputs=*/false);
454 };
455 TF_ASSIGN_OR_RETURN(
456 xla::ExecutionOutput output,
457 ExecuteTPUProgram(context, node_context.get(), memory_manager.get(),
458 executable, get_buffers_fn,
459 config.execution_instance_key(), config.rng_seed(),
460 tpu_program_group, backend, stream, core_index,
461 device_ordinal, rendezvous_key_base));
462 return AllocateOutputTuple(node_context.get(), stream, input_tuples,
463 input_output_alias, output.ConsumeResult(),
464 device_ordinal);
465 };
466
467 return ExecuteChained(context, memory_manager, backend, device_ordinal, plan,
468 config, execute_op);
469 }
470
471 } // namespace
472
473 REGISTER_KERNEL_BUILDER(Name("XRTExecute")
474 .Device(DEVICE_TPU_NODE)
475 .HostMemory("computation_handle")
476 .HostMemory("execution_config")
477 .HostMemory("input_handles")
478 .HostMemory("output_handle"),
479 XRTExecuteOp);
480
481 REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained")
482 .Device(DEVICE_TPU_NODE)
483 .HostMemory("execution_plan")
484 .HostMemory("execution_config")
485 .HostMemory("output_handle"),
486 XRTExecuteChainedOp);
487
488 } // namespace tensorflow
489