1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 
17 #include "absl/types/span.h"
18 #include "tensorflow/compiler/xla/service/compiler.h"
19 #include "tensorflow/compiler/xla/service/executable.h"
20 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
23 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/stream_executor/device_memory_allocator.h"
28 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
29 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
30 #include "tensorflow/stream_executor/tpu/proto_helper.h"
31 #include "tensorflow/stream_executor/tpu/status_helper.h"
32 #include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
33 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
34 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
35 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
36 #include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
37 #include "tensorflow/stream_executor/tpu/tpu_stream.h"
38 
39 namespace ApiConverter {
ToC(const xla::ServiceExecutableRunOptions & options)40 static SE_ExecutableRunOptions ToC(
41     const xla::ServiceExecutableRunOptions& options) {
42   SE_ExecutableRunOptions se_options;
43   se_options.allocator = ApiConverter::ToC(options.run_options().allocator());
44   se_options.device_ordinal = options.run_options().device_ordinal();
45   if (options.run_options().host_to_device_stream() != nullptr) {
46     se_options.host_to_device_stream =
47         static_cast<tensorflow::tpu::TpuStream*>(
48             options.run_options().host_to_device_stream()->implementation())
49             ->se_stream();
50   } else {
51     se_options.host_to_device_stream = nullptr;
52   }
53 
54   if (options.run_options().device_assignment() != nullptr) {
55     xla::DeviceAssignmentProto dev_assign_proto;
56     options.run_options()
57         .device_assignment()
58         ->Serialize(&dev_assign_proto)
59         .IgnoreError();
60     se_options.device_assignment =
61         stream_executor::tpu::SerializeProto(dev_assign_proto);
62   } else {
63     se_options.device_assignment.bytes = nullptr;
64     se_options.device_assignment.size = 0;
65   }
66 
67   se_options.rng_seed = options.run_options().rng_seed();
68   se_options.run_id = options.run_options().run_id().ToInt();
69   se_options.launch_id = options.run_options().launch_id();
70 
71   CHECK_EQ(options.run_options().then_execute_function(), nullptr)
72       << "ThenExecuteFunction not supported by this platform.";
73 
74   auto impl =
75       const_cast<stream_executor::Stream*>(options.stream())->implementation();
76   se_options.stream =
77       static_cast<tensorflow::tpu::TpuStream*>(impl)->se_stream();
78   return se_options;
79 }
80 }  // namespace ApiConverter
81 
82 namespace xla {
83 
84 namespace {
85 
86 using ::tensorflow::tpu::ExecutorApiFn;
87 
88 class TpuExecutable : public TpuExecutableInterface {
89  public:
TpuExecutable(SE_Executable * se_executable,std::shared_ptr<HloModule> hlo_module)90   TpuExecutable(SE_Executable* se_executable,
91                 std::shared_ptr<HloModule> hlo_module)
92       : TpuExecutableInterface(std::move(hlo_module)),
93         se_executable_(se_executable) {}
94 
~TpuExecutable()95   ~TpuExecutable() override {
96     ExecutorApiFn()->TpuExecutable_FreeFn(se_executable_);
97   }
98 
ExecuteAsyncOnStream(const ServiceExecutableRunOptions * run_options,std::vector<ExecutionInput> arguments,HloExecutionProfile * hlo_execution_profile)99   StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
100       const ServiceExecutableRunOptions* run_options,
101       std::vector<ExecutionInput> arguments,
102       HloExecutionProfile* hlo_execution_profile) override {
103     SE_ExecutableRunOptions se_run_options = ApiConverter::ToC(*run_options);
104     SE_ExecutionInput** se_args = new SE_ExecutionInput*[arguments.size()];
105     for (int i = 0; i < arguments.size(); ++i) {
106       auto& arg = arguments[i];
107       se_args[i] = new SE_ExecutionInput;
108 
109       ApiConverter::ToC(arg.shape(), &se_args[i]->shape_tree.shape);
110       auto* arg_buffers = arg.MutableBuffers();
111       absl::InlinedVector<SE_MaybeOwningDeviceMemory, 2> se_buffers;
112       for (auto& pair : *arg_buffers) {
113         bool aliased = arg.unowned_indices().count(pair.first) > 0;
114         se_buffers.push_back(ApiConverter::ToC(pair.second, aliased));
115       }
116       se_args[i]->shape_tree.buffers =
117           new SE_MaybeOwningDeviceMemory[se_buffers.size()];
118       for (int j = 0; j < se_buffers.size(); ++j) {
119         se_args[i]->shape_tree.buffers[j] = se_buffers[j];
120       }
121 
122       ApiConverter::ToC(arg.shape(), &se_args[i]->dynamic_shape);
123       const auto& unowned_indices = arg.unowned_indices();
124       se_args[i]->unowned_indices_size = unowned_indices.size();
125       se_args[i]->unowned_indices = new XLA_ShapeIndex[unowned_indices.size()];
126       int j = 0;
127       for (auto& idx : unowned_indices) {
128         se_args[i]->unowned_indices[j] = ApiConverter::ToC(idx);
129         ++j;
130       }
131     }
132     SE_ExecutionOutput se_execution_output;
133     StatusHelper status;
134     ExecutorApiFn()->TpuExecutable_ExecuteAsyncOnStreamFn(
135         se_executable_, &se_run_options, se_args, arguments.size(), nullptr,
136         &se_execution_output, status.c_status);
137 
138     if (se_run_options.device_assignment.bytes != nullptr) {
139       stream_executor::tpu::SerializedProto_Free(
140           se_run_options.device_assignment);
141     }
142     for (int i = 0; i < arguments.size(); ++i) {
143       ApiConverter::Free(&se_args[i]->shape_tree.shape);
144       ApiConverter::Free(&se_args[i]->dynamic_shape);
145       delete[] se_args[i]->unowned_indices;
146       delete[] se_args[i]->shape_tree.buffers;
147       delete se_args[i];
148     }
149     delete[] se_args;
150 
151     if (!status.ok()) {
152       return status.status();
153     }
154 
155     xla::ScopedShapedBuffer result(
156         ApiConverter::FromC(&se_execution_output.result),
157         run_options->stream()->parent()->GetAllocator());
158     ApiConverter::Free(&se_execution_output.result);
159 
160     ExecutionOutput output(std::move(result));
161     for (int i = 0; i < se_execution_output.aliased_indices_size; ++i) {
162       output.AddAliasedIndex(
163           ApiConverter::FromC(&se_execution_output.aliased_indices[i]));
164     }
165     ExecutorApiFn()->TpuExecutable_FreeXlaShapeIndexArrayFn(
166         se_execution_output.aliased_indices);
167 
168     for (int i = 0; i < se_execution_output.to_be_released_size; ++i) {
169       output.AddToBeReleased(
170           ApiConverter::FromC(&se_execution_output.to_be_released[i],
171                               run_options->stream()->parent()->GetAllocator())
172               .Release()
173               .value());
174     }
175     ExecutorApiFn()->TpuExecutable_FreeMaybeOwningDeviceMemoryArrayFn(
176         se_execution_output.to_be_released);
177 
178     return output;
179   }
180 
fingerprint() const181   absl::string_view fingerprint() const override {
182     const char* data;
183     size_t size;
184     ExecutorApiFn()->TpuExecutable_FingerprintFn(se_executable_, &data, &size);
185     return absl::string_view(data, size);
186   }
187 
188  private:
LoadProgramAndEnqueueToStream(const ServiceExecutableRunOptions & run_options,absl::Span<const stream_executor::DeviceMemoryBase> arguments,stream_executor::DeviceMemoryBase result,absl::optional<stream_executor::DeviceMemoryBase> cross_program_prefetch_addr)189   Status LoadProgramAndEnqueueToStream(
190       const ServiceExecutableRunOptions& run_options,
191       absl::Span<const stream_executor::DeviceMemoryBase> arguments,
192       stream_executor::DeviceMemoryBase result,
193       absl::optional<stream_executor::DeviceMemoryBase>
194           cross_program_prefetch_addr) override {
195     LOG(FATAL) << "LoadProgramAndEnqueueToStream unimplemented";
196   }
197 
HostShapeToDeviceShape(const Shape & host_shape)198   Shape HostShapeToDeviceShape(const Shape& host_shape) override {
199     LOG(FATAL) << "HostShapeToDeviceShape unimplemented";
200   }
201 
ShapeSize(const Shape & shape)202   int64 ShapeSize(const Shape& shape) override {
203     LOG(FATAL) << "ShapeSize unimplemented";
204   }
205 
206   SE_Executable* se_executable_;
207 };
208 
209 class TpuCompiler : public Compiler {
210  public:
TpuCompiler()211   TpuCompiler() { compiler_ = ExecutorApiFn()->TpuCompiler_NewFn(); }
~TpuCompiler()212   ~TpuCompiler() override { ExecutorApiFn()->TpuCompiler_FreeFn(compiler_); }
213 
PlatformId() const214   stream_executor::Platform::Id PlatformId() const override {
215     return tensorflow::tpu::GetTpuPlatformId();
216   }
217 
RunHloPasses(std::unique_ptr<HloModule> module,stream_executor::StreamExecutor * executor,const CompileOptions & options)218   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
219       std::unique_ptr<HloModule> module,
220       stream_executor::StreamExecutor* executor,
221       const CompileOptions& options) override {
222     XLA_HloModule hlo_module;
223     auto cleanup = xla::MakeCleanup([&hlo_module]() {
224       stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
225       ApiConverter::Free(&hlo_module.module_config);
226     });
227     hlo_module.module_config = ApiConverter::ToC(module->config());
228     hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
229     auto allocator = ApiConverter::ToC(options.device_allocator);
230     XLA_HloModule result;
231     StatusHelper status;
232     ExecutorApiFn()->TpuCompiler_RunHloPassesFn(
233         compiler_, &hlo_module,
234         static_cast<tensorflow::tpu::TpuExecutor*>(executor->implementation())
235             ->se_executor(),
236         &allocator, &result, status.c_status);
237     if (!status.ok()) {
238       return status.status();
239     }
240     HloModuleProto result_proto =
241         stream_executor::tpu::DeserializeProto<HloModuleProto>(result.proto);
242     stream_executor::tpu::SerializedProto_Free(result.proto);
243     return HloModule::CreateFromProto(result_proto, module->config());
244   }
245 
RunBackend(std::unique_ptr<HloModule> module,stream_executor::StreamExecutor * executor,const CompileOptions & options)246   StatusOr<std::unique_ptr<Executable>> RunBackend(
247       std::unique_ptr<HloModule> module,
248       stream_executor::StreamExecutor* executor,
249       const CompileOptions& options) override {
250     XLA_HloModule hlo_module;
251     auto cleanup = xla::MakeCleanup([&hlo_module]() {
252       stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
253       ApiConverter::Free(&hlo_module.module_config);
254     });
255     SE_Executable* result;
256     hlo_module.module_config = ApiConverter::ToC(module->config());
257     hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
258     auto allocator = ApiConverter::ToC(options.device_allocator);
259 
260     StatusHelper status;
261     ExecutorApiFn()->TpuCompiler_RunBackendFn(
262         compiler_, &hlo_module,
263         static_cast<tensorflow::tpu::TpuExecutor*>(executor->implementation())
264             ->se_executor(),
265         &allocator, &result, status.c_status);
266     if (!status.ok()) {
267       return status.status();
268     }
269 
270     std::unique_ptr<Executable> exec =
271         absl::make_unique<TpuExecutable>(result, std::move(module));
272     return exec;
273   }
274 
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<stream_executor::StreamExecutor * >> stream_exec,const CompileOptions & options)275   StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
276       std::unique_ptr<HloModuleGroup> module_group,
277       std::vector<std::vector<stream_executor::StreamExecutor*>> stream_exec,
278       const CompileOptions& options) override {
279     XLA_HloModuleGroup se_module_group;
280     se_module_group.proto =
281         stream_executor::tpu::SerializeProto(module_group->ToProto());
282     se_module_group.module_config =
283         new XLA_HloModuleConfig[module_group->size()];
284     int module_group_size = module_group->size();
285     auto cleanup_config =
286         xla::MakeCleanup([&se_module_group, module_group_size]() {
287           for (auto i = 0; i < module_group_size; ++i) {
288             ApiConverter::Free(&se_module_group.module_config[i]);
289           }
290           delete[] se_module_group.module_config;
291         });
292     for (int i = 0; i < module_group->size(); ++i) {
293       const auto& config = module_group->module(i).config();
294       se_module_group.module_config[i] = ApiConverter::ToC(config);
295     }
296     std::vector<SE_StreamExecutorList> se_lists(stream_exec.size());
297     std::vector<std::vector<SE_StreamExecutor*>> se_lists_storage;
298     for (int i = 0; i < stream_exec.size(); ++i) {
299       se_lists[i].count = stream_exec[i].size();
300       se_lists_storage.emplace_back(stream_exec[i].size());
301       se_lists[i].exec = se_lists_storage.back().data();
302       for (int j = 0; j < stream_exec[i].size(); ++j) {
303         se_lists[i].exec[j] = static_cast<tensorflow::tpu::TpuExecutor*>(
304                                   stream_exec[i][j]->implementation())
305                                   ->se_executor();
306       }
307     }
308 
309     SE_DeviceMemoryAllocator allocator =
310         ApiConverter::ToC(options.device_allocator);
311 
312     SE_Executable** se_executables = new SE_Executable*[module_group->size()];
313 
314     StatusHelper status;
315 
316     ExecutorApiFn()->TpuCompiler_CompileFn(
317         compiler_, &se_module_group, se_lists.data(), stream_exec.size(),
318         &allocator, se_executables, status.c_status);
319 
320     if (!status.ok()) {
321       return status.status();
322     }
323 
324     std::vector<std::unique_ptr<Executable>> executables;
325     for (int i = 0; i < module_group->size(); ++i) {
326       // We get the HloModule from the compiled executable, rather than reusing
327       // the input module from 'module_group', in case the module changed in
328       // some way. For example, if the computation is automatically partitioned
329       // via XLA, the executable's module may have different input/output shapes
330       // than the input module.
331       XLA_HloModule c_module =
332           ExecutorApiFn()->TpuExecutable_HloModuleFn(se_executables[i]);
333       auto cleanup_c_module =
334           xla::MakeCleanup([&c_module]() { ApiConverter::Free(&c_module); });
335       TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
336                           ApiConverter::FromC(c_module));
337       std::shared_ptr<HloModule> module_shared(module.release());
338       executables.emplace_back(absl::make_unique<TpuExecutable>(
339           se_executables[i], std::move(module_shared)));
340     }
341 
342     stream_executor::tpu::SerializedProto_Free(se_module_group.proto);
343     delete[] se_executables;
344 
345     return executables;
346   }
347 
348   // Compiles the HLO module group for ahead-of-time execution.  This is
349   // intended for use in static compilation.
350   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & options)351   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
352                      const AotCompilationOptions& options) override {
353     return Unimplemented("This compiler does not support CompileAheadOfTime.");
354   }
355 
356   // Returns a function that computes the size in bytes of the logical
357   // buffer that contains a shape.
ShapeSizeBytesFunction() const358   HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
359     return [this](const xla::Shape& shape) {
360       XLA_Shape c_shape;
361       ApiConverter::ToC(shape, &c_shape);
362       int64 bytes =
363           ExecutorApiFn()->TpuCompiler_ShapeSizeFn(compiler_, &c_shape);
364       ApiConverter::Free(&c_shape);
365       return bytes;
366     };
367   }
368 
369  private:
370   Tpu_Compiler* compiler_;
371 };
372 
InitModule()373 static bool InitModule() {
374   xla::Compiler::RegisterCompilerFactory(
375       tensorflow::tpu::GetTpuPlatformId(),
376       []() { return absl::make_unique<TpuCompiler>(); });
377   return true;
378 }
379 
380 static bool module_initialized = InitModule();
381 
382 }  // namespace
383 }  // namespace xla
384