1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/client/client.h"
22 #include "tensorflow/compiler/xla/client/client_library.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/padding.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/local_service.h"
29 #include "tensorflow/compiler/xla/service/service.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/logging.h"
34 
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/test_helpers.h"
37 
38 namespace xla {
39 namespace {
40 
41 constexpr int64 kPointerSize = 8;
42 
ShapeSize(const Shape & shape)43 int64 ShapeSize(const Shape& shape) {
44   return ShapeUtil::ByteSizeOf(shape, kPointerSize);
45 }
46 
47 // This test suite tests the HLO cost analysis by first building a computation
48 // using the client computation builder and running the HloCostAnalysis that
49 // returns the number of floating point and transcendental operations in the
50 // graph. We test both individual HLO operations as well as a mixed graph.
51 class HloCostAnalysisTest : public ::testing::Test {
52  protected:
HloCostAnalysisTest()53   HloCostAnalysisTest()
54       : client_(ClientLibrary::LocalClientOrDie()),
55         // Accessing service instance is required for the unit tests to enable
56         // whitebox accesses to the user computation built from the client,
57         // as shown in the BuildHloGraph functions below.
58         service_(static_cast<Service*>(ClientLibrary::GetXlaService(
59             static_cast<LocalClient*>(client_)->platform()))) {
60     // Create a computation for a unary user function: x => exp(x + 0.5)
61     {
62       XlaBuilder builder("add_and_exp");
63       auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
64       auto half = ConstantR0<float>(&builder, 0.5);
65       Exp(Add(x, half));
66       auto computation_status = builder.Build();
67       TF_CHECK_OK(computation_status.status());
68       add_and_exp_ = computation_status.ConsumeValueOrDie();
69     }
70 
71     // Create a computation for a binary user function: (x, y) => x + y
72     {
73       XlaBuilder builder("add");
74       auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
75       auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
76       Add(x, y);
77       auto computation_status = builder.Build();
78       TF_CHECK_OK(computation_status.status());
79       add_ = computation_status.ConsumeValueOrDie();
80     }
81 
82     // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x))
83     {
84       XlaBuilder builder("sigmoid");
85       auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
86       auto one = ConstantR0<float>(&builder, 1.0);
87       Div(one, Add(one, Exp(Neg(x))));
88       auto computation_status = builder.Build();
89       TF_CHECK_OK(computation_status.status());
90       sigmoid_ = computation_status.ConsumeValueOrDie();
91     }
92 
93     // Create a computation for a binary max function: (x, y) => max (x, y)
94     {
95       XlaBuilder builder("max");
96       auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
97       auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
98       Max(x, y);
99       auto computation_status = builder.Build();
100       TF_CHECK_OK(computation_status.status());
101       max_ = computation_status.ConsumeValueOrDie();
102     }
103 
104     // Create a computation for a binary GT function: (x, y) => x > y
105     {
106       XlaBuilder builder("gt");
107       auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
108       auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
109       Gt(x, y);
110       auto computation_status = builder.Build();
111       TF_CHECK_OK(computation_status.status());
112       gt_ = computation_status.ConsumeValueOrDie();
113     }
114   }
115 
116   // Build HLO graph from the given builder and return the HLO module.
BuildHloGraph(XlaBuilder * builder)117   std::unique_ptr<HloModule> BuildHloGraph(XlaBuilder* builder) {
118     auto computation_status = builder->Build();
119     TF_CHECK_OK(computation_status.status());
120     auto computation = computation_status.ConsumeValueOrDie();
121     auto config = HloModule::CreateModuleConfigFromProto(computation.proto(),
122                                                          DebugOptions())
123                       .ConsumeValueOrDie();
124     return HloModule::CreateFromProto(computation.proto(), config)
125         .ConsumeValueOrDie();
126   }
127 
128   Client* client_;
129   Service* service_;
130 
131   // User computations used for higher order operations (e.g., Map, Reduce).
132   XlaComputation add_;
133   XlaComputation add_and_exp_;
134   XlaComputation sigmoid_;
135   XlaComputation max_;
136   XlaComputation gt_;
137 };
138 
TEST_F(HloCostAnalysisTest,MatrixMultiply)139 TEST_F(HloCostAnalysisTest, MatrixMultiply) {
140   XlaBuilder builder("matrix_multiply");
141   auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
142   auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
143   Dot(lhs, rhs);
144 
145   // Run HLO cost analysis.
146   auto hlo_module = BuildHloGraph(&builder);
147   HloCostAnalysis analysis(ShapeSize);
148   ASSERT_IS_OK(
149       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
150 
151   // Check the number of computations returned from the analysis (1500 FMAs).
152   EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5);
153 
154   EXPECT_EQ(analysis.transcendental_count(), 0);
155 
156   // Bytes accessed is sum of inputs and output.
157   EXPECT_EQ(analysis.bytes_accessed(),
158             sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30));
159 }
160 
TEST_F(HloCostAnalysisTest,DotGeneral)161 TEST_F(HloCostAnalysisTest, DotGeneral) {
162   XlaBuilder builder("matrix_multiply");
163   auto lhs =
164       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs");
165   auto rhs =
166       Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs");
167   DotDimensionNumbers dnums;
168   dnums.add_lhs_contracting_dimensions(1);
169   dnums.add_lhs_contracting_dimensions(2);
170   dnums.add_rhs_contracting_dimensions(0);
171   dnums.add_rhs_contracting_dimensions(1);
172   DotGeneral(lhs, rhs, dnums);
173 
174   // Run HLO cost analysis.
175   auto hlo_module = BuildHloGraph(&builder);
176   HloCostAnalysis analysis(ShapeSize);
177   ASSERT_IS_OK(
178       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
179 
180   // Check the number of computations returned from the analysis (1500 FMAs).
181   EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5);
182 
183   EXPECT_EQ(analysis.transcendental_count(), 0);
184 
185   // Bytes accessed is sum of inputs and output.
186   EXPECT_EQ(analysis.bytes_accessed(),
187             sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 10 * 30));
188 }
189 
TEST_F(HloCostAnalysisTest,DotGeneral2)190 TEST_F(HloCostAnalysisTest, DotGeneral2) {
191   XlaBuilder builder("matrix_multiply");
192   auto lhs =
193       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs");
194   auto rhs =
195       Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs");
196   DotDimensionNumbers dnums;
197   dnums.add_lhs_contracting_dimensions(1);
198   dnums.add_lhs_batch_dimensions(2);
199   dnums.add_rhs_contracting_dimensions(0);
200   dnums.add_rhs_batch_dimensions(1);
201   DotGeneral(lhs, rhs, dnums);
202 
203   // Run HLO cost analysis.
204   auto hlo_module = BuildHloGraph(&builder);
205   HloCostAnalysis analysis(ShapeSize);
206   ASSERT_IS_OK(
207       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
208 
209   // Check the number of computations returned from the analysis (1500 FMAs).
210   EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5);
211 
212   EXPECT_EQ(analysis.transcendental_count(), 0);
213 
214   // Bytes accessed is sum of inputs and output.
215   EXPECT_EQ(analysis.bytes_accessed(),
216             sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 5 * 10 * 30));
217 }
218 
TEST_F(HloCostAnalysisTest,DotGeneral3)219 TEST_F(HloCostAnalysisTest, DotGeneral3) {
220   XlaBuilder builder("matrix_multiply");
221   auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
222   auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
223   DotDimensionNumbers dnums;
224   DotGeneral(lhs, rhs, dnums);
225 
226   // Run HLO cost analysis.
227   auto hlo_module = BuildHloGraph(&builder);
228   HloCostAnalysis analysis(ShapeSize);
229   ASSERT_IS_OK(
230       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
231 
232   // Check the number of computations returned from the analysis (1500 FMAs).
233   EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5);
234 
235   EXPECT_EQ(analysis.transcendental_count(), 0);
236 
237   // Bytes accessed is sum of inputs and output.
238   EXPECT_EQ(analysis.bytes_accessed(),
239             sizeof(float) * (10 * 5 + 5 * 30 + 5 * 5 * 10 * 30));
240 }
241 
TEST_F(HloCostAnalysisTest,Map)242 TEST_F(HloCostAnalysisTest, Map) {
243   XlaBuilder builder("map");
244   auto input = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10}), "in");
245   Map(&builder, {input}, add_and_exp_, {0});
246 
247   // Run HLO cost analysis.
248   auto hlo_module = BuildHloGraph(&builder);
249   HloCostAnalysis analysis(ShapeSize);
250   ASSERT_IS_OK(
251       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
252 
253   // add contributes to 10 flops and exp contributes to 10 transcendental ops.
254   EXPECT_EQ(analysis.flop_count(), 10);
255   EXPECT_EQ(analysis.transcendental_count(), 10);
256   EXPECT_EQ(analysis.bytes_accessed(), 80);
257 }
258 
TEST_F(HloCostAnalysisTest,Convolution)259 TEST_F(HloCostAnalysisTest, Convolution) {
260   XlaBuilder builder("convolution");
261   auto input = Parameter(
262       &builder, 0,
263       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
264                                  /*x_dim=*/20}),
265       "input");
266   auto kernel = Parameter(
267       &builder, 1,
268       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
269                                  /*x_dim=*/3}),
270       "kernel");
271   Conv(input, kernel, {1, 1}, Padding::kValid);
272 
273   // Run HLO cost analysis.
274   auto hlo_module = BuildHloGraph(&builder);
275   HloCostAnalysis analysis(ShapeSize);
276   ASSERT_IS_OK(
277       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
278 
279   // Output shape is [1x1x8x18] and each output element requires (3x3)
280   // FMAs and one FMA is 2 flops.
281   EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3);
282 
283   // Bytes accessed is sum of inputs and output.
284   EXPECT_EQ(analysis.bytes_accessed(),
285             sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
286 }
287 
TEST_F(HloCostAnalysisTest,ConvolutionWithFeatureGroup)288 TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) {
289   XlaBuilder builder("convolution");
290   auto input = Parameter(
291       &builder, 0,
292       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10,
293                                  /*x_dim=*/20}),
294       "input");
295   auto kernel = Parameter(
296       &builder, 1,
297       ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3,
298                                  /*x_dim=*/3}),
299       "kernel");
300   Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120);
301 
302   // Run HLO cost analysis.
303   auto hlo_module = BuildHloGraph(&builder);
304   HloCostAnalysis analysis(ShapeSize);
305   ASSERT_IS_OK(
306       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
307 
308   // Output shape is [1x120x8x18] and each output element requires (3x3)
309   // FMAs and one FMA is 2 flops.
310   EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3);
311 
312   // Bytes accessed is sum of inputs and output.
313   EXPECT_EQ(analysis.bytes_accessed(),
314             sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18));
315 }
316 
TEST_F(HloCostAnalysisTest,Reduce)317 TEST_F(HloCostAnalysisTest, Reduce) {
318   XlaBuilder builder("reduce");
319   auto input =
320       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
321   Reduce(input, ConstantR0<float>(&builder, 0.0f), add_, {1});
322 
323   // Run HLO cost analysis.
324   auto hlo_module = BuildHloGraph(&builder);
325   HloCostAnalysis analysis(ShapeSize);
326   ASSERT_IS_OK(
327       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
328 
329   // Subtracting the output size from the input size gives the number of
330   // reduction operations performed.
331   EXPECT_EQ(analysis.flop_count(), 10 * 20 - 10);
332 }
333 
TEST_F(HloCostAnalysisTest,ReduceWindow)334 TEST_F(HloCostAnalysisTest, ReduceWindow) {
335   XlaBuilder builder("reduce_window");
336   auto input =
337       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
338   ReduceWindow(input, ConstantR0<float>(&builder, 0), add_, {4, 5}, {4, 5},
339                Padding::kValid);
340 
341   // Run HLO cost analysis.
342   auto hlo_module = BuildHloGraph(&builder);
343   HloCostAnalysis analysis(ShapeSize);
344   ASSERT_IS_OK(
345       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
346 
347   // Each of [2x4] output elements are generated from reducing [4x5] elements.
348   EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1));
349 }
350 
TEST_F(HloCostAnalysisTest,SelectAndScatter)351 TEST_F(HloCostAnalysisTest, SelectAndScatter) {
352   XlaBuilder builder("select_and_scatter");
353   auto operand =
354       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
355   auto source =
356       Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 4}), "source");
357   SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, source,
358                    ConstantR0<float>(&builder, 0), add_);
359 
360   // Run HLO cost analysis.
361   auto hlo_module = BuildHloGraph(&builder);
362   HloCostAnalysis analysis(ShapeSize);
363   ASSERT_IS_OK(
364       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
365 
366   // Each of [2x4] source elements computes its destination from reducing [4x5]
367   // elements followed by the scatter computation.
368   EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1 + 1));
369 }
370 
TEST_F(HloCostAnalysisTest,Broadcast)371 TEST_F(HloCostAnalysisTest, Broadcast) {
372   XlaBuilder b("broadcast");
373   Broadcast(ConstantR0<float>(&b, 42), {10, 7});
374   auto hlo_module = BuildHloGraph(&b);
375   HloCostAnalysis analysis(ShapeSize);
376   ASSERT_IS_OK(
377       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
378   EXPECT_EQ(analysis.flop_count(), 0);
379 }
380 
381 // Calculates the computation cost of a graph with more than one HLO node.
TEST_F(HloCostAnalysisTest,FullyConnectedForward)382 TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
383   XlaBuilder builder("fully_connected_forward");
384   auto input =
385       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "input");
386   auto weight =
387       Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 20}), "weight");
388   auto bias = Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {20}), "bias");
389   // sigmoid(input * weight + bias)
390   Map(&builder, {Add(Dot(input, weight), bias, {1})}, sigmoid_, {0, 1});
391 
392   // Run HLO cost analysis.
393   auto hlo_module = BuildHloGraph(&builder);
394   HloCostAnalysis analysis(ShapeSize);
395   ASSERT_IS_OK(
396       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
397 
398   // 1000 FMAs from matrix multiplication, 200 flops from bias addition,
399   // 600 flops from sigmoid, and 200 transcendental ops from sigmoid.
400   EXPECT_EQ(analysis.flop_count(), 2 * 1000 + 200 + 3 * 200);
401   EXPECT_EQ(analysis.transcendental_count(), 200);
402 }
403 
TEST_F(HloCostAnalysisTest,MatmulAndConvolutionCanBeTheSameComputation)404 TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
405   HloCostAnalysis conv_analysis(ShapeSize);
406   {
407     XlaBuilder builder("conv_looking_matmul");
408     auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
409                          "input");
410     auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
411                          "weights");
412     Conv(lhs, rhs, {1, 1}, Padding::kSame);
413     auto hlo_module = BuildHloGraph(&builder);
414     ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
415         &conv_analysis));
416   }
417 
418   HloCostAnalysis matmul_analysis(ShapeSize);
419   {
420     XlaBuilder builder("matmul");
421     auto lhs =
422         Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64}), "input");
423     auto rhs =
424         Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64}), "weights");
425     Dot(lhs, rhs);
426     auto hlo_module = BuildHloGraph(&builder);
427     ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
428         &matmul_analysis));
429   }
430 
431   EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count());
432 }
433 
434 using FusionCostAnalysis = HloTestBase;
435 
TEST_F(FusionCostAnalysis,LoopFusion)436 TEST_F(FusionCostAnalysis, LoopFusion) {
437   // Do this 4 times with different per-second rates to test the computation of
438   // bottleneck time on fusion nodes.
439   for (int i = 0; i < 4; ++i) {
440     Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
441 
442     // Fuse all instructions in complicated expression:
443     //
444     //   add = Add(C1, C2)
445     //   clamp = Clamp(C2, add, add)
446     //   exp = Exp(add)
447     //   mul = Mul(exp, C3)
448     //   sub = Sub(mul, clamp)
449     //   tuple = Tuple({sub, sub, mul, C1})
450     HloComputation::Builder builder(TestName());
451     auto c1 = builder.AddInstruction(
452         HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
453             /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
454     auto c2 = builder.AddInstruction(
455         HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
456             /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
457     auto c3 = builder.AddInstruction(
458         HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
459             /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
460     auto add = builder.AddInstruction(
461         HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
462     auto clamp = builder.AddInstruction(
463         HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add));
464     auto exp = builder.AddInstruction(
465         HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add));
466     auto mul = builder.AddInstruction(
467         HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3));
468     auto sub = builder.AddInstruction(
469         HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
470     auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1});
471 
472     auto module = CreateNewVerifiedModule();
473     auto* computation = module->AddEntryComputation(builder.Build());
474     auto* fusion = computation->CreateFusionInstruction(
475         {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
476 
477     // The time given these rates at i == 0 is exactly even among the properties
478     // at 1.0 seconds. For other values, one of the rates is slower so that it
479     // becomes the bottleneck.
480     HloCostAnalysis fusion_analysis(ShapeSize);
481     fusion_analysis.set_flops_per_second(16 * (i == 1 ? 1 / 2.0 : 1.0));
482     fusion_analysis.set_transcendentals_per_second(4 *
483                                                    (i == 2 ? 1 / 4.0 : 1.0));
484     fusion_analysis.set_bytes_per_second(64 * (i == 3 ? 1 / 8.0 : 1.0));
485     ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
486 
487     EXPECT_EQ(fusion_analysis.flop_count(), 16);
488     EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
489     constexpr int64 bytes_accessed = sizeof(float) * 4 * 2 * 2;
490     static_assert(bytes_accessed == 64, "");
491     EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed);
492 
493     EXPECT_EQ(fusion_analysis.optimal_seconds(), 1 << i);
494   }
495 }
496 
TEST_F(FusionCostAnalysis,NoLayout)497 TEST_F(FusionCostAnalysis, NoLayout) {
498   Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
499   // Instructions within a fused op may have no layout.
500   Shape shape_without_layout = shape_with_layout;
501   shape_without_layout.clear_layout();
502 
503   HloComputation::Builder builder(TestName());
504   auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
505       LiteralUtil::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))));
506   auto c2 = builder.AddInstruction(
507       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3})));
508 
509   auto broadcast = builder.AddInstruction(
510       HloInstruction::CreateBroadcast(shape_without_layout, c2, {1}));
511   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
512       shape_with_layout, HloOpcode::kAdd, c1, broadcast));
513 
514   auto module = CreateNewVerifiedModule();
515   auto* computation = module->AddEntryComputation(builder.Build());
516   auto* fusion = computation->CreateFusionInstruction(
517       {add, broadcast}, HloInstruction::FusionKind::kLoop);
518 
519   HloCostAnalysis fusion_analysis(ShapeSize);
520   ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
521 
522   EXPECT_EQ(fusion_analysis.flop_count(), 120);
523   EXPECT_EQ(fusion_analysis.transcendental_count(), 0);
524 }
525 
TEST_F(HloCostAnalysisTest,TupleCost)526 TEST_F(HloCostAnalysisTest, TupleCost) {
527   HloCostAnalysis analysis(ShapeSize);
528   {
529     XlaBuilder builder("tuple");
530     auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
531     auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
532     Tuple(&builder, {x, y});
533     auto hlo_module = BuildHloGraph(&builder);
534 
535     ASSERT_IS_OK(
536         hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
537   }
538 
539   EXPECT_EQ(analysis.flop_count(), 0);
540   EXPECT_EQ(analysis.transcendental_count(), 0);
541   EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
542 }
543 
544 using DomainCostAnalysis = HloTestBase;
TEST_F(DomainCostAnalysis,DomainCost)545 TEST_F(DomainCostAnalysis, DomainCost) {
546   HloCostAnalysis analysis(ShapeSize);
547 
548   HloComputation::Builder builder("domain");
549   auto x = builder.AddInstruction(HloInstruction::CreateParameter(
550       0, ShapeUtil::MakeShape(F32, {123}), "x"));
551   auto y = builder.AddInstruction(
552       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y"));
553   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y}));
554   auto domain = builder.AddInstruction(
555       HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
556 
557   auto hlo_module = CreateNewVerifiedModule();
558   hlo_module->AddEntryComputation(builder.Build());
559 
560   EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain);
561   ASSERT_IS_OK(domain->Accept(&analysis));
562 
563   EXPECT_EQ(analysis.flop_count(*domain), 0);
564   EXPECT_EQ(analysis.transcendental_count(*domain), 0);
565   EXPECT_EQ(analysis.bytes_accessed(*domain), 0);
566 }
567 
TEST_F(HloCostAnalysisTest,BaseDilatedConvolution)568 TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
569   XlaBuilder builder("BaseDilatedConvolution");
570   auto input = Parameter(
571       &builder, 0,
572       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
573                                  /*x_dim=*/20}),
574       "input");
575   auto kernel = Parameter(
576       &builder, 1,
577       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
578                                  /*x_dim=*/3}),
579       "kernel");
580 
581   ConvGeneralDilated(input, kernel, /*window_strides=*/{1, 1},
582                      /*padding=*/{{1, 1}, {1, 1}},
583                      /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11},
584                      XlaBuilder::CreateDefaultConvDimensionNumbers(2));
585 
586   // Run HLO cost analysis.
587   auto hlo_module = BuildHloGraph(&builder);
588   HloCostAnalysis analysis(ShapeSize);
589   ASSERT_IS_OK(
590       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
591 
592   EXPECT_EQ(analysis.flop_count(), 1472);
593 }
594 
TEST_F(HloCostAnalysisTest,Slice)595 TEST_F(HloCostAnalysisTest, Slice) {
596   // Test the analysis on a slice.
597   XlaBuilder builder("slice");
598   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
599   Slice(x, {0}, {1}, {1});
600   auto hlo_module = BuildHloGraph(&builder);
601 
602   // Run HLO cost analysis.
603   HloCostAnalysis analysis(ShapeSize);
604   ASSERT_IS_OK(
605       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
606 
607   EXPECT_EQ(analysis.bytes_accessed(), 8);
608 }
609 
TEST_F(HloCostAnalysisTest,DynamicSlice)610 TEST_F(HloCostAnalysisTest, DynamicSlice) {
611   // Test the analysis on a slice.
612   XlaBuilder builder("dynamic-slice");
613   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
614   DynamicSlice(x, absl::Span<const XlaOp>({ConstantR0<int32>(&builder, 1)}),
615                {1});
616   auto hlo_module = BuildHloGraph(&builder);
617 
618   // Run HLO cost analysis.
619   HloCostAnalysis analysis(ShapeSize);
620   ASSERT_IS_OK(
621       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
622 
623   EXPECT_EQ(analysis.bytes_accessed(), 8);
624 }
625 
TEST_F(HloCostAnalysisTest,DynamicUpdateSlice)626 TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) {
627   // Test the analysis on a slice.
628   XlaBuilder builder("dynamic-update-slice");
629   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
630   DynamicUpdateSlice(x, ConstantR1<float>(&builder, {1.0}),
631                      absl::Span<const XlaOp>({ConstantR0<int32>(&builder, 1)}));
632   auto hlo_module = BuildHloGraph(&builder);
633 
634   // Run HLO cost analysis.
635   HloCostAnalysis analysis(ShapeSize);
636   ASSERT_IS_OK(
637       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
638 
639   EXPECT_EQ(analysis.bytes_accessed(), 8);
640 }
641 
TEST_F(HloCostAnalysisTest,Gather)642 TEST_F(HloCostAnalysisTest, Gather) {
643   // Test the analysis on a gather.
644   XlaBuilder builder("gather");
645   Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3});
646   Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
647 
648   auto operand = Parameter(&builder, 0, operand_shape, "operand");
649   auto indices = Parameter(&builder, 1, indices_shape, "indices");
650   GatherDimensionNumbers dim_numbers;
651   dim_numbers.add_offset_dims(1);
652   dim_numbers.add_collapsed_slice_dims(0);
653   dim_numbers.add_start_index_map(0);
654   dim_numbers.set_index_vector_dim(1);
655   Gather(operand, indices, dim_numbers, {1, 3});
656 
657   auto hlo_module = BuildHloGraph(&builder);
658 
659   // Run HLO cost analysis.
660   HloCostAnalysis analysis(ShapeSize);
661   ASSERT_IS_OK(
662       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
663 
664   EXPECT_EQ(analysis.bytes_accessed(), 56);
665 }
666 
TEST_F(HloCostAnalysisTest,Scatter)667 TEST_F(HloCostAnalysisTest, Scatter) {
668   // Test the analysis on a scatter.
669   XlaBuilder builder("scatter");
670   Shape operand_shape = ShapeUtil::MakeShape(F32, {3, 3});
671   Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
672   Shape values_shape = ShapeUtil::MakeShape(F32, {2, 3});
673 
674   auto operand = Parameter(&builder, 0, operand_shape, "operand");
675   auto indices = Parameter(&builder, 1, indices_shape, "indices");
676   auto values = Parameter(&builder, 2, values_shape, "values");
677   ScatterDimensionNumbers dim_numbers;
678   dim_numbers.set_index_vector_dim(1);
679   dim_numbers.add_update_window_dims(1);
680   dim_numbers.add_inserted_window_dims(0);
681   dim_numbers.add_scatter_dims_to_operand_dims(0);
682   Scatter(operand, indices, values, add_, dim_numbers);
683 
684   auto hlo_module = BuildHloGraph(&builder);
685 
686   // Run HLO cost analysis.
687   HloCostAnalysis analysis(ShapeSize);
688   ASSERT_IS_OK(
689       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
690 
691   EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 2 * (2 * 3)));
692 }
693 }  // namespace
694 }  // namespace xla
695