1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/tests/test_utils.h"
27 #include "tensorflow/compiler/xla/xla.pb.h"
28
29 namespace xla {
30 namespace {
31
32 using absl::StrCat;
33 using ::testing::HasSubstr;
34
35 using HloGraphDumperTest = HloTestBase;
36
TestName()37 string TestName() {
38 return ::testing::UnitTest::GetInstance()->current_test_info()->name();
39 }
40
TEST_F(HloGraphDumperTest,NestedFusion)41 TEST_F(HloGraphDumperTest, NestedFusion) {
42 HloComputation::Builder b("b");
43
44 // Build param0 + param1 + param2 + param3 + param4.
45 auto shape = ShapeUtil::MakeShape(F32, {10, 100});
46 std::vector<HloInstruction*> params;
47 for (int i = 0; i <= 4; ++i) {
48 params.push_back(b.AddInstruction(
49 HloInstruction::CreateParameter(i, shape, StrCat("param", i))));
50 }
51 std::vector<HloInstruction*> sums;
52 sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
53 shape, HloOpcode::kAdd, params[0], params[1])));
54 for (int i = 0; i <= 2; ++i) {
55 sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
56 shape, HloOpcode::kAdd, sums[i], params[i + 2])));
57 }
58 HloModuleConfig config;
59 HloModule m(TestName(), config);
60 m.AddEntryComputation(b.Build());
61 HloComputation* root_computation = m.entry_computation();
62
63 // Fuse into fusion(param0 + param1 + param2 + param3 + param4).
64 auto* outer_fusion = root_computation->CreateFusionInstruction(
65 {sums[3], sums[2], sums[1], sums[0]}, HloInstruction::FusionKind::kLoop);
66
67 // Fusing invalidates the pointers in sums -- the instructions are cloned when
68 // they're moved to the new computation. Get the updated pointers to sums.
69 std::vector<HloInstruction*> fused_sums;
70 for (auto* instr : outer_fusion->fused_instructions_computation()
71 ->MakeInstructionPostOrder()) {
72 if (instr->opcode() == HloOpcode::kAdd) {
73 fused_sums.push_back(instr);
74 }
75 }
76
77 // Fuse into fusion(fusion(param0 + param1 + param2) + param3 + param4).
78 auto* inner_fusion =
79 outer_fusion->fused_instructions_computation()->CreateFusionInstruction(
80 {fused_sums[1], fused_sums[0]}, HloInstruction::FusionKind::kLoop);
81
82 // Generate the graph; all nodes should be present.
83 TF_ASSERT_OK_AND_ASSIGN(
84 string graph, RenderGraph(*root_computation, /*label=*/"", DebugOptions(),
85 RenderedGraphFormat::kDot));
86 for (const HloComputation* computation :
87 {root_computation, //
88 inner_fusion->fused_instructions_computation(),
89 outer_fusion->fused_instructions_computation()}) {
90 for (const HloInstruction* instruction : computation->instructions()) {
91 EXPECT_THAT(graph, HasSubstr(instruction->name()));
92 }
93 }
94
95 // Dump a neighborhood around one of the inner sum nodes. We don't really
96 // care that the outer nodes are omitted -- whether they are or not is based
97 // fiddly heuristics -- but we do care that the node we asked for is printed.
98 const HloInstruction* inner_sum = nullptr;
99 for (const HloInstruction* instruction :
100 inner_fusion->fused_instructions_computation()->instructions()) {
101 if (instruction->opcode() == HloOpcode::kAdd) {
102 inner_sum = instruction;
103 break;
104 }
105 }
106 ASSERT_NE(inner_sum, nullptr);
107 TF_ASSERT_OK_AND_ASSIGN(string neighborhood_graph,
108 RenderNeighborhoodAround(*inner_sum, /*radius=*/1,
109 RenderedGraphFormat::kDot));
110 EXPECT_THAT(neighborhood_graph, HasSubstr(inner_sum->name()));
111 }
112
TEST_F(HloGraphDumperTest,Constant)113 TEST_F(HloGraphDumperTest, Constant) {
114 HloComputation::Builder b("b");
115 auto instruction = b.AddInstruction(
116 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-42)));
117 instruction->SetAndSanitizeName("i_am_a_constant_root_instruction");
118 HloModuleConfig config;
119 HloModule m(TestName(), config);
120 HloComputation* root_computation = m.AddEntryComputation(b.Build());
121 TF_ASSERT_OK_AND_ASSIGN(
122 string graph, RenderGraph(*root_computation, /*label=*/"an_empty_graph",
123 DebugOptions(), RenderedGraphFormat::kDot));
124 EXPECT_THAT(graph, HasSubstr("an_empty_graph"));
125 EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction")));
126 }
127
TEST_F(HloGraphDumperTest,TupleConstant)128 TEST_F(HloGraphDumperTest, TupleConstant) {
129 Shape tuple_shape = ShapeUtil::MakeTupleShape(
130 {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})});
131 HloComputation::Builder b("b");
132 auto constant = b.AddInstruction(
133 HloInstruction::CreateConstant(Literal::CreateFromShape(tuple_shape)));
134 auto gte = b.AddInstruction(HloInstruction::CreateGetTupleElement(
135 ShapeUtil::MakeShape(F32, {3, 2}), constant, 0));
136
137 HloModuleConfig config;
138 HloModule m(TestName(), config);
139 HloComputation* root_computation = m.AddEntryComputation(b.Build(gte));
140 TF_ASSERT_OK_AND_ASSIGN(
141 string graph, RenderGraph(*root_computation, /*label=*/"tuple_constant",
142 DebugOptions(), RenderedGraphFormat::kDot));
143 EXPECT_THAT(graph, HasSubstr("tuple_constant"));
144 EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])"));
145 }
146
TEST_F(HloGraphDumperTest,Compare)147 TEST_F(HloGraphDumperTest, Compare) {
148 const char* hlo_string = R"(
149 HloModule comp
150
151 ENTRY comp {
152 param.0 = f32[10] parameter(0)
153 param.1 = f32[10] parameter(1)
154 ROOT lt = pred[10] compare(param.0, param.1), direction=LT
155 })";
156 TF_ASSERT_OK_AND_ASSIGN(auto module,
157 ParseAndReturnVerifiedModule(hlo_string));
158 TF_ASSERT_OK_AND_ASSIGN(
159 string graph,
160 RenderGraph(*module->entry_computation(), /*label=*/"tuple_constant",
161 DebugOptions(), RenderedGraphFormat::kDot));
162 EXPECT_THAT(graph, HasSubstr("direction=LT"));
163 }
164
165 } // anonymous namespace
166 } // namespace xla
167