1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
16 
17 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
18 #include "tensorflow/compiler/xla/xla.pb.h"
19 #include "tensorflow/core/lib/gtl/cleanup.h"
20 #include "tensorflow/core/platform/casts.h"
21 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
22 #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
23 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
24 #include "tensorflow/core/tpu/tpu_api.h"
25 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
26 #include "tensorflow/stream_executor/tpu/proto_helper.h"
27 #include "tensorflow/stream_executor/tpu/status_helper.h"
28 
29 namespace tensorflow {
30 namespace tpu {
31 namespace {
32 namespace se_tpu = ::stream_executor::tpu;
33 using stream_executor::port::Status;
34 }  // namespace
35 
ConstructExecutableInfo(const XLA_TpuProgram * xla_tpu_program)36 TPUExecutableInfoProto TpuProgramGroup::ConstructExecutableInfo(
37     const XLA_TpuProgram* xla_tpu_program) {
38   VLOG(1) << "ConstructExecutableInfo";
39   TpuSerializedProto serialized_executable_info = {};
40   StatusHelper status;
41   OpsApiFn()->TpuProgram_GetExecutableInfoFn(
42       xla_tpu_program, &serialized_executable_info, status.c_status);
43   TPUExecutableInfoProto executable_info;
44   if (status.ok()) {
45     executable_info = se_tpu::DeserializeProto<TPUExecutableInfoProto>(
46         serialized_executable_info);
47     StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
48   }
49   return executable_info;
50 }
51 
ConstructHostTransferInfo(const XLA_TpuProgram * xla_tpu_program)52 TPUHostTransferInfoProto TpuProgramGroup::ConstructHostTransferInfo(
53     const XLA_TpuProgram* xla_tpu_program) {
54   VLOG(1) << "ConstructHostTransferInfo";
55   TpuSerializedProto serialized_host_transfer_info = {};
56   StatusHelper status;
57   OpsApiFn()->TpuProgram_GetHostTransferInfoFn(
58       xla_tpu_program, &serialized_host_transfer_info, status.c_status);
59   TPUHostTransferInfoProto host_transfer_info;
60   if (status.ok()) {
61     host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
62         serialized_host_transfer_info);
63     StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
64   }
65   return host_transfer_info;
66 }
67 
ConstructHloMetadata(const XLA_TpuProgram * xla_tpu_program)68 xla::HloProto TpuProgramGroup::ConstructHloMetadata(
69     const XLA_TpuProgram* xla_tpu_program) {
70   VLOG(1) << "ConstructHloMetadata";
71   TpuSerializedProto serialized_hlo_metadata = {};
72   StatusHelper status;
73   OpsApiFn()->TpuProgram_GetHloMetadataFn(
74       xla_tpu_program, &serialized_hlo_metadata, status.c_status);
75   xla::HloProto hlo_metadata;
76   if (status.ok()) {
77     hlo_metadata =
78         se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
79     StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
80   }
81   return hlo_metadata;
82 }
83 
Initialize(absl::Span<XLA_TpuProgram * const> xla_tpu_programs)84 void TpuProgramGroup::Initialize(
85     absl::Span<XLA_TpuProgram* const> xla_tpu_programs) {
86   CHECK_GT(xla_tpu_programs.size(), 0);
87   CHECK_EQ(program_count(), 0) << "Reinitialization of an existing "
88                                   "`TpuProgramGroup` instance is prohibited.";
89   set_tpu_programs(xla_tpu_programs);
90 
91   std::vector<bool> may_modify_variables_array(tpu_programs_.size(), false);
92   std::vector<TPUExecutableInfoProto> executable_infos(tpu_programs_.size());
93   std::vector<TPUHostTransferInfoProto> host_transfer_infos(
94       tpu_programs_.size());
95   std::vector<xla::HloProto> hlo_metadatas(tpu_programs_.size());
96   for (size_t i = 0; i < tpu_programs_.size(); ++i) {
97     const XLA_TpuProgram* xla_tpu_program = tpu_programs_[i];
98     bool may_modify_variables;
99     OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_program,
100                                                    &may_modify_variables);
101     may_modify_variables_array[i] = may_modify_variables;
102     executable_infos[i] = ConstructExecutableInfo(xla_tpu_program);
103     host_transfer_infos[i] = ConstructHostTransferInfo(xla_tpu_program);
104     hlo_metadatas[i] = ConstructHloMetadata(xla_tpu_program);
105   }
106 
107   may_modify_variables_ = may_modify_variables_array;
108   executable_infos_ = executable_infos;
109   host_transfer_infos_ = host_transfer_infos;
110   hlo_metadatas_ = hlo_metadatas;
111   RefreshHloMetadatasPtrs();
112 }
113 
has_sharding_program() const114 bool TpuProgramGroup::has_sharding_program() const {
115   for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
116     if (!OpsApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
117       return false;
118     }
119   }
120   return true;
121 }
122 
program_count() const123 size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
124 
program_size() const125 int64_t TpuProgramGroup::program_size() const {
126   int64_t total_size = 0;
127   for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
128     total_size += OpsApiFn()->TpuProgram_GetProgramSizeFn(tpu_program);
129   }
130   return total_size;
131 }
132 
LogProgramMemorySummary()133 bool TpuProgramGroup::LogProgramMemorySummary() {
134   bool success = true;
135   for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
136     success &= OpsApiFn()->TpuProgram_LogProgramMemorySummaryFn(tpu_program);
137   }
138   return success;
139 }
140 
UnloadAndDestroyPrograms()141 void TpuProgramGroup::UnloadAndDestroyPrograms() {
142   for (XLA_TpuProgram* tpu_program : tpu_programs_) {
143     StatusHelper status;
144     OpsApiFn()->TpuProgram_UnloadAndDestroyFn(tpu_program, status.c_status);
145     auto s = status.status();
146     if (!s.ok()) {
147       LOG(ERROR) << "TpuProgramGroup::UnloadPrograms(): " << s.ToString();
148     }
149   }
150   tpu_programs_.clear();
151 }
152 
TpuProgramGroup(TpuProgramGroup && other)153 TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other)
154     : may_modify_variables_(std::move(other.may_modify_variables_)),
155       tpu_programs_(std::move(other.tpu_programs_)),
156       executable_infos_(std::move(other.executable_infos_)),
157       host_transfer_infos_(std::move(other.host_transfer_infos_)),
158       hlo_metadatas_(std::move(other.hlo_metadatas_)) {
159   RefreshHloMetadatasPtrs();
160 }
161 
set_hlo_metadatas(absl::Span<const xla::HloProto> hlo_metadatas)162 void TpuProgramGroup::set_hlo_metadatas(
163     absl::Span<const xla::HloProto> hlo_metadatas) {
164   hlo_metadatas_.resize(hlo_metadatas.size());
165   for (size_t i = 0; i < hlo_metadatas.size(); ++i) {
166     hlo_metadatas_[i] = hlo_metadatas[i];
167   }
168   RefreshHloMetadatasPtrs();
169 }
170 
hlo_metadatas() const171 absl::Span<const xla::HloProto* const> TpuProgramGroup::hlo_metadatas() const {
172   return hlo_metadatas_ptrs_;
173 }
174 
hlo_metadata(int index) const175 const xla::HloProto* TpuProgramGroup::hlo_metadata(int index) const {
176   CHECK_GE(index, 0);
177   CHECK_LT(index, hlo_metadatas_ptrs_.size());
178   return hlo_metadatas_ptrs_[index];
179 }
180 
RefreshHloMetadatasPtrs()181 void TpuProgramGroup::RefreshHloMetadatasPtrs() {
182   hlo_metadatas_ptrs_.reserve(hlo_metadatas_.size());
183   for (const auto& hlo_metadata_internal_ : hlo_metadatas_) {
184     hlo_metadatas_ptrs_.push_back(&hlo_metadata_internal_);
185   }
186 }
187 
LogCompilationStats(const TpuCompilationCacheKey & key,absl::Duration duration)188 Status TpuProgramGroup::LogCompilationStats(const TpuCompilationCacheKey& key,
189                                             absl::Duration duration) {
190   // A placeholder for tracking compilation statistics for future work. The
191   // implementation can be pushing into some external storage for analytics.
192   return Status::OK();
193 }
194 
may_modify_variables_list() const195 const std::vector<bool>& TpuProgramGroup::may_modify_variables_list() const {
196   return may_modify_variables_;
197 }
198 
set_may_modify_variables(const std::vector<bool> & may_modify_variables)199 void TpuProgramGroup::set_may_modify_variables(
200     const std::vector<bool>& may_modify_variables) {
201   may_modify_variables_ = may_modify_variables;
202 }
203 
may_modify_variables(int index) const204 bool TpuProgramGroup::may_modify_variables(int index) const {
205   CHECK_GE(index, 0);
206   CHECK_LT(index, tpu_programs_.size());
207   bool may_modify_variables;
208   OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
209                                                  &may_modify_variables);
210   return may_modify_variables;
211 }
212 
tpu_programs() const213 const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
214   return tpu_programs_;
215 }
216 
tpu_program(int index) const217 const XLA_TpuProgram* TpuProgramGroup::tpu_program(int index) const {
218   CHECK_GE(index, 0);
219   CHECK_LT(index, tpu_programs_.size());
220   return tpu_programs_[index];
221 }
222 
set_tpu_programs(absl::Span<XLA_TpuProgram * const> tpu_programs)223 void TpuProgramGroup::set_tpu_programs(
224     absl::Span<XLA_TpuProgram* const> tpu_programs) {
225   tpu_programs_.resize(tpu_programs.size());
226   for (size_t i = 0; i < tpu_programs.size(); ++i) {
227     tpu_programs_[i] = tpu_programs[i];
228   }
229 }
230 
executable_info(int index) const231 const TPUExecutableInfoProto& TpuProgramGroup::executable_info(
232     int index) const {
233   CHECK_GE(index, 0);
234   CHECK_LT(index, executable_infos_.size());
235   return executable_infos_[index];
236 }
237 
host_transfer_info(int index) const238 const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info(
239     int index) const {
240   CHECK_GE(index, 0);
241   CHECK_LT(index, host_transfer_infos_.size());
242   return host_transfer_infos_[index];
243 }
244 
245 /*static*/
CompileAndBuild(const TpuCompilationRequestProto & compilation_request,const XLA_TpuMeshState * mesh_state,TpuProgramGroupInterface * tpu_program_group_interface)246 Status TpuProgramGroup::CompileAndBuild(
247     const TpuCompilationRequestProto& compilation_request,
248     const XLA_TpuMeshState* mesh_state,
249     TpuProgramGroupInterface* tpu_program_group_interface) {
250   se_tpu::SerializedProto serialized_compilation_request =
251       se_tpu::SerializeProto(compilation_request);
252   auto cleanup = gtl::MakeCleanup([serialized_compilation_request] {
253     se_tpu::SerializedProto_Free(serialized_compilation_request);
254   });
255   size_t count = 0;
256   XLA_TpuProgram** xla_tpu_programs = nullptr;
257   StatusHelper status;
258   OpsApiFn()->TpuCompile_CompileAndBuildFn(serialized_compilation_request,
259                                            mesh_state, &xla_tpu_programs,
260                                            &count, status.c_status);
261   if (!status.ok()) {
262     VLOG(1) << "Run CompileAndBuild failed.";
263     return status.status();
264   }
265 
266   // SPMD could return 1 result for all partitions.
267   TF_RET_CHECK(count == 1 ||
268                count == compilation_request.metadata().num_cores_per_replica());
269 
270   VLOG(1) << "Initialize TpuProgramGroup.";
271   TpuProgramGroup* tpu_program_group =
272       tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
273   tpu_program_group->Initialize(
274       absl::MakeConstSpan(&xla_tpu_programs[0], count));
275   OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
276   return status.status();
277 }
278 
279 /*static*/
CompileAndBuild(const xrt::XLAComputation & xrt_computation_proto,const XLA_TpuMeshState * mesh_state,TpuProgramGroupInterface * tpu_program_group_interface)280 Status TpuProgramGroup::CompileAndBuild(
281     const xrt::XLAComputation& xrt_computation_proto,
282     const XLA_TpuMeshState* mesh_state,
283     TpuProgramGroupInterface* tpu_program_group_interface) {
284   se_tpu::SerializedProto serialized_compilation_request =
285       se_tpu::SerializeProto(xrt_computation_proto);
286   auto cleanup = gtl::MakeCleanup([serialized_compilation_request] {
287     se_tpu::SerializedProto_Free(serialized_compilation_request);
288   });
289   size_t count = 0;
290   XLA_TpuProgram** xla_tpu_programs = nullptr;
291   StatusHelper status;
292   OpsApiFn()->TpuCompile_XrtCompileAndBuildFn(serialized_compilation_request,
293                                               mesh_state, &xla_tpu_programs,
294                                               &count, status.c_status);
295   if (!status.ok()) {
296     VLOG(1) << "Run CompileAndBuild failed.";
297     return status.status();
298   }
299 
300   // SPMD could return 1 result for all partitions.
301   int num_cores_per_replica =
302       xrt_computation_proto.config().num_cores_per_replica()
303           ? xrt_computation_proto.config().num_cores_per_replica()
304           : 1;
305   TF_RET_CHECK(count == 1 || count == num_cores_per_replica);
306   VLOG(1) << "Initialize TpuProgramGroup.";
307   TpuProgramGroup* tpu_program_group =
308       tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
309   tpu_program_group->Initialize(
310       absl::MakeConstSpan(&xla_tpu_programs[0], count));
311   OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
312   return status.status();
313 }
314 
tpu_programs(TpuProgramShardingType sharding_type) const315 std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
316     TpuProgramShardingType sharding_type) const {
317   std::vector<XLA_TpuProgram*> tpu_programs;
318   tpu_programs.reserve(tpu_programs_.size());
319   for (size_t i = 0; i < tpu_programs_.size(); ++i) {
320     if (OpsApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
321       tpu_programs.push_back(OpsApiFn()->TpuProgram_GetTpuProgramFn(
322           tpu_programs_[i], sharding_type));
323       CHECK_NE(tpu_programs[i], nullptr);
324     }
325   }
326   return tpu_programs;
327 }
328 
DeserializeFromRpcResponseProtos(const std::vector<TpuSerializedProto> & rpc_response_protos)329 Status TpuProgramGroup::DeserializeFromRpcResponseProtos(
330     const std::vector<TpuSerializedProto>& rpc_response_protos) {
331   std::vector<XLA_TpuProgram*> tpu_programs;
332   tpu_programs.resize(rpc_response_protos.size());
333 
334   for (size_t i = 0; i < rpc_response_protos.size(); ++i) {
335     StatusHelper status;
336     auto* xla_tpu_program = OpsApiFn()->TpuProgram_NewFn();
337     OpsApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
338         rpc_response_protos[i], xla_tpu_program, status.c_status);
339     if (!status.status().ok()) {
340       OpsApiFn()->TpuProgram_FreeFn(xla_tpu_program);
341       return status.status();
342     }
343     tpu_programs[i] = xla_tpu_program;
344   }
345 
346   Initialize(tpu_programs);
347   return Status::OK();
348 }
349 
SerializeExecutable(int index,TpuExecutableSerializedProto * executable) const350 Status TpuProgramGroup::SerializeExecutable(
351     int index, TpuExecutableSerializedProto* executable) const {
352   CHECK_GE(index, 0);
353   CHECK_LT(index, tpu_programs_.size());
354   StatusHelper status;
355   OpsApiFn()->TpuProgram_SerializeTpuExecutableFn(tpu_programs_[index],
356                                                   executable, status.c_status);
357   return status.status();
358 }
359 
SerializeCompilerMetadata(int index,CompilerMetadataSerializedProto * compiler_metadata) const360 Status TpuProgramGroup::SerializeCompilerMetadata(
361     int index, CompilerMetadataSerializedProto* compiler_metadata) const {
362   CHECK_GE(index, 0);
363   CHECK_LT(index, tpu_programs_.size());
364   StatusHelper status;
365   OpsApiFn()->TpuProgram_SerializeCompilerMetadataFn(
366       tpu_programs_[index], compiler_metadata, status.c_status);
367   return status.status();
368 }
369 }  // namespace tpu
370 }  // namespace tensorflow
371