1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 
26 namespace xla {
27 namespace {
28 
29 class HloPassPipelineTest : public HloTestBase {
30  protected:
ParseModuleGroup(absl::Span<const string> hlo_strings)31   StatusOr<HloModuleGroup> ParseModuleGroup(
32       absl::Span<const string> hlo_strings) {
33     HloModuleGroup group(TestName());
34     for (const string& hlo_string : hlo_strings) {
35       TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
36                           ParseAndReturnVerifiedModule(hlo_string));
37       group.push_back(std::move(module));
38     }
39     return std::move(group);
40   }
41 };
42 
43 // A module pass which renames instructions named 'foo' to 'bar'.
44 class FooToBarModulePass : public HloModulePass {
name() const45   absl::string_view name() const override { return "foo2bar"; }
46 
Run(HloModule * module)47   StatusOr<bool> Run(HloModule* module) override {
48     bool changed = false;
49     for (HloComputation* computation : module->computations()) {
50       for (HloInstruction* instruction : computation->instructions()) {
51         if (instruction->name() == "foo") {
52           instruction->SetAndSanitizeName("bar");
53           changed = true;
54         }
55       }
56     }
57     return changed;
58   }
59 };
60 
61 // A module group pass which renames instructions named 'baz' to 'qux'.
62 class BazToQuxModuleGroupPass : public HloModuleGroupPass {
name() const63   absl::string_view name() const override { return "baz2qux"; }
64 
RunOnModuleGroup(HloModuleGroup * module_group)65   StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
66     bool changed = false;
67     for (HloModule* module : module_group->modules()) {
68       for (HloComputation* computation : module->computations()) {
69         for (HloInstruction* instruction : computation->instructions()) {
70           if (instruction->name() == "baz") {
71             instruction->SetAndSanitizeName("qux");
72             changed = true;
73           }
74         }
75       }
76     }
77     return changed;
78   }
79 };
80 
81 // An invariant checker pass which returns an error if there exists an
82 // instruction named 'bar'.
83 class BarBlowerUpper : public HloModulePass {
name() const84   absl::string_view name() const override { return "bar-blower-upper"; }
85 
Run(HloModule * module)86   StatusOr<bool> Run(HloModule* module) override {
87     for (HloComputation* computation : module->computations()) {
88       for (HloInstruction* instruction : computation->instructions()) {
89         if (instruction->name() == "bar") {
90           return InternalError("Module has instruction named bar");
91         }
92       }
93     }
94     return false;
95   }
96 };
97 
TEST_F(HloPassPipelineTest,ModulePassChanged)98 TEST_F(HloPassPipelineTest, ModulePassChanged) {
99   // Test an HLO module pass which changes a module.
100   const string module_str = R"(
101 HloModule ModulePassChanged
102 
103 ENTRY main {
104   a = f32[] parameter(0)
105   b = f32[] parameter(1)
106   ROOT foo = f32[] multiply(a, b)
107 }
108 )";
109   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
110                           ParseAndReturnVerifiedModule(module_str));
111   HloPassPipeline pipeline(TestName());
112   pipeline.AddPass<FooToBarModulePass>();
113 
114   HloInstruction* root = module->entry_computation()->root_instruction();
115   EXPECT_EQ(root->name(), "foo");
116   TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
117   EXPECT_TRUE(changed);
118   EXPECT_EQ(root->name(), "bar");
119 }
120 
TEST_F(HloPassPipelineTest,ModulePassUnchanged)121 TEST_F(HloPassPipelineTest, ModulePassUnchanged) {
122   // Test an HLO module pass which does not change a module.
123   const string module_str = R"(
124 HloModule ModulePassUnchanged
125 
126 ENTRY main {
127   a = f32[] parameter(0)
128   b = f32[] parameter(1)
129   ROOT blahblah = f32[] multiply(a, b)
130 }
131 )";
132   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
133                           ParseAndReturnVerifiedModule(module_str));
134   HloPassPipeline pipeline(TestName());
135   pipeline.AddPass<FooToBarModulePass>();
136 
137   TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
138   EXPECT_FALSE(changed);
139 }
140 
TEST_F(HloPassPipelineTest,MixedPipeline)141 TEST_F(HloPassPipelineTest, MixedPipeline) {
142   // Test a pipeline with both a module pass and a module group pass.
143   const string module_0_str = R"(
144 HloModule MixedPipeline.1
145 
146 ENTRY main {
147   a = f32[] parameter(0)
148   b = f32[] parameter(1)
149   ROOT baz = f32[] multiply(a, b)
150 }
151 )";
152   const string module_1_str = R"(
153 HloModule MixedPipeline.0
154 
155 ENTRY main {
156   a = f32[] parameter(0)
157   b = f32[] parameter(1)
158   ROOT foo = f32[] multiply(a, b)
159 }
160 )";
161 
162   TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group,
163                           ParseModuleGroup({module_0_str, module_1_str}));
164 
165   HloPassPipeline pipeline(TestName());
166   pipeline.AddPass<BazToQuxModuleGroupPass>();
167   pipeline.AddPass<FooToBarModulePass>();
168 
169   HloInstruction* root0 =
170       module_group.module(0).entry_computation()->root_instruction();
171   HloInstruction* root1 =
172       module_group.module(1).entry_computation()->root_instruction();
173   EXPECT_EQ(root0->name(), "baz");
174   EXPECT_EQ(root1->name(), "foo");
175 
176   TF_ASSERT_OK_AND_ASSIGN(bool changed,
177                           pipeline.RunOnModuleGroup(&module_group));
178   EXPECT_TRUE(changed);
179 
180   EXPECT_EQ(root0->name(), "qux");
181   EXPECT_EQ(root1->name(), "bar");
182 }
183 
TEST_F(HloPassPipelineTest,InvariantChecker)184 TEST_F(HloPassPipelineTest, InvariantChecker) {
185   const string module_str = R"(
186 HloModule InvariantChecker
187 
188 ENTRY main {
189   a = f32[] parameter(0)
190   b = f32[] parameter(1)
191   ROOT foo = f32[] multiply(a, b)
192 }
193 )";
194   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
195                           ParseAndReturnVerifiedModule(module_str));
196   {
197     // Run a pipeline with just the invariant checker. It should not fail
198     // because there is no 'bar' instruction in the module.
199     HloPassPipeline pipeline(TestName());
200     pipeline.AddInvariantChecker<BarBlowerUpper>();
201 
202     TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
203     EXPECT_FALSE(changed);
204   }
205 
206   {
207     // Run a pipeline which renames 'foo' to 'bar' then an invariant checker
208     // which fails if there is an instruction named 'bar'.
209     HloPassPipeline pipeline(TestName());
210     pipeline.AddInvariantChecker<BarBlowerUpper>();
211     pipeline.AddPass<FooToBarModulePass>();
212 
213     Status status = pipeline.Run(module.get()).status();
214     ASSERT_IS_NOT_OK(status);
215     EXPECT_THAT(status.error_message(),
216                 ::testing::HasSubstr("Module has instruction named bar"));
217     EXPECT_THAT(status.error_message(),
218                 ::testing::HasSubstr("Failed after foo2bar"));
219   }
220 
221   {
222     // Run the invariant-checker only pipeline again. It should fail this time.
223     HloPassPipeline pipeline(TestName());
224     pipeline.AddInvariantChecker<BarBlowerUpper>();
225 
226     Status status = pipeline.Run(module.get()).status();
227     ASSERT_IS_NOT_OK(status);
228     EXPECT_THAT(status.error_message(),
229                 ::testing::HasSubstr("Module has instruction named bar"));
230     EXPECT_THAT(status.error_message(),
231                 ::testing::HasSubstr("Failed after pipeline-start"));
232   }
233 }
234 
TEST_F(HloPassPipelineTest,ModuleGroupPassOnModule)235 TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) {
236   // Running a module group pass on a module should produce an error.
237   const string module_str = R"(
238 HloModule ModuleGroupPassOnModule
239 
240 ENTRY main {
241   a = f32[] parameter(0)
242   b = f32[] parameter(1)
243   ROOT foo = f32[] multiply(a, b)
244 }
245 )";
246   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
247                           ParseAndReturnVerifiedModule(module_str));
248   HloPassPipeline pipeline(TestName());
249   pipeline.AddPass<BazToQuxModuleGroupPass>();
250 
251   Status status = pipeline.Run(module.get()).status();
252   ASSERT_IS_NOT_OK(status);
253   EXPECT_THAT(
254       status.error_message(),
255       ::testing::HasSubstr("Module group pass cannot be run on a module"));
256 }
257 
258 }  // namespace
259 }  // namespace xla
260