1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo.pb.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 
26 namespace xla {
27 
28 namespace {
29 
30 namespace op = ::xla::testing::opcode_matchers;
31 
32 class HloModuleGroupTest : public HloTestBase {
33  protected:
34   HloModuleGroupTest() = default;
35 };
36 
TEST_F(HloModuleGroupTest,SingleModule)37 TEST_F(HloModuleGroupTest, SingleModule) {
38   const string text = R"(
39 HloModule simple_module
40 
41 ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
42   %x = f32[] parameter(0)
43   %y = f32[] parameter(1)
44   ROOT %add = f32[] add(%x, %y)
45 }
46 )";
47   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
48                           ParseHloString(text));
49   HloModuleGroup group(std::move(module));
50 
51   EXPECT_EQ(group.modules().size(), 1);
52   EXPECT_THAT(
53       group.module(0).entry_computation()->instructions(),
54       ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
55 
56   TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
57                           HloModuleGroup::CreateFromProto(
58                               group.ToProto(), {group.module(0).config()}));
59   EXPECT_EQ(group_copy.modules().size(), 1);
60   EXPECT_THAT(
61       group_copy.module(0).entry_computation()->instructions(),
62       ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
63 
64   std::vector<std::unique_ptr<HloModule>> modules = group.ConsumeModules();
65   EXPECT_EQ(modules.size(), 1);
66   EXPECT_EQ(group.modules().size(), 0);
67 }
68 
TEST_F(HloModuleGroupTest,MultipleModules)69 TEST_F(HloModuleGroupTest, MultipleModules) {
70   const string text_0 = R"(
71 HloModule module0
72 
73 ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
74   %x = f32[] parameter(0)
75   %y = f32[] parameter(1)
76   ROOT %add = f32[] add(%x, %y)
77 }
78 )";
79   const string text_1 = R"(
80 HloModule module1
81 
82 ENTRY %entry (a: f32[]) -> f32[] {
83   ROOT %a = f32[] parameter(0)
84 }
85 )";
86   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
87                           ParseHloString(text_0));
88   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
89                           ParseHloString(text_1));
90   std::vector<std::unique_ptr<HloModule>> modules;
91   modules.push_back(std::move(module_0));
92   modules.push_back(std::move(module_1));
93   HloModuleGroup group(TestName(), absl::MakeSpan(modules));
94   EXPECT_EQ(group.modules().size(), 2);
95   EXPECT_THAT(
96       group.module(0).entry_computation()->instructions(),
97       ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
98   EXPECT_THAT(group.module(1).entry_computation()->instructions(),
99               ::testing::ElementsAre(op::Parameter()));
100 
101   TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
102                           HloModuleGroup::CreateFromProto(
103                               group.ToProto(), {group.module(0).config(),
104                                                 group.module(1).config()}));
105   EXPECT_EQ(group_copy.modules().size(), 2);
106 }
107 
TEST_F(HloModuleGroupTest,BuildModuleGroupByPushBack)108 TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) {
109   const string text_0 = R"(
110 HloModule module0
111 
112 ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
113   %x = f32[] parameter(0)
114   %y = f32[] parameter(1)
115   ROOT %add = f32[] add(%x, %y)
116 }
117 )";
118   const string text_1 = R"(
119 HloModule module1
120 
121 ENTRY %entry (a: f32[]) -> f32[] {
122   ROOT %a = f32[] parameter(0)
123 }
124 )";
125   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
126                           ParseHloString(text_0));
127   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
128                           ParseHloString(text_1));
129   HloModuleGroup group(TestName());
130   group.push_back(std::move(module_0));
131   group.push_back(std::move(module_1));
132 
133   EXPECT_EQ(group.modules().size(), 2);
134   EXPECT_THAT(
135       group.module(0).entry_computation()->instructions(),
136       ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
137   EXPECT_THAT(group.module(1).entry_computation()->instructions(),
138               ::testing::ElementsAre(op::Parameter()));
139 }
140 
141 // Tests that the order of companion instructions in the companion set doesn't
142 // change across runs.
TEST_F(HloModuleGroupTest,ModuleGroupCompanionOrder)143 TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) {
144   // A simple while loop template for core i sending to core i+1.
145   constexpr char text[] = R"(
146 HloModule module_%d
147 
148 while_cond {
149   ROOT p = pred[] constant(true)
150 }
151 
152 while_body {
153   param = s32[] parameter(0)
154   token.s = token[] after-all()
155   token.r = token[] after-all()
156   send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d
157   send-done = token[] send-done(send), channel_id=%d
158   recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d
159   ROOT recv-done = (s32[], token[]) recv-done(recv), channel_id=%d
160 }
161 
162 ENTRY entry {
163   while_init = s32[] constant(1)
164   ROOT while = s32[] while(while_init), condition=while_cond, body=while_body
165 }
166 )";
167 
168   // Try creating the module and the metadata kTrialCount times and check the
169   // companion instructions remain in the same order.
170   const int64 kTrialCount = 5;
171   const int64 kDeviceCount = 10;
172   std::vector<int64> companion_order;
173 
174   for (int64 t = 0; t < kTrialCount; ++t) {
175     HloModuleGroup group(TestName());
176     for (int64 i = 0; i < kDeviceCount; ++i) {
177       const int64 send_channel = i;
178       const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1;
179       TF_ASSERT_OK_AND_ASSIGN(
180           std::unique_ptr<HloModule> module,
181           ParseHloString(absl::StrFormat(text, i, send_channel, send_channel,
182                                          recv_channel, recv_channel)));
183       group.push_back(std::move(module));
184     }
185     ASSERT_EQ(group.modules().size(), kDeviceCount);
186 
187     TF_ASSERT_OK_AND_ASSIGN(auto metadata,
188                             HloModuleGroupMetadata::Build(group.modules()));
189     ASSERT_EQ(metadata->companion_sets().size(), 1);
190 
191     std::vector<int64> module_ids;
192     for (HloInstruction* companion : *metadata->companion_sets()[0]) {
193       module_ids.push_back(metadata->GetModuleId(companion->GetModule()));
194     }
195 
196     if (t == 0) {
197       companion_order = module_ids;
198     } else {
199       EXPECT_TRUE(absl::c_equal(companion_order, module_ids));
200     }
201   }
202 }
203 
204 }  // namespace
205 
206 }  // namespace xla
207