1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <set>
17 #include <unordered_map>
18 #include <utility>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_parser.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/test_helpers.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/util.h"
28
29 namespace xla {
30 namespace {
31
MakeArray(absl::Span<const int64> dimensions,absl::Span<const int64> contents)32 Array<int64> MakeArray(absl::Span<const int64> dimensions,
33 absl::Span<const int64> contents) {
34 Array<int64> a(dimensions);
35 std::copy(contents.begin(), contents.end(), a.begin());
36 return a;
37 }
38
39 class HloShardingTest : public HloTestBase {};
40
TEST_F(HloShardingTest,Replicate)41 TEST_F(HloShardingTest, Replicate) {
42 HloSharding sharding = HloSharding::Replicate();
43 EXPECT_TRUE(sharding.IsReplicated());
44 EXPECT_TRUE(sharding.IsTileMaximal());
45 EXPECT_TRUE(sharding.UsesDevice(0));
46 EXPECT_TRUE(sharding.UsesDevice(65535));
47
48 HloSharding other = HloSharding::Replicate();
49 EXPECT_EQ(other, sharding);
50
51 EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
52 /*num_devices=*/2));
53 EXPECT_FALSE(sharding.HasUniqueDevice());
54 }
55
TEST_F(HloShardingTest,DevicePlacement)56 TEST_F(HloShardingTest, DevicePlacement) {
57 HloSharding sharding = HloSharding::AssignDevice(5);
58 EXPECT_FALSE(sharding.IsReplicated());
59 EXPECT_TRUE(sharding.IsTileMaximal());
60 EXPECT_FALSE(sharding.UsesDevice(0));
61 EXPECT_TRUE(sharding.UsesDevice(5));
62 EXPECT_EQ(5, sharding.GetUniqueDevice());
63
64 HloSharding other = HloSharding::Replicate();
65 EXPECT_NE(other, sharding);
66
67 EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
68 /*num_devices=*/6));
69 EXPECT_IS_NOT_OK(
70 sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/5));
71
72 ShapeTree<HloSharding> shape_tree =
73 sharding.GetAsShapeTree(ShapeUtil::MakeShape(U32, {4}));
74 EXPECT_EQ(shape_tree.element({}), sharding);
75 EXPECT_TRUE(shape_tree.IsLeaf({}));
76 }
77
TEST_F(HloShardingTest,Tile)78 TEST_F(HloShardingTest, Tile) {
79 {
80 // Test should fail because of a duplicate tile assignment.
81 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3}));
82 EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}),
83 /*num_devices=*/4));
84 }
85
86 {
87 // Test should fail because of more devices used than `num_device`.
88 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3}));
89 EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}),
90 /*num_devices=*/2));
91 }
92
93 {
94 // Test should pass.
95 Shape shape = ShapeUtil::MakeShape(U32, {4, 5});
96 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
97 EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}),
98 /*num_devices=*/5));
99
100 EXPECT_EQ(0, sharding.DeviceForTileIndex({0, 0}));
101 EXPECT_EQ(3, sharding.DeviceForTileIndex({0, 1}));
102 EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0}));
103 EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1}));
104
105 EXPECT_EQ(sharding.TileOffsetForDevice(shape, 0),
106 (std::vector<int64>{0, 0}));
107 EXPECT_EQ(sharding.TileOffsetForDevice(shape, 3),
108 (std::vector<int64>{0, 3}));
109 EXPECT_EQ(sharding.TileOffsetForDevice(shape, 2),
110 (std::vector<int64>{2, 0}));
111 EXPECT_EQ(sharding.TileOffsetForDevice(shape, 1),
112 (std::vector<int64>{2, 3}));
113
114 EXPECT_FALSE(sharding.HasUniqueDevice());
115 }
116 }
117
118 // Tests that empty tuple is supported.
TEST_F(HloShardingTest,EmptySingleTuple)119 TEST_F(HloShardingTest, EmptySingleTuple) {
120 HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}),
121 HloSharding::AssignDevice(0));
122 EXPECT_TRUE(sharding.ExtractSingleSharding());
123 }
124
TEST_F(HloShardingTest,NestedTuple)125 TEST_F(HloShardingTest, NestedTuple) {
126 // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
127 Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({
128 ShapeUtil::MakeShape(F32, {}),
129 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}),
130 ShapeUtil::MakeShape(F32, {4, 6}),
131 });
132
133 HloSharding tiled_sharding = HloSharding::Tile(Array<int64>({{0, 1}}));
134 OpSharding proto;
135 proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
136 *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
137 *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto();
138 *proto.add_tuple_shardings() = tiled_sharding.ToProto();
139 HloSharding tuple_sharding =
140 HloSharding::FromProto(proto).ConsumeValueOrDie();
141
142 ShapeTree<HloSharding> shape_tree =
143 tuple_sharding.GetAsShapeTree(nested_tuple_shape);
144 EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate());
145 EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0));
146 EXPECT_EQ(shape_tree.element({2}), tiled_sharding);
147
148 EXPECT_IS_OK(tuple_sharding.Validate(nested_tuple_shape, /*num_devices=*/5));
149 // Test should fail because tuple element count does not match.
150 EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeTupleShape({}),
151 /*num_devices=*/5));
152 // Test should fail because the input type is not a tuple.
153 EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeShape(F32, {}),
154 /*num_devices=*/5));
155 }
156
TEST_F(HloShardingTest,Hash)157 TEST_F(HloShardingTest, Hash) {
158 auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) {
159 if (a.Hash() != b.Hash()) {
160 return false;
161 }
162 return a == b;
163 };
164
165 {
166 HloSharding sharding1 = HloSharding::Replicate();
167 HloSharding sharding2 = HloSharding::Replicate();
168 EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
169 }
170
171 {
172 HloSharding sharding1 = HloSharding::AssignDevice(1);
173 HloSharding sharding2 = HloSharding::AssignDevice(1);
174 EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
175 }
176
177 {
178 HloSharding sharding1 = HloSharding::AssignDevice(1);
179 HloSharding sharding2 = HloSharding::AssignDevice(2);
180 EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
181 }
182
183 {
184 HloSharding sharding1 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
185 HloSharding sharding2 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
186 EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
187 }
188
189 HloSharding default_sharding = HloSharding::Replicate();
190 {
191 ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
192 default_sharding);
193 HloSharding sharding1 = HloSharding::Replicate();
194 HloSharding sharding2 = HloSharding::Tuple(shape_tree);
195 EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
196 }
197
198 {
199 ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
200 default_sharding);
201 HloSharding sharding1 = HloSharding::Tuple(shape_tree);
202 HloSharding sharding2 = HloSharding::Tuple(shape_tree);
203 EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
204 }
205
206 {
207 ShapeTree<HloSharding> shape_tree1(
208 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
209 default_sharding);
210 *shape_tree1.mutable_element({0}) = HloSharding::Replicate();
211 ShapeTree<HloSharding> shape_tree2(
212 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
213 default_sharding);
214 *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
215 HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
216 HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
217 EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
218 }
219
220 {
221 ShapeTree<HloSharding> shape_tree1(
222 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
223 default_sharding);
224 *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0);
225 ShapeTree<HloSharding> shape_tree2(
226 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
227 default_sharding);
228 *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
229 HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
230 HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
231 EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
232 }
233 }
234
TEST_F(HloShardingTest,ToStringReplicatedTest)235 TEST_F(HloShardingTest, ToStringReplicatedTest) {
236 HloSharding sharding = HloSharding::Replicate();
237 EXPECT_EQ(sharding.ToString(), "{replicated}");
238 }
239
TEST_F(HloShardingTest,ToStringAssignDeviceTest)240 TEST_F(HloShardingTest, ToStringAssignDeviceTest) {
241 HloSharding sharding = HloSharding::AssignDevice(7);
242 EXPECT_EQ(sharding.ToString(), "{maximal device=7}");
243 }
244
TEST_F(HloShardingTest,ToStringTiledTest)245 TEST_F(HloShardingTest, ToStringTiledTest) {
246 HloSharding sharding =
247 HloSharding::Tile(Array3D<int64>({{{2, 3}}, {{5, 7}}}));
248 EXPECT_EQ(sharding.ToString(), "{devices=[2,1,2]2,3,5,7}");
249 }
250
TEST_F(HloShardingTest,ToStringTupleTest)251 TEST_F(HloShardingTest, ToStringTupleTest) {
252 HloSharding sharding = HloSharding::Tuple(
253 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}),
254 ShapeUtil::MakeShape(U32, {7, 25}),
255 ShapeUtil::MakeShape(S32, {9, 11})}),
256 {HloSharding::Replicate(), HloSharding::Tile(Array2D<int64>({{3, 5}})),
257 HloSharding::AssignDevice(3)});
258 EXPECT_EQ(sharding.ToString(),
259 "{{replicated}, {devices=[1,2]3,5}, {maximal device=3}}");
260 }
261
TEST_F(HloShardingTest,OstreamTest)262 TEST_F(HloShardingTest, OstreamTest) {
263 HloSharding sharding =
264 HloSharding::Tile(Array4D<int64>({{{{0, 1}, {2, 3}}}}));
265 std::ostringstream oss;
266 oss << sharding;
267 EXPECT_EQ(oss.str(), "{devices=[1,1,2,2]0,1,2,3}");
268 }
269
TEST_F(HloShardingTest,ParseHloString)270 TEST_F(HloShardingTest, ParseHloString) {
271 auto check = [](const HloSharding& sharding) {
272 TF_ASSERT_OK_AND_ASSIGN(auto parsed_sharding,
273 ParseSharding(sharding.ToString()));
274 EXPECT_EQ(sharding, parsed_sharding);
275 };
276 check(HloSharding::Replicate());
277 check(HloSharding::AssignDevice(2));
278 check(HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})));
279 // Empty tuple. One sharding is required for empty tuples, as we need to be
280 // able to assign sharding to them, even though they have no leaves.
281 check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}),
282 {HloSharding::Replicate()}));
283 {
284 // Non-nested tuple.
285 auto tuple_shape =
286 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 1, 5, 7}),
287 ShapeUtil::MakeShape(F32, {3, 5, 7}),
288 ShapeUtil::MakeShape(F32, {3, 7})});
289 check(HloSharding::Tuple(
290 tuple_shape, {HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})),
291 HloSharding::Replicate(), HloSharding::AssignDevice(1)}));
292 }
293 {
294 // Nested tuple.
295 auto tuple_shape = ShapeUtil::MakeTupleShape(
296 {ShapeUtil::MakeShape(F32, {3, 1, 5, 7}),
297 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}),
298 ShapeUtil::MakeShape(F32, {3, 7})})});
299 std::vector<HloSharding> leaf_shardings = {
300 HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})),
301 HloSharding::Replicate(), HloSharding::AssignDevice(1)};
302 ShapeTree<HloSharding> sharding_tree(tuple_shape, HloSharding::Replicate());
303 // Assign leaf_shardings to sharding_tree leaves.
304 auto it = leaf_shardings.begin();
305 for (auto& index_to_sharding : sharding_tree.leaves()) {
306 index_to_sharding.second = *it++;
307 }
308 check(HloSharding::Tuple(sharding_tree));
309 }
310 }
311
312 } // namespace
313 } // namespace xla
314