1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/client/client.h"
22 #include "tensorflow/compiler/xla/client/client_library.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/padding.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/local_service.h"
29 #include "tensorflow/compiler/xla/service/service.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/logging.h"
34
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/test_helpers.h"
37
38 namespace xla {
39 namespace {
40
41 constexpr int64 kPointerSize = 8;
42
ShapeSize(const Shape & shape)43 int64 ShapeSize(const Shape& shape) {
44 return ShapeUtil::ByteSizeOf(shape, kPointerSize);
45 }
46
47 // This test suite tests the HLO cost analysis by first building a computation
48 // using the client computation builder and running the HloCostAnalysis that
49 // returns the number of floating point and transcendental operations in the
50 // graph. We test both individual HLO operations as well as a mixed graph.
51 class HloCostAnalysisTest : public ::testing::Test {
52 protected:
HloCostAnalysisTest()53 HloCostAnalysisTest()
54 : client_(ClientLibrary::LocalClientOrDie()),
55 // Accessing service instance is required for the unit tests to enable
56 // whitebox accesses to the user computation built from the client,
57 // as shown in the BuildHloGraph functions below.
58 service_(static_cast<Service*>(ClientLibrary::GetXlaService(
59 static_cast<LocalClient*>(client_)->platform()))) {
60 // Create a computation for a unary user function: x => exp(x + 0.5)
61 {
62 XlaBuilder builder("add_and_exp");
63 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
64 auto half = ConstantR0<float>(&builder, 0.5);
65 Exp(Add(x, half));
66 auto computation_status = builder.Build();
67 TF_CHECK_OK(computation_status.status());
68 add_and_exp_ = computation_status.ConsumeValueOrDie();
69 }
70
71 // Create a computation for a binary user function: (x, y) => x + y
72 {
73 XlaBuilder builder("add");
74 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
75 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
76 Add(x, y);
77 auto computation_status = builder.Build();
78 TF_CHECK_OK(computation_status.status());
79 add_ = computation_status.ConsumeValueOrDie();
80 }
81
82 // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x))
83 {
84 XlaBuilder builder("sigmoid");
85 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
86 auto one = ConstantR0<float>(&builder, 1.0);
87 Div(one, Add(one, Exp(Neg(x))));
88 auto computation_status = builder.Build();
89 TF_CHECK_OK(computation_status.status());
90 sigmoid_ = computation_status.ConsumeValueOrDie();
91 }
92
93 // Create a computation for a binary max function: (x, y) => max (x, y)
94 {
95 XlaBuilder builder("max");
96 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
97 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
98 Max(x, y);
99 auto computation_status = builder.Build();
100 TF_CHECK_OK(computation_status.status());
101 max_ = computation_status.ConsumeValueOrDie();
102 }
103
104 // Create a computation for a binary GT function: (x, y) => x > y
105 {
106 XlaBuilder builder("gt");
107 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
108 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
109 Gt(x, y);
110 auto computation_status = builder.Build();
111 TF_CHECK_OK(computation_status.status());
112 gt_ = computation_status.ConsumeValueOrDie();
113 }
114 }
115
116 // Build HLO graph from the given builder and return the HLO module.
BuildHloGraph(XlaBuilder * builder)117 std::unique_ptr<HloModule> BuildHloGraph(XlaBuilder* builder) {
118 auto computation_status = builder->Build();
119 TF_CHECK_OK(computation_status.status());
120 auto computation = computation_status.ConsumeValueOrDie();
121 auto config = HloModule::CreateModuleConfigFromProto(computation.proto(),
122 DebugOptions())
123 .ConsumeValueOrDie();
124 return HloModule::CreateFromProto(computation.proto(), config)
125 .ConsumeValueOrDie();
126 }
127
128 Client* client_;
129 Service* service_;
130
131 // User computations used for higher order operations (e.g., Map, Reduce).
132 XlaComputation add_;
133 XlaComputation add_and_exp_;
134 XlaComputation sigmoid_;
135 XlaComputation max_;
136 XlaComputation gt_;
137 };
138
TEST_F(HloCostAnalysisTest,MatrixMultiply)139 TEST_F(HloCostAnalysisTest, MatrixMultiply) {
140 XlaBuilder builder("matrix_multiply");
141 auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
142 auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
143 Dot(lhs, rhs);
144
145 // Run HLO cost analysis.
146 auto hlo_module = BuildHloGraph(&builder);
147 HloCostAnalysis analysis(ShapeSize);
148 ASSERT_IS_OK(
149 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
150
151 // Check the number of computations returned from the analysis (1500 FMAs).
152 EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5);
153
154 EXPECT_EQ(analysis.transcendental_count(), 0);
155
156 // Bytes accessed is sum of inputs and output.
157 EXPECT_EQ(analysis.bytes_accessed(),
158 sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30));
159 }
160
TEST_F(HloCostAnalysisTest,DotGeneral)161 TEST_F(HloCostAnalysisTest, DotGeneral) {
162 XlaBuilder builder("matrix_multiply");
163 auto lhs =
164 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs");
165 auto rhs =
166 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs");
167 DotDimensionNumbers dnums;
168 dnums.add_lhs_contracting_dimensions(1);
169 dnums.add_lhs_contracting_dimensions(2);
170 dnums.add_rhs_contracting_dimensions(0);
171 dnums.add_rhs_contracting_dimensions(1);
172 DotGeneral(lhs, rhs, dnums);
173
174 // Run HLO cost analysis.
175 auto hlo_module = BuildHloGraph(&builder);
176 HloCostAnalysis analysis(ShapeSize);
177 ASSERT_IS_OK(
178 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
179
180 // Check the number of computations returned from the analysis (1500 FMAs).
181 EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5);
182
183 EXPECT_EQ(analysis.transcendental_count(), 0);
184
185 // Bytes accessed is sum of inputs and output.
186 EXPECT_EQ(analysis.bytes_accessed(),
187 sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 10 * 30));
188 }
189
TEST_F(HloCostAnalysisTest,DotGeneral2)190 TEST_F(HloCostAnalysisTest, DotGeneral2) {
191 XlaBuilder builder("matrix_multiply");
192 auto lhs =
193 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs");
194 auto rhs =
195 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs");
196 DotDimensionNumbers dnums;
197 dnums.add_lhs_contracting_dimensions(1);
198 dnums.add_lhs_batch_dimensions(2);
199 dnums.add_rhs_contracting_dimensions(0);
200 dnums.add_rhs_batch_dimensions(1);
201 DotGeneral(lhs, rhs, dnums);
202
203 // Run HLO cost analysis.
204 auto hlo_module = BuildHloGraph(&builder);
205 HloCostAnalysis analysis(ShapeSize);
206 ASSERT_IS_OK(
207 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
208
209 // Check the number of computations returned from the analysis (1500 FMAs).
210 EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5);
211
212 EXPECT_EQ(analysis.transcendental_count(), 0);
213
214 // Bytes accessed is sum of inputs and output.
215 EXPECT_EQ(analysis.bytes_accessed(),
216 sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 5 * 10 * 30));
217 }
218
TEST_F(HloCostAnalysisTest,DotGeneral3)219 TEST_F(HloCostAnalysisTest, DotGeneral3) {
220 XlaBuilder builder("matrix_multiply");
221 auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
222 auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
223 DotDimensionNumbers dnums;
224 DotGeneral(lhs, rhs, dnums);
225
226 // Run HLO cost analysis.
227 auto hlo_module = BuildHloGraph(&builder);
228 HloCostAnalysis analysis(ShapeSize);
229 ASSERT_IS_OK(
230 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
231
232 // Check the number of computations returned from the analysis (1500 FMAs).
233 EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5);
234
235 EXPECT_EQ(analysis.transcendental_count(), 0);
236
237 // Bytes accessed is sum of inputs and output.
238 EXPECT_EQ(analysis.bytes_accessed(),
239 sizeof(float) * (10 * 5 + 5 * 30 + 5 * 5 * 10 * 30));
240 }
241
TEST_F(HloCostAnalysisTest,Map)242 TEST_F(HloCostAnalysisTest, Map) {
243 XlaBuilder builder("map");
244 auto input = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10}), "in");
245 Map(&builder, {input}, add_and_exp_, {0});
246
247 // Run HLO cost analysis.
248 auto hlo_module = BuildHloGraph(&builder);
249 HloCostAnalysis analysis(ShapeSize);
250 ASSERT_IS_OK(
251 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
252
253 // add contributes to 10 flops and exp contributes to 10 transcendental ops.
254 EXPECT_EQ(analysis.flop_count(), 10);
255 EXPECT_EQ(analysis.transcendental_count(), 10);
256 EXPECT_EQ(analysis.bytes_accessed(), 80);
257 }
258
TEST_F(HloCostAnalysisTest,Convolution)259 TEST_F(HloCostAnalysisTest, Convolution) {
260 XlaBuilder builder("convolution");
261 auto input = Parameter(
262 &builder, 0,
263 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
264 /*x_dim=*/20}),
265 "input");
266 auto kernel = Parameter(
267 &builder, 1,
268 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
269 /*x_dim=*/3}),
270 "kernel");
271 Conv(input, kernel, {1, 1}, Padding::kValid);
272
273 // Run HLO cost analysis.
274 auto hlo_module = BuildHloGraph(&builder);
275 HloCostAnalysis analysis(ShapeSize);
276 ASSERT_IS_OK(
277 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
278
279 // Output shape is [1x1x8x18] and each output element requires (3x3)
280 // FMAs and one FMA is 2 flops.
281 EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3);
282
283 // Bytes accessed is sum of inputs and output.
284 EXPECT_EQ(analysis.bytes_accessed(),
285 sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
286 }
287
TEST_F(HloCostAnalysisTest,ConvolutionWithFeatureGroup)288 TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) {
289 XlaBuilder builder("convolution");
290 auto input = Parameter(
291 &builder, 0,
292 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10,
293 /*x_dim=*/20}),
294 "input");
295 auto kernel = Parameter(
296 &builder, 1,
297 ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3,
298 /*x_dim=*/3}),
299 "kernel");
300 Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120);
301
302 // Run HLO cost analysis.
303 auto hlo_module = BuildHloGraph(&builder);
304 HloCostAnalysis analysis(ShapeSize);
305 ASSERT_IS_OK(
306 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
307
308 // Output shape is [1x120x8x18] and each output element requires (3x3)
309 // FMAs and one FMA is 2 flops.
310 EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3);
311
312 // Bytes accessed is sum of inputs and output.
313 EXPECT_EQ(analysis.bytes_accessed(),
314 sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18));
315 }
316
TEST_F(HloCostAnalysisTest,Reduce)317 TEST_F(HloCostAnalysisTest, Reduce) {
318 XlaBuilder builder("reduce");
319 auto input =
320 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
321 Reduce(input, ConstantR0<float>(&builder, 0.0f), add_, {1});
322
323 // Run HLO cost analysis.
324 auto hlo_module = BuildHloGraph(&builder);
325 HloCostAnalysis analysis(ShapeSize);
326 ASSERT_IS_OK(
327 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
328
329 // Subtracting the output size from the input size gives the number of
330 // reduction operations performed.
331 EXPECT_EQ(analysis.flop_count(), 10 * 20 - 10);
332 }
333
TEST_F(HloCostAnalysisTest,ReduceWindow)334 TEST_F(HloCostAnalysisTest, ReduceWindow) {
335 XlaBuilder builder("reduce_window");
336 auto input =
337 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
338 ReduceWindow(input, ConstantR0<float>(&builder, 0), add_, {4, 5}, {4, 5},
339 Padding::kValid);
340
341 // Run HLO cost analysis.
342 auto hlo_module = BuildHloGraph(&builder);
343 HloCostAnalysis analysis(ShapeSize);
344 ASSERT_IS_OK(
345 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
346
347 // Each of [2x4] output elements are generated from reducing [4x5] elements.
348 EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1));
349 }
350
TEST_F(HloCostAnalysisTest,SelectAndScatter)351 TEST_F(HloCostAnalysisTest, SelectAndScatter) {
352 XlaBuilder builder("select_and_scatter");
353 auto operand =
354 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
355 auto source =
356 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 4}), "source");
357 SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, source,
358 ConstantR0<float>(&builder, 0), add_);
359
360 // Run HLO cost analysis.
361 auto hlo_module = BuildHloGraph(&builder);
362 HloCostAnalysis analysis(ShapeSize);
363 ASSERT_IS_OK(
364 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
365
366 // Each of [2x4] source elements computes its destination from reducing [4x5]
367 // elements followed by the scatter computation.
368 EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1 + 1));
369 }
370
TEST_F(HloCostAnalysisTest,Broadcast)371 TEST_F(HloCostAnalysisTest, Broadcast) {
372 XlaBuilder b("broadcast");
373 Broadcast(ConstantR0<float>(&b, 42), {10, 7});
374 auto hlo_module = BuildHloGraph(&b);
375 HloCostAnalysis analysis(ShapeSize);
376 ASSERT_IS_OK(
377 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
378 EXPECT_EQ(analysis.flop_count(), 0);
379 }
380
381 // Calculates the computation cost of a graph with more than one HLO node.
TEST_F(HloCostAnalysisTest,FullyConnectedForward)382 TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
383 XlaBuilder builder("fully_connected_forward");
384 auto input =
385 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "input");
386 auto weight =
387 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 20}), "weight");
388 auto bias = Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {20}), "bias");
389 // sigmoid(input * weight + bias)
390 Map(&builder, {Add(Dot(input, weight), bias, {1})}, sigmoid_, {0, 1});
391
392 // Run HLO cost analysis.
393 auto hlo_module = BuildHloGraph(&builder);
394 HloCostAnalysis analysis(ShapeSize);
395 ASSERT_IS_OK(
396 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
397
398 // 1000 FMAs from matrix multiplication, 200 flops from bias addition,
399 // 600 flops from sigmoid, and 200 transcendental ops from sigmoid.
400 EXPECT_EQ(analysis.flop_count(), 2 * 1000 + 200 + 3 * 200);
401 EXPECT_EQ(analysis.transcendental_count(), 200);
402 }
403
TEST_F(HloCostAnalysisTest,MatmulAndConvolutionCanBeTheSameComputation)404 TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
405 HloCostAnalysis conv_analysis(ShapeSize);
406 {
407 XlaBuilder builder("conv_looking_matmul");
408 auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
409 "input");
410 auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
411 "weights");
412 Conv(lhs, rhs, {1, 1}, Padding::kSame);
413 auto hlo_module = BuildHloGraph(&builder);
414 ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
415 &conv_analysis));
416 }
417
418 HloCostAnalysis matmul_analysis(ShapeSize);
419 {
420 XlaBuilder builder("matmul");
421 auto lhs =
422 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64}), "input");
423 auto rhs =
424 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64}), "weights");
425 Dot(lhs, rhs);
426 auto hlo_module = BuildHloGraph(&builder);
427 ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
428 &matmul_analysis));
429 }
430
431 EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count());
432 }
433
434 using FusionCostAnalysis = HloTestBase;
435
TEST_F(FusionCostAnalysis,LoopFusion)436 TEST_F(FusionCostAnalysis, LoopFusion) {
437 // Do this 4 times with different per-second rates to test the computation of
438 // bottleneck time on fusion nodes.
439 for (int i = 0; i < 4; ++i) {
440 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
441
442 // Fuse all instructions in complicated expression:
443 //
444 // add = Add(C1, C2)
445 // clamp = Clamp(C2, add, add)
446 // exp = Exp(add)
447 // mul = Mul(exp, C3)
448 // sub = Sub(mul, clamp)
449 // tuple = Tuple({sub, sub, mul, C1})
450 HloComputation::Builder builder(TestName());
451 auto c1 = builder.AddInstruction(
452 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
453 /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
454 auto c2 = builder.AddInstruction(
455 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
456 /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
457 auto c3 = builder.AddInstruction(
458 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
459 /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
460 auto add = builder.AddInstruction(
461 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
462 auto clamp = builder.AddInstruction(
463 HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add));
464 auto exp = builder.AddInstruction(
465 HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add));
466 auto mul = builder.AddInstruction(
467 HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3));
468 auto sub = builder.AddInstruction(
469 HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
470 auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1});
471
472 auto module = CreateNewVerifiedModule();
473 auto* computation = module->AddEntryComputation(builder.Build());
474 auto* fusion = computation->CreateFusionInstruction(
475 {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
476
477 // The time given these rates at i == 0 is exactly even among the properties
478 // at 1.0 seconds. For other values, one of the rates is slower so that it
479 // becomes the bottleneck.
480 HloCostAnalysis fusion_analysis(ShapeSize);
481 fusion_analysis.set_flops_per_second(16 * (i == 1 ? 1 / 2.0 : 1.0));
482 fusion_analysis.set_transcendentals_per_second(4 *
483 (i == 2 ? 1 / 4.0 : 1.0));
484 fusion_analysis.set_bytes_per_second(64 * (i == 3 ? 1 / 8.0 : 1.0));
485 ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
486
487 EXPECT_EQ(fusion_analysis.flop_count(), 16);
488 EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
489 constexpr int64 bytes_accessed = sizeof(float) * 4 * 2 * 2;
490 static_assert(bytes_accessed == 64, "");
491 EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed);
492
493 EXPECT_EQ(fusion_analysis.optimal_seconds(), 1 << i);
494 }
495 }
496
TEST_F(FusionCostAnalysis,NoLayout)497 TEST_F(FusionCostAnalysis, NoLayout) {
498 Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
499 // Instructions within a fused op may have no layout.
500 Shape shape_without_layout = shape_with_layout;
501 shape_without_layout.clear_layout();
502
503 HloComputation::Builder builder(TestName());
504 auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
505 LiteralUtil::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))));
506 auto c2 = builder.AddInstruction(
507 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3})));
508
509 auto broadcast = builder.AddInstruction(
510 HloInstruction::CreateBroadcast(shape_without_layout, c2, {1}));
511 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
512 shape_with_layout, HloOpcode::kAdd, c1, broadcast));
513
514 auto module = CreateNewVerifiedModule();
515 auto* computation = module->AddEntryComputation(builder.Build());
516 auto* fusion = computation->CreateFusionInstruction(
517 {add, broadcast}, HloInstruction::FusionKind::kLoop);
518
519 HloCostAnalysis fusion_analysis(ShapeSize);
520 ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
521
522 EXPECT_EQ(fusion_analysis.flop_count(), 120);
523 EXPECT_EQ(fusion_analysis.transcendental_count(), 0);
524 }
525
TEST_F(HloCostAnalysisTest,TupleCost)526 TEST_F(HloCostAnalysisTest, TupleCost) {
527 HloCostAnalysis analysis(ShapeSize);
528 {
529 XlaBuilder builder("tuple");
530 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
531 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
532 Tuple(&builder, {x, y});
533 auto hlo_module = BuildHloGraph(&builder);
534
535 ASSERT_IS_OK(
536 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
537 }
538
539 EXPECT_EQ(analysis.flop_count(), 0);
540 EXPECT_EQ(analysis.transcendental_count(), 0);
541 EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
542 }
543
544 using DomainCostAnalysis = HloTestBase;
TEST_F(DomainCostAnalysis,DomainCost)545 TEST_F(DomainCostAnalysis, DomainCost) {
546 HloCostAnalysis analysis(ShapeSize);
547
548 HloComputation::Builder builder("domain");
549 auto x = builder.AddInstruction(HloInstruction::CreateParameter(
550 0, ShapeUtil::MakeShape(F32, {123}), "x"));
551 auto y = builder.AddInstruction(
552 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y"));
553 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y}));
554 auto domain = builder.AddInstruction(
555 HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
556
557 auto hlo_module = CreateNewVerifiedModule();
558 hlo_module->AddEntryComputation(builder.Build());
559
560 EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain);
561 ASSERT_IS_OK(domain->Accept(&analysis));
562
563 EXPECT_EQ(analysis.flop_count(*domain), 0);
564 EXPECT_EQ(analysis.transcendental_count(*domain), 0);
565 EXPECT_EQ(analysis.bytes_accessed(*domain), 0);
566 }
567
TEST_F(HloCostAnalysisTest,BaseDilatedConvolution)568 TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
569 XlaBuilder builder("BaseDilatedConvolution");
570 auto input = Parameter(
571 &builder, 0,
572 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
573 /*x_dim=*/20}),
574 "input");
575 auto kernel = Parameter(
576 &builder, 1,
577 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
578 /*x_dim=*/3}),
579 "kernel");
580
581 ConvGeneralDilated(input, kernel, /*window_strides=*/{1, 1},
582 /*padding=*/{{1, 1}, {1, 1}},
583 /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11},
584 XlaBuilder::CreateDefaultConvDimensionNumbers(2));
585
586 // Run HLO cost analysis.
587 auto hlo_module = BuildHloGraph(&builder);
588 HloCostAnalysis analysis(ShapeSize);
589 ASSERT_IS_OK(
590 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
591
592 EXPECT_EQ(analysis.flop_count(), 1472);
593 }
594
TEST_F(HloCostAnalysisTest,Slice)595 TEST_F(HloCostAnalysisTest, Slice) {
596 // Test the analysis on a slice.
597 XlaBuilder builder("slice");
598 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
599 Slice(x, {0}, {1}, {1});
600 auto hlo_module = BuildHloGraph(&builder);
601
602 // Run HLO cost analysis.
603 HloCostAnalysis analysis(ShapeSize);
604 ASSERT_IS_OK(
605 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
606
607 EXPECT_EQ(analysis.bytes_accessed(), 8);
608 }
609
TEST_F(HloCostAnalysisTest,DynamicSlice)610 TEST_F(HloCostAnalysisTest, DynamicSlice) {
611 // Test the analysis on a slice.
612 XlaBuilder builder("dynamic-slice");
613 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
614 DynamicSlice(x, absl::Span<const XlaOp>({ConstantR0<int32>(&builder, 1)}),
615 {1});
616 auto hlo_module = BuildHloGraph(&builder);
617
618 // Run HLO cost analysis.
619 HloCostAnalysis analysis(ShapeSize);
620 ASSERT_IS_OK(
621 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
622
623 EXPECT_EQ(analysis.bytes_accessed(), 8);
624 }
625
TEST_F(HloCostAnalysisTest,DynamicUpdateSlice)626 TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) {
627 // Test the analysis on a slice.
628 XlaBuilder builder("dynamic-update-slice");
629 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
630 DynamicUpdateSlice(x, ConstantR1<float>(&builder, {1.0}),
631 absl::Span<const XlaOp>({ConstantR0<int32>(&builder, 1)}));
632 auto hlo_module = BuildHloGraph(&builder);
633
634 // Run HLO cost analysis.
635 HloCostAnalysis analysis(ShapeSize);
636 ASSERT_IS_OK(
637 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
638
639 EXPECT_EQ(analysis.bytes_accessed(), 8);
640 }
641
TEST_F(HloCostAnalysisTest,Gather)642 TEST_F(HloCostAnalysisTest, Gather) {
643 // Test the analysis on a gather.
644 XlaBuilder builder("gather");
645 Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3});
646 Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
647
648 auto operand = Parameter(&builder, 0, operand_shape, "operand");
649 auto indices = Parameter(&builder, 1, indices_shape, "indices");
650 GatherDimensionNumbers dim_numbers;
651 dim_numbers.add_offset_dims(1);
652 dim_numbers.add_collapsed_slice_dims(0);
653 dim_numbers.add_start_index_map(0);
654 dim_numbers.set_index_vector_dim(1);
655 Gather(operand, indices, dim_numbers, {1, 3});
656
657 auto hlo_module = BuildHloGraph(&builder);
658
659 // Run HLO cost analysis.
660 HloCostAnalysis analysis(ShapeSize);
661 ASSERT_IS_OK(
662 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
663
664 EXPECT_EQ(analysis.bytes_accessed(), 56);
665 }
666
TEST_F(HloCostAnalysisTest,Scatter)667 TEST_F(HloCostAnalysisTest, Scatter) {
668 // Test the analysis on a scatter.
669 XlaBuilder builder("scatter");
670 Shape operand_shape = ShapeUtil::MakeShape(F32, {3, 3});
671 Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
672 Shape values_shape = ShapeUtil::MakeShape(F32, {2, 3});
673
674 auto operand = Parameter(&builder, 0, operand_shape, "operand");
675 auto indices = Parameter(&builder, 1, indices_shape, "indices");
676 auto values = Parameter(&builder, 2, values_shape, "values");
677 ScatterDimensionNumbers dim_numbers;
678 dim_numbers.set_index_vector_dim(1);
679 dim_numbers.add_update_window_dims(1);
680 dim_numbers.add_inserted_window_dims(0);
681 dim_numbers.add_scatter_dims_to_operand_dims(0);
682 Scatter(operand, indices, values, add_, dim_numbers);
683
684 auto hlo_module = BuildHloGraph(&builder);
685
686 // Run HLO cost analysis.
687 HloCostAnalysis analysis(ShapeSize);
688 ASSERT_IS_OK(
689 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
690
691 EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 2 * (2 * 3)));
692 }
693 } // namespace
694 } // namespace xla
695