1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
17
18 #include "tensorflow/compiler/xla/service/hlo.pb.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25
26 namespace xla {
27
28 namespace {
29
30 namespace op = ::xla::testing::opcode_matchers;
31
32 class HloModuleGroupTest : public HloTestBase {
33 protected:
34 HloModuleGroupTest() = default;
35 };
36
TEST_F(HloModuleGroupTest,SingleModule)37 TEST_F(HloModuleGroupTest, SingleModule) {
38 const string text = R"(
39 HloModule simple_module
40
41 ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
42 %x = f32[] parameter(0)
43 %y = f32[] parameter(1)
44 ROOT %add = f32[] add(%x, %y)
45 }
46 )";
47 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
48 ParseHloString(text));
49 HloModuleGroup group(std::move(module));
50
51 EXPECT_EQ(group.modules().size(), 1);
52 EXPECT_THAT(
53 group.module(0).entry_computation()->instructions(),
54 ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
55
56 TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
57 HloModuleGroup::CreateFromProto(
58 group.ToProto(), {group.module(0).config()}));
59 EXPECT_EQ(group_copy.modules().size(), 1);
60 EXPECT_THAT(
61 group_copy.module(0).entry_computation()->instructions(),
62 ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
63
64 std::vector<std::unique_ptr<HloModule>> modules = group.ConsumeModules();
65 EXPECT_EQ(modules.size(), 1);
66 EXPECT_EQ(group.modules().size(), 0);
67 }
68
TEST_F(HloModuleGroupTest,MultipleModules)69 TEST_F(HloModuleGroupTest, MultipleModules) {
70 const string text_0 = R"(
71 HloModule module0
72
73 ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
74 %x = f32[] parameter(0)
75 %y = f32[] parameter(1)
76 ROOT %add = f32[] add(%x, %y)
77 }
78 )";
79 const string text_1 = R"(
80 HloModule module1
81
82 ENTRY %entry (a: f32[]) -> f32[] {
83 ROOT %a = f32[] parameter(0)
84 }
85 )";
86 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
87 ParseHloString(text_0));
88 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
89 ParseHloString(text_1));
90 std::vector<std::unique_ptr<HloModule>> modules;
91 modules.push_back(std::move(module_0));
92 modules.push_back(std::move(module_1));
93 HloModuleGroup group(TestName(), absl::MakeSpan(modules));
94 EXPECT_EQ(group.modules().size(), 2);
95 EXPECT_THAT(
96 group.module(0).entry_computation()->instructions(),
97 ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
98 EXPECT_THAT(group.module(1).entry_computation()->instructions(),
99 ::testing::ElementsAre(op::Parameter()));
100
101 TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
102 HloModuleGroup::CreateFromProto(
103 group.ToProto(), {group.module(0).config(),
104 group.module(1).config()}));
105 EXPECT_EQ(group_copy.modules().size(), 2);
106 }
107
TEST_F(HloModuleGroupTest,BuildModuleGroupByPushBack)108 TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) {
109 const string text_0 = R"(
110 HloModule module0
111
112 ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
113 %x = f32[] parameter(0)
114 %y = f32[] parameter(1)
115 ROOT %add = f32[] add(%x, %y)
116 }
117 )";
118 const string text_1 = R"(
119 HloModule module1
120
121 ENTRY %entry (a: f32[]) -> f32[] {
122 ROOT %a = f32[] parameter(0)
123 }
124 )";
125 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
126 ParseHloString(text_0));
127 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
128 ParseHloString(text_1));
129 HloModuleGroup group(TestName());
130 group.push_back(std::move(module_0));
131 group.push_back(std::move(module_1));
132
133 EXPECT_EQ(group.modules().size(), 2);
134 EXPECT_THAT(
135 group.module(0).entry_computation()->instructions(),
136 ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
137 EXPECT_THAT(group.module(1).entry_computation()->instructions(),
138 ::testing::ElementsAre(op::Parameter()));
139 }
140
141 // Tests that the order of companion instructions in the companion set doesn't
142 // change across runs.
TEST_F(HloModuleGroupTest,ModuleGroupCompanionOrder)143 TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) {
144 // A simple while loop template for core i sending to core i+1.
145 constexpr char text[] = R"(
146 HloModule module_%d
147
148 while_cond {
149 ROOT p = pred[] constant(true)
150 }
151
152 while_body {
153 param = s32[] parameter(0)
154 token.s = token[] after-all()
155 token.r = token[] after-all()
156 send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d
157 send-done = token[] send-done(send), channel_id=%d
158 recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d
159 ROOT recv-done = (s32[], token[]) recv-done(recv), channel_id=%d
160 }
161
162 ENTRY entry {
163 while_init = s32[] constant(1)
164 ROOT while = s32[] while(while_init), condition=while_cond, body=while_body
165 }
166 )";
167
168 // Try creating the module and the metadata kTrialCount times and check the
169 // companion instructions remain in the same order.
170 const int64 kTrialCount = 5;
171 const int64 kDeviceCount = 10;
172 std::vector<int64> companion_order;
173
174 for (int64 t = 0; t < kTrialCount; ++t) {
175 HloModuleGroup group(TestName());
176 for (int64 i = 0; i < kDeviceCount; ++i) {
177 const int64 send_channel = i;
178 const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1;
179 TF_ASSERT_OK_AND_ASSIGN(
180 std::unique_ptr<HloModule> module,
181 ParseHloString(absl::StrFormat(text, i, send_channel, send_channel,
182 recv_channel, recv_channel)));
183 group.push_back(std::move(module));
184 }
185 ASSERT_EQ(group.modules().size(), kDeviceCount);
186
187 TF_ASSERT_OK_AND_ASSIGN(auto metadata,
188 HloModuleGroupMetadata::Build(group.modules()));
189 ASSERT_EQ(metadata->companion_sets().size(), 1);
190
191 std::vector<int64> module_ids;
192 for (HloInstruction* companion : *metadata->companion_sets()[0]) {
193 module_ids.push_back(metadata->GetModuleId(companion->GetModule()));
194 }
195
196 if (t == 0) {
197 companion_order = module_ids;
198 } else {
199 EXPECT_TRUE(absl::c_equal(companion_order, module_ids));
200 }
201 }
202 }
203
204 } // namespace
205
206 } // namespace xla
207