1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "absl/algorithm/container.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/compiler/xla/array2d.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/platform_util.h"
30 #include "tensorflow/compiler/xla/service/stream_pool.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/regexp.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/platform/types.h"
39
40 namespace xla {
41 namespace {
42
43 class HloProfileTest : public ClientLibraryTestBase {};
44
45 struct ParsedProfileOutputLine {
46 int64 cycles;
47 string cycles_percentage;
48 double usec;
49 string flops;
50 string trops;
51 string bytes_per_sec;
52 string bytes_per_cycle;
53 string opcode;
54 };
55
HasFlops(const ParsedProfileOutputLine & parsed_line)56 ::testing::AssertionResult HasFlops(
57 const ParsedProfileOutputLine& parsed_line) {
58 if (RE2::FullMatch(parsed_line.flops, "[0-9.TGMk]+FLOP/s")) {
59 return ::testing::AssertionSuccess()
60 << "'flops' field present in " << parsed_line.opcode << ": '"
61 << parsed_line.flops << "'";
62 }
63
64 return ::testing::AssertionFailure()
65 << "'flops' field absent in " << parsed_line.opcode << ": '"
66 << parsed_line.flops << "'";
67 }
68
HasTrops(const ParsedProfileOutputLine & parsed_line)69 ::testing::AssertionResult HasTrops(
70 const ParsedProfileOutputLine& parsed_line) {
71 if (RE2::FullMatch(parsed_line.trops, "[0-9.TGMk]+TROP/s")) {
72 return ::testing::AssertionSuccess()
73 << "'trops' field present in " << parsed_line.opcode << ": '"
74 << parsed_line.trops << "'";
75 }
76
77 return ::testing::AssertionFailure()
78 << "'trops' field absent in " << parsed_line.opcode << ": '"
79 << parsed_line.trops << "'";
80 }
81
ParseOneProfileOutputLine(const string & line,bool expect_hlo,absl::flat_hash_map<string,ParsedProfileOutputLine> * parsed_results,absl::Span<const absl::string_view> opcodes_to_ignore={})82 Status ParseOneProfileOutputLine(
83 const string& line, bool expect_hlo,
84 absl::flat_hash_map<string, ParsedProfileOutputLine>* parsed_results,
85 absl::Span<const absl::string_view> opcodes_to_ignore = {}) {
86 string separator = "[^:]*:: +";
87 string match_percentage = R"(\d+\.\d*% +\d+Σ)";
88 string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))";
89 string match_usecs = "([0-9.]+) usec";
90 string match_flops = "([^ ]*)";
91 string match_trops = "([^ ]*)";
92 string match_bytes_per_sec = "([0-9.TGMKi]*)(?:B/s)?";
93 string match_bytes_per_cycle = "([0-9.TGMKi]*)(?:B/cycle)?";
94
95 // The underlined part is what we're trying to match with match_opcode:
96 //
97 // %dot33 = f32[256,256]{1,0} dot(...)
98 // ^^^
99
100 string match_opcode = expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*"
101 : "(\\[total\\])( \\[entry\\])?";
102 string regexp_pattern = absl::StrCat(
103 " +", match_cycles, separator, match_usecs, separator, match_flops,
104 separator, match_trops, separator, match_bytes_per_sec, separator,
105 match_bytes_per_cycle, separator, match_opcode);
106
107 ParsedProfileOutputLine parsed_line;
108 bool matched = RE2::FullMatch(
109 line, regexp_pattern, &parsed_line.cycles, &parsed_line.cycles_percentage,
110 &parsed_line.usec, &parsed_line.flops, &parsed_line.trops,
111 &parsed_line.bytes_per_sec, &parsed_line.bytes_per_cycle,
112 &parsed_line.opcode);
113 if (!matched) {
114 return tensorflow::errors::InvalidArgument(
115 "Input did not match regexp. Input: ", line,
116 ", Regexp: ", regexp_pattern);
117 }
118
119 if (!absl::c_linear_search(opcodes_to_ignore, parsed_line.opcode)) {
120 InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
121 }
122
123 return Status::OK();
124 }
125
IsExtraMetricProfileOutputLine(const string & line)126 bool IsExtraMetricProfileOutputLine(const string& line) {
127 return RE2::FullMatch(line, "Extra metric \\S+: \\d+");
128 }
129
130 // Returns void so that we can ASSERT.
ExecuteAndFetchProfile(string * profile_output,LocalClient * client,const XlaComputation & computation,const Shape & lhs_arg_shape,const Shape & rhs_arg_shape)131 void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
132 const XlaComputation& computation,
133 const Shape& lhs_arg_shape,
134 const Shape& rhs_arg_shape) {
135 LocalService* service = ClientLibrary::GetXlaService(client->platform());
136 Backend* backend = service->mutable_backend();
137 se::StreamExecutor* executor = backend->default_stream_executor();
138 se::DeviceMemoryAllocator* allocator = backend->memory_allocator();
139 auto* transfer_manager = backend->transfer_manager();
140 TF_ASSERT_OK_AND_ASSIGN(
141 StreamPool::Ptr stream_ptr,
142 backend->BorrowStream(backend->default_device_ordinal()));
143
144 TF_ASSERT_OK_AND_ASSIGN(
145 ScopedShapedBuffer lhs_arg,
146 transfer_manager->AllocateScopedShapedBuffer(
147 lhs_arg_shape, allocator, backend->default_device_ordinal()));
148 TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
149 stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
150
151 TF_ASSERT_OK_AND_ASSIGN(
152 ScopedShapedBuffer rhs_arg,
153 transfer_manager->AllocateScopedShapedBuffer(
154 rhs_arg_shape, allocator, backend->default_device_ordinal()));
155 TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
156 stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
157
158 ExecutableBuildOptions build_options;
159 build_options.mutable_debug_options()->set_xla_hlo_profile(true);
160 TF_ASSERT_OK_AND_ASSIGN(
161 auto local_executables,
162 client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape},
163 build_options));
164
165 Executable* executable = local_executables[0]->executable();
166 HloExecutionProfile hlo_execution_profile(
167 &executable->hlo_profile_printer_data(),
168 &executable->hlo_profile_index_map());
169
170 ExecutableRunOptions exec_run_options;
171 exec_run_options.set_stream(stream_ptr.get());
172 exec_run_options.set_allocator(backend->memory_allocator());
173 exec_run_options.set_intra_op_thread_pool(
174 backend->eigen_intra_op_thread_pool_device());
175 ServiceExecutableRunOptions run_options(exec_run_options,
176 /*borrow_stream=*/nullptr);
177 std::vector<const ShapedBuffer*> args = {&lhs_arg, &rhs_arg};
178 TF_ASSERT_OK_AND_ASSIGN(
179 auto execution_result,
180 executable->ExecuteOnStream(&run_options, args, &hlo_execution_profile));
181 TF_ASSERT_OK(stream_ptr->BlockHostUntilDone());
182 (void)execution_result;
183
184 *profile_output = hlo_execution_profile.ToString(
185 executor->GetDeviceDescription().clock_rate_ghz());
186
187 XLA_VLOG_LINES(4, *profile_output);
188 }
189
XLA_TEST_F(HloProfileTest,ProfileSingleComputation)190 XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
191 const int64 m = 256, k = 256, n = 256;
192 Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
193 Shape rhs_shape = ShapeUtil::MakeShape(F32, {m, k});
194
195 TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
196 PlatformUtil::GetDefaultPlatform());
197 TF_ASSERT_OK_AND_ASSIGN(LocalClient * client,
198 ClientLibrary::GetOrCreateLocalClient(platform));
199
200 XlaBuilder builder(TestName());
201 Tanh(Add(
202 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"),
203 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs")));
204
205 TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
206
207 string profile_output;
208 ExecuteAndFetchProfile(&profile_output, client, computation, lhs_shape,
209 rhs_shape);
210 VLOG(4) << "Profile Output:\n" << profile_output;
211 std::vector<string> profile_output_lines =
212 absl::StrSplit(profile_output, '\n');
213
214 absl::flat_hash_map<string, ParsedProfileOutputLine> parsed_profile_lines;
215
216 int line_no = 0;
217
218 // Skip extra metrics.
219 while (IsExtraMetricProfileOutputLine(profile_output_lines[line_no])) {
220 line_no++;
221 }
222
223 line_no++; // Skip 'Execution profile for ....'
224
225 ASSERT_LT(line_no, profile_output_lines.size());
226 TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++],
227 /*expect_hlo=*/false,
228 &parsed_profile_lines));
229
230 ASSERT_LT(line_no, profile_output_lines.size());
231 TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++],
232 /*expect_hlo=*/true,
233 &parsed_profile_lines));
234
235 ASSERT_LT(line_no, profile_output_lines.size());
236 TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++],
237 /*expect_hlo=*/true,
238 &parsed_profile_lines));
239
240 TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_profile,
241 MaybeFind(parsed_profile_lines, "[total]"));
242 TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
243 MaybeFind(parsed_profile_lines, "add"));
244 TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine tanh_profile,
245 MaybeFind(parsed_profile_lines, "tanh"));
246
247 EXPECT_GT(total_profile.cycles, 0);
248 EXPECT_EQ(total_profile.cycles_percentage, "100.% 100Σ");
249
250 EXPECT_TRUE(HasFlops(total_profile));
251 EXPECT_TRUE(HasTrops(total_profile));
252
253 EXPECT_GT(total_profile.cycles, dot_profile.cycles);
254 EXPECT_NE(dot_profile.cycles_percentage, "0.00%");
255 EXPECT_NE(dot_profile.cycles_percentage, "100.00%");
256
257 EXPECT_TRUE(HasFlops(dot_profile));
258 EXPECT_FALSE(HasTrops(dot_profile));
259
260 EXPECT_GT(total_profile.cycles, tanh_profile.cycles);
261 EXPECT_NE(tanh_profile.cycles_percentage, "0.00%");
262 EXPECT_NE(tanh_profile.cycles_percentage, "100.00%");
263
264 EXPECT_FALSE(HasFlops(tanh_profile));
265 EXPECT_TRUE(HasTrops(tanh_profile));
266 }
267
XLA_TEST_F(HloProfileTest,ProfileWhileComputation)268 XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
269 const int64 size = 256;
270 Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size});
271 Shape while_result_shape =
272 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}), matrix_shape});
273
274 TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
275 PlatformUtil::GetDefaultPlatform());
276 TF_ASSERT_OK_AND_ASSIGN(LocalClient * client,
277 ClientLibrary::GetOrCreateLocalClient(platform));
278
279 XlaComputation condition;
280 {
281 XlaBuilder builder("condition");
282 auto state = Parameter(&builder, 0, while_result_shape, "state");
283 auto iteration = GetTupleElement(state, 0);
284 Gt(ConstantR0<int32>(&builder, 5), iteration);
285 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
286 }
287
288 XlaComputation body;
289 {
290 XlaBuilder builder("body");
291 auto state = Parameter(&builder, 0, while_result_shape, "state");
292 auto matrix = GetTupleElement(state, 1);
293 auto next_iteration =
294 Add(GetTupleElement(state, 0), ConstantR0<int32>(&builder, 1));
295 Tuple(&builder, {next_iteration, Mul(matrix, matrix)});
296 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
297 }
298
299 XlaBuilder builder(TestName());
300 auto initial_while_state =
301 Tuple(&builder, {ConstantR0<int32>(&builder, 0),
302 Parameter(&builder, 0, matrix_shape, "initial_value")});
303 auto while_result = While(condition, body, initial_while_state);
304 Add(GetTupleElement(while_result, 1),
305 Parameter(&builder, 1, matrix_shape, "other_value"));
306
307 TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
308
309 string profile_output;
310 ExecuteAndFetchProfile(&profile_output, client, computation, matrix_shape,
311 matrix_shape);
312 SCOPED_TRACE(profile_output);
313
314 std::vector<string> profile_output_lines =
315 absl::StrSplit(profile_output, '\n');
316
317 auto while_body_profile_start =
318 absl::c_find_if(profile_output_lines, [](absl::string_view s) {
319 return absl::StartsWith(s, "Execution profile for body");
320 });
321
322 ASSERT_NE(while_body_profile_start, profile_output_lines.cend());
323
324 auto while_body_profile_end =
325 std::find_if(while_body_profile_start, profile_output_lines.end(),
326 [](absl::string_view s) {
327 return absl::StartsWith(s, "********** microseconds ");
328 });
329
330 // We emit a blank line before the "microseconds report" line.
331 while_body_profile_end--;
332
333 ASSERT_NE(while_body_profile_end, profile_output_lines.end());
334
335 absl::flat_hash_map<string, ParsedProfileOutputLine> parsed_profile_lines;
336
337 for (auto while_body_profile_i = while_body_profile_start + 1;
338 while_body_profile_i != while_body_profile_end; while_body_profile_i++) {
339 // There are multiple "get-tuple-element" instructions in the while body so
340 // we ignore them -- we don't want parsed_profile_lines to be a multi-map.
341 TF_ASSERT_OK(ParseOneProfileOutputLine(
342 *while_body_profile_i,
343 /*expect_hlo=*/while_body_profile_i != (while_body_profile_start + 1),
344 &parsed_profile_lines, {"get-tuple-element"}));
345 }
346
347 TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile,
348 MaybeFind(parsed_profile_lines, "[total]"));
349 TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine multiply_profile,
350 MaybeFind(parsed_profile_lines, "multiply"));
351
352 EXPECT_GT(total_while_body_profile.cycles, 0);
353 EXPECT_EQ(total_while_body_profile.opcode, "[total]");
354 EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.% 100Σ");
355
356 EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles);
357 EXPECT_NE(multiply_profile.cycles_percentage, "0.00%");
358 EXPECT_NE(multiply_profile.cycles_percentage, "100.00%");
359 }
360 } // namespace
361 } // namespace xla
362
AddXlaHloProfileFlag(int argc,char ** argv)363 static std::pair<int, char**> AddXlaHloProfileFlag(int argc, char** argv) {
364 // Intentional "leak".
365 char** new_argv = new char*[argc + 2];
366 for (int i = 0; i < argc; i++) {
367 new_argv[i] = argv[i];
368 }
369
370 // We do it this way (as opposed to piping in a modified DebugOptions
371 // instance) for better end-to-end integration testing.
372 new_argv[argc] = strdup("--xla_hlo_profile");
373
374 // Fusion can change the Hlo instructions that show up in the final Hlo
375 // executable, so block it here. Also block the WhileLoopInvariantCodeMotion
376 // pass, otherwise a while loop is transformed and we could not match the
377 // original name in the ProfileWhileComputation test.
378 new_argv[argc + 1] = strdup(
379 "--xla_disable_hlo_passes=fusion,fusion_merger,multi_output_fusion,"
380 "while-loop-invariant-code-motion");
381 return {argc + 2, new_argv};
382 }
383
main(int argc,char ** argv)384 GTEST_API_ int main(int argc, char** argv) {
385 std::vector<tensorflow::Flag> flag_list;
386 xla::AppendDebugOptionsFlags(&flag_list);
387 std::tie(argc, argv) = AddXlaHloProfileFlag(argc, argv);
388
389 auto usage = tensorflow::Flags::Usage(argv[0], flag_list);
390 if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {
391 LOG(ERROR) << "\n" << usage;
392 return 2;
393 }
394
395 testing::InitGoogleTest(&argc, argv);
396 if (argc > 1) {
397 LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
398 return 2;
399 }
400 return RUN_ALL_TESTS();
401 }
402