1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25
26 namespace xla {
27 namespace {
28
29 class HloPassPipelineTest : public HloTestBase {
30 protected:
ParseModuleGroup(absl::Span<const string> hlo_strings)31 StatusOr<HloModuleGroup> ParseModuleGroup(
32 absl::Span<const string> hlo_strings) {
33 HloModuleGroup group(TestName());
34 for (const string& hlo_string : hlo_strings) {
35 TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
36 ParseAndReturnVerifiedModule(hlo_string));
37 group.push_back(std::move(module));
38 }
39 return std::move(group);
40 }
41 };
42
43 // A module pass which renames instructions named 'foo' to 'bar'.
44 class FooToBarModulePass : public HloModulePass {
name() const45 absl::string_view name() const override { return "foo2bar"; }
46
Run(HloModule * module)47 StatusOr<bool> Run(HloModule* module) override {
48 bool changed = false;
49 for (HloComputation* computation : module->computations()) {
50 for (HloInstruction* instruction : computation->instructions()) {
51 if (instruction->name() == "foo") {
52 instruction->SetAndSanitizeName("bar");
53 changed = true;
54 }
55 }
56 }
57 return changed;
58 }
59 };
60
61 // A module group pass which renames instructions named 'baz' to 'qux'.
62 class BazToQuxModuleGroupPass : public HloModuleGroupPass {
name() const63 absl::string_view name() const override { return "baz2qux"; }
64
RunOnModuleGroup(HloModuleGroup * module_group)65 StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
66 bool changed = false;
67 for (HloModule* module : module_group->modules()) {
68 for (HloComputation* computation : module->computations()) {
69 for (HloInstruction* instruction : computation->instructions()) {
70 if (instruction->name() == "baz") {
71 instruction->SetAndSanitizeName("qux");
72 changed = true;
73 }
74 }
75 }
76 }
77 return changed;
78 }
79 };
80
81 // An invariant checker pass which returns an error if there exists an
82 // instruction named 'bar'.
83 class BarBlowerUpper : public HloModulePass {
name() const84 absl::string_view name() const override { return "bar-blower-upper"; }
85
Run(HloModule * module)86 StatusOr<bool> Run(HloModule* module) override {
87 for (HloComputation* computation : module->computations()) {
88 for (HloInstruction* instruction : computation->instructions()) {
89 if (instruction->name() == "bar") {
90 return InternalError("Module has instruction named bar");
91 }
92 }
93 }
94 return false;
95 }
96 };
97
TEST_F(HloPassPipelineTest,ModulePassChanged)98 TEST_F(HloPassPipelineTest, ModulePassChanged) {
99 // Test an HLO module pass which changes a module.
100 const string module_str = R"(
101 HloModule ModulePassChanged
102
103 ENTRY main {
104 a = f32[] parameter(0)
105 b = f32[] parameter(1)
106 ROOT foo = f32[] multiply(a, b)
107 }
108 )";
109 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
110 ParseAndReturnVerifiedModule(module_str));
111 HloPassPipeline pipeline(TestName());
112 pipeline.AddPass<FooToBarModulePass>();
113
114 HloInstruction* root = module->entry_computation()->root_instruction();
115 EXPECT_EQ(root->name(), "foo");
116 TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
117 EXPECT_TRUE(changed);
118 EXPECT_EQ(root->name(), "bar");
119 }
120
TEST_F(HloPassPipelineTest,ModulePassUnchanged)121 TEST_F(HloPassPipelineTest, ModulePassUnchanged) {
122 // Test an HLO module pass which does not change a module.
123 const string module_str = R"(
124 HloModule ModulePassUnchanged
125
126 ENTRY main {
127 a = f32[] parameter(0)
128 b = f32[] parameter(1)
129 ROOT blahblah = f32[] multiply(a, b)
130 }
131 )";
132 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
133 ParseAndReturnVerifiedModule(module_str));
134 HloPassPipeline pipeline(TestName());
135 pipeline.AddPass<FooToBarModulePass>();
136
137 TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
138 EXPECT_FALSE(changed);
139 }
140
TEST_F(HloPassPipelineTest,MixedPipeline)141 TEST_F(HloPassPipelineTest, MixedPipeline) {
142 // Test a pipeline with both a module pass and a module group pass.
143 const string module_0_str = R"(
144 HloModule MixedPipeline.1
145
146 ENTRY main {
147 a = f32[] parameter(0)
148 b = f32[] parameter(1)
149 ROOT baz = f32[] multiply(a, b)
150 }
151 )";
152 const string module_1_str = R"(
153 HloModule MixedPipeline.0
154
155 ENTRY main {
156 a = f32[] parameter(0)
157 b = f32[] parameter(1)
158 ROOT foo = f32[] multiply(a, b)
159 }
160 )";
161
162 TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group,
163 ParseModuleGroup({module_0_str, module_1_str}));
164
165 HloPassPipeline pipeline(TestName());
166 pipeline.AddPass<BazToQuxModuleGroupPass>();
167 pipeline.AddPass<FooToBarModulePass>();
168
169 HloInstruction* root0 =
170 module_group.module(0).entry_computation()->root_instruction();
171 HloInstruction* root1 =
172 module_group.module(1).entry_computation()->root_instruction();
173 EXPECT_EQ(root0->name(), "baz");
174 EXPECT_EQ(root1->name(), "foo");
175
176 TF_ASSERT_OK_AND_ASSIGN(bool changed,
177 pipeline.RunOnModuleGroup(&module_group));
178 EXPECT_TRUE(changed);
179
180 EXPECT_EQ(root0->name(), "qux");
181 EXPECT_EQ(root1->name(), "bar");
182 }
183
TEST_F(HloPassPipelineTest,InvariantChecker)184 TEST_F(HloPassPipelineTest, InvariantChecker) {
185 const string module_str = R"(
186 HloModule InvariantChecker
187
188 ENTRY main {
189 a = f32[] parameter(0)
190 b = f32[] parameter(1)
191 ROOT foo = f32[] multiply(a, b)
192 }
193 )";
194 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
195 ParseAndReturnVerifiedModule(module_str));
196 {
197 // Run a pipeline with just the invariant checker. It should not fail
198 // because there is no 'bar' instruction in the module.
199 HloPassPipeline pipeline(TestName());
200 pipeline.AddInvariantChecker<BarBlowerUpper>();
201
202 TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
203 EXPECT_FALSE(changed);
204 }
205
206 {
207 // Run a pipeline which renames 'foo' to 'bar' then an invariant checker
208 // which fails if there is an instruction named 'bar'.
209 HloPassPipeline pipeline(TestName());
210 pipeline.AddInvariantChecker<BarBlowerUpper>();
211 pipeline.AddPass<FooToBarModulePass>();
212
213 Status status = pipeline.Run(module.get()).status();
214 ASSERT_IS_NOT_OK(status);
215 EXPECT_THAT(status.error_message(),
216 ::testing::HasSubstr("Module has instruction named bar"));
217 EXPECT_THAT(status.error_message(),
218 ::testing::HasSubstr("Failed after foo2bar"));
219 }
220
221 {
222 // Run the invariant-checker only pipeline again. It should fail this time.
223 HloPassPipeline pipeline(TestName());
224 pipeline.AddInvariantChecker<BarBlowerUpper>();
225
226 Status status = pipeline.Run(module.get()).status();
227 ASSERT_IS_NOT_OK(status);
228 EXPECT_THAT(status.error_message(),
229 ::testing::HasSubstr("Module has instruction named bar"));
230 EXPECT_THAT(status.error_message(),
231 ::testing::HasSubstr("Failed after pipeline-start"));
232 }
233 }
234
TEST_F(HloPassPipelineTest,ModuleGroupPassOnModule)235 TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) {
236 // Running a module group pass on a module should produce an error.
237 const string module_str = R"(
238 HloModule ModuleGroupPassOnModule
239
240 ENTRY main {
241 a = f32[] parameter(0)
242 b = f32[] parameter(1)
243 ROOT foo = f32[] multiply(a, b)
244 }
245 )";
246 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
247 ParseAndReturnVerifiedModule(module_str));
248 HloPassPipeline pipeline(TestName());
249 pipeline.AddPass<BazToQuxModuleGroupPass>();
250
251 Status status = pipeline.Run(module.get()).status();
252 ASSERT_IS_NOT_OK(status);
253 EXPECT_THAT(
254 status.error_message(),
255 ::testing::HasSubstr("Module group pass cannot be run on a module"));
256 }
257
258 } // namespace
259 } // namespace xla
260