1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/tests/test_utils.h"
27 #include "tensorflow/compiler/xla/xla.pb.h"
28 
29 namespace xla {
30 namespace {
31 
32 using absl::StrCat;
33 using ::testing::HasSubstr;
34 
35 using HloGraphDumperTest = HloTestBase;
36 
TestName()37 string TestName() {
38   return ::testing::UnitTest::GetInstance()->current_test_info()->name();
39 }
40 
TEST_F(HloGraphDumperTest,NestedFusion)41 TEST_F(HloGraphDumperTest, NestedFusion) {
42   HloComputation::Builder b("b");
43 
44   // Build param0 + param1 + param2 + param3 + param4.
45   auto shape = ShapeUtil::MakeShape(F32, {10, 100});
46   std::vector<HloInstruction*> params;
47   for (int i = 0; i <= 4; ++i) {
48     params.push_back(b.AddInstruction(
49         HloInstruction::CreateParameter(i, shape, StrCat("param", i))));
50   }
51   std::vector<HloInstruction*> sums;
52   sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
53       shape, HloOpcode::kAdd, params[0], params[1])));
54   for (int i = 0; i <= 2; ++i) {
55     sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
56         shape, HloOpcode::kAdd, sums[i], params[i + 2])));
57   }
58   HloModuleConfig config;
59   HloModule m(TestName(), config);
60   m.AddEntryComputation(b.Build());
61   HloComputation* root_computation = m.entry_computation();
62 
63   // Fuse into fusion(param0 + param1 + param2 + param3 + param4).
64   auto* outer_fusion = root_computation->CreateFusionInstruction(
65       {sums[3], sums[2], sums[1], sums[0]}, HloInstruction::FusionKind::kLoop);
66 
67   // Fusing invalidates the pointers in sums -- the instructions are cloned when
68   // they're moved to the new computation.  Get the updated pointers to sums.
69   std::vector<HloInstruction*> fused_sums;
70   for (auto* instr : outer_fusion->fused_instructions_computation()
71                          ->MakeInstructionPostOrder()) {
72     if (instr->opcode() == HloOpcode::kAdd) {
73       fused_sums.push_back(instr);
74     }
75   }
76 
77   // Fuse into fusion(fusion(param0 + param1 + param2) + param3 + param4).
78   auto* inner_fusion =
79       outer_fusion->fused_instructions_computation()->CreateFusionInstruction(
80           {fused_sums[1], fused_sums[0]}, HloInstruction::FusionKind::kLoop);
81 
82   // Generate the graph; all nodes should be present.
83   TF_ASSERT_OK_AND_ASSIGN(
84       string graph, RenderGraph(*root_computation, /*label=*/"", DebugOptions(),
85                                 RenderedGraphFormat::kDot));
86   for (const HloComputation* computation :
87        {root_computation,  //
88         inner_fusion->fused_instructions_computation(),
89         outer_fusion->fused_instructions_computation()}) {
90     for (const HloInstruction* instruction : computation->instructions()) {
91       EXPECT_THAT(graph, HasSubstr(instruction->name()));
92     }
93   }
94 
95   // Dump a neighborhood around one of the inner sum nodes.  We don't really
96   // care that the outer nodes are omitted -- whether they are or not is based
97   // fiddly heuristics -- but we do care that the node we asked for is printed.
98   const HloInstruction* inner_sum = nullptr;
99   for (const HloInstruction* instruction :
100        inner_fusion->fused_instructions_computation()->instructions()) {
101     if (instruction->opcode() == HloOpcode::kAdd) {
102       inner_sum = instruction;
103       break;
104     }
105   }
106   ASSERT_NE(inner_sum, nullptr);
107   TF_ASSERT_OK_AND_ASSIGN(string neighborhood_graph,
108                           RenderNeighborhoodAround(*inner_sum, /*radius=*/1,
109                                                    RenderedGraphFormat::kDot));
110   EXPECT_THAT(neighborhood_graph, HasSubstr(inner_sum->name()));
111 }
112 
TEST_F(HloGraphDumperTest,Constant)113 TEST_F(HloGraphDumperTest, Constant) {
114   HloComputation::Builder b("b");
115   auto instruction = b.AddInstruction(
116       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-42)));
117   instruction->SetAndSanitizeName("i_am_a_constant_root_instruction");
118   HloModuleConfig config;
119   HloModule m(TestName(), config);
120   HloComputation* root_computation = m.AddEntryComputation(b.Build());
121   TF_ASSERT_OK_AND_ASSIGN(
122       string graph, RenderGraph(*root_computation, /*label=*/"an_empty_graph",
123                                 DebugOptions(), RenderedGraphFormat::kDot));
124   EXPECT_THAT(graph, HasSubstr("an_empty_graph"));
125   EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction")));
126 }
127 
TEST_F(HloGraphDumperTest,TupleConstant)128 TEST_F(HloGraphDumperTest, TupleConstant) {
129   Shape tuple_shape = ShapeUtil::MakeTupleShape(
130       {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})});
131   HloComputation::Builder b("b");
132   auto constant = b.AddInstruction(
133       HloInstruction::CreateConstant(Literal::CreateFromShape(tuple_shape)));
134   auto gte = b.AddInstruction(HloInstruction::CreateGetTupleElement(
135       ShapeUtil::MakeShape(F32, {3, 2}), constant, 0));
136 
137   HloModuleConfig config;
138   HloModule m(TestName(), config);
139   HloComputation* root_computation = m.AddEntryComputation(b.Build(gte));
140   TF_ASSERT_OK_AND_ASSIGN(
141       string graph, RenderGraph(*root_computation, /*label=*/"tuple_constant",
142                                 DebugOptions(), RenderedGraphFormat::kDot));
143   EXPECT_THAT(graph, HasSubstr("tuple_constant"));
144   EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])"));
145 }
146 
TEST_F(HloGraphDumperTest,Compare)147 TEST_F(HloGraphDumperTest, Compare) {
148   const char* hlo_string = R"(
149     HloModule comp
150 
151     ENTRY comp {
152       param.0 = f32[10] parameter(0)
153       param.1 = f32[10] parameter(1)
154       ROOT lt = pred[10] compare(param.0, param.1), direction=LT
155     })";
156   TF_ASSERT_OK_AND_ASSIGN(auto module,
157                           ParseAndReturnVerifiedModule(hlo_string));
158   TF_ASSERT_OK_AND_ASSIGN(
159       string graph,
160       RenderGraph(*module->entry_computation(), /*label=*/"tuple_constant",
161                   DebugOptions(), RenderedGraphFormat::kDot));
162   EXPECT_THAT(graph, HasSubstr("direction=LT"));
163 }
164 
165 }  // anonymous namespace
166 }  // namespace xla
167