1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <set>
17 #include <unordered_map>
18 #include <utility>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_parser.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/test_helpers.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/util.h"
28 
29 namespace xla {
30 namespace {
31 
MakeArray(absl::Span<const int64> dimensions,absl::Span<const int64> contents)32 Array<int64> MakeArray(absl::Span<const int64> dimensions,
33                        absl::Span<const int64> contents) {
34   Array<int64> a(dimensions);
35   std::copy(contents.begin(), contents.end(), a.begin());
36   return a;
37 }
38 
39 class HloShardingTest : public HloTestBase {};
40 
TEST_F(HloShardingTest,Replicate)41 TEST_F(HloShardingTest, Replicate) {
42   HloSharding sharding = HloSharding::Replicate();
43   EXPECT_TRUE(sharding.IsReplicated());
44   EXPECT_TRUE(sharding.IsTileMaximal());
45   EXPECT_TRUE(sharding.UsesDevice(0));
46   EXPECT_TRUE(sharding.UsesDevice(65535));
47 
48   HloSharding other = HloSharding::Replicate();
49   EXPECT_EQ(other, sharding);
50 
51   EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
52                                  /*num_devices=*/2));
53   EXPECT_FALSE(sharding.HasUniqueDevice());
54 }
55 
TEST_F(HloShardingTest,DevicePlacement)56 TEST_F(HloShardingTest, DevicePlacement) {
57   HloSharding sharding = HloSharding::AssignDevice(5);
58   EXPECT_FALSE(sharding.IsReplicated());
59   EXPECT_TRUE(sharding.IsTileMaximal());
60   EXPECT_FALSE(sharding.UsesDevice(0));
61   EXPECT_TRUE(sharding.UsesDevice(5));
62   EXPECT_EQ(5, sharding.GetUniqueDevice());
63 
64   HloSharding other = HloSharding::Replicate();
65   EXPECT_NE(other, sharding);
66 
67   EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
68                                  /*num_devices=*/6));
69   EXPECT_IS_NOT_OK(
70       sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/5));
71 
72   ShapeTree<HloSharding> shape_tree =
73       sharding.GetAsShapeTree(ShapeUtil::MakeShape(U32, {4}));
74   EXPECT_EQ(shape_tree.element({}), sharding);
75   EXPECT_TRUE(shape_tree.IsLeaf({}));
76 }
77 
TEST_F(HloShardingTest,Tile)78 TEST_F(HloShardingTest, Tile) {
79   {
80     // Test should fail because of a duplicate tile assignment.
81     HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3}));
82     EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}),
83                                        /*num_devices=*/4));
84   }
85 
86   {
87     // Test should fail because of more devices used than `num_device`.
88     HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3}));
89     EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}),
90                                        /*num_devices=*/2));
91   }
92 
93   {
94     // Test should pass.
95     Shape shape = ShapeUtil::MakeShape(U32, {4, 5});
96     HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
97     EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}),
98                                    /*num_devices=*/5));
99 
100     EXPECT_EQ(0, sharding.DeviceForTileIndex({0, 0}));
101     EXPECT_EQ(3, sharding.DeviceForTileIndex({0, 1}));
102     EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0}));
103     EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1}));
104 
105     EXPECT_EQ(sharding.TileOffsetForDevice(shape, 0),
106               (std::vector<int64>{0, 0}));
107     EXPECT_EQ(sharding.TileOffsetForDevice(shape, 3),
108               (std::vector<int64>{0, 3}));
109     EXPECT_EQ(sharding.TileOffsetForDevice(shape, 2),
110               (std::vector<int64>{2, 0}));
111     EXPECT_EQ(sharding.TileOffsetForDevice(shape, 1),
112               (std::vector<int64>{2, 3}));
113 
114     EXPECT_FALSE(sharding.HasUniqueDevice());
115   }
116 }
117 
118 // Tests that empty tuple is supported.
TEST_F(HloShardingTest,EmptySingleTuple)119 TEST_F(HloShardingTest, EmptySingleTuple) {
120   HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}),
121                                                   HloSharding::AssignDevice(0));
122   EXPECT_TRUE(sharding.ExtractSingleSharding());
123 }
124 
TEST_F(HloShardingTest,NestedTuple)125 TEST_F(HloShardingTest, NestedTuple) {
126   // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
127   Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({
128       ShapeUtil::MakeShape(F32, {}),
129       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}),
130       ShapeUtil::MakeShape(F32, {4, 6}),
131   });
132 
133   HloSharding tiled_sharding = HloSharding::Tile(Array<int64>({{0, 1}}));
134   OpSharding proto;
135   proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
136   *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
137   *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto();
138   *proto.add_tuple_shardings() = tiled_sharding.ToProto();
139   HloSharding tuple_sharding =
140       HloSharding::FromProto(proto).ConsumeValueOrDie();
141 
142   ShapeTree<HloSharding> shape_tree =
143       tuple_sharding.GetAsShapeTree(nested_tuple_shape);
144   EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate());
145   EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0));
146   EXPECT_EQ(shape_tree.element({2}), tiled_sharding);
147 
148   EXPECT_IS_OK(tuple_sharding.Validate(nested_tuple_shape, /*num_devices=*/5));
149   // Test should fail because tuple element count does not match.
150   EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeTupleShape({}),
151                                            /*num_devices=*/5));
152   // Test should fail because the input type is not a tuple.
153   EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeShape(F32, {}),
154                                            /*num_devices=*/5));
155 }
156 
TEST_F(HloShardingTest,Hash)157 TEST_F(HloShardingTest, Hash) {
158   auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) {
159     if (a.Hash() != b.Hash()) {
160       return false;
161     }
162     return a == b;
163   };
164 
165   {
166     HloSharding sharding1 = HloSharding::Replicate();
167     HloSharding sharding2 = HloSharding::Replicate();
168     EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
169   }
170 
171   {
172     HloSharding sharding1 = HloSharding::AssignDevice(1);
173     HloSharding sharding2 = HloSharding::AssignDevice(1);
174     EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
175   }
176 
177   {
178     HloSharding sharding1 = HloSharding::AssignDevice(1);
179     HloSharding sharding2 = HloSharding::AssignDevice(2);
180     EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
181   }
182 
183   {
184     HloSharding sharding1 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
185     HloSharding sharding2 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
186     EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
187   }
188 
189   HloSharding default_sharding = HloSharding::Replicate();
190   {
191     ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
192                                       default_sharding);
193     HloSharding sharding1 = HloSharding::Replicate();
194     HloSharding sharding2 = HloSharding::Tuple(shape_tree);
195     EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
196   }
197 
198   {
199     ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
200                                       default_sharding);
201     HloSharding sharding1 = HloSharding::Tuple(shape_tree);
202     HloSharding sharding2 = HloSharding::Tuple(shape_tree);
203     EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
204   }
205 
206   {
207     ShapeTree<HloSharding> shape_tree1(
208         ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
209         default_sharding);
210     *shape_tree1.mutable_element({0}) = HloSharding::Replicate();
211     ShapeTree<HloSharding> shape_tree2(
212         ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
213         default_sharding);
214     *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
215     HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
216     HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
217     EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
218   }
219 
220   {
221     ShapeTree<HloSharding> shape_tree1(
222         ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
223         default_sharding);
224     *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0);
225     ShapeTree<HloSharding> shape_tree2(
226         ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
227         default_sharding);
228     *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
229     HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
230     HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
231     EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
232   }
233 }
234 
TEST_F(HloShardingTest,ToStringReplicatedTest)235 TEST_F(HloShardingTest, ToStringReplicatedTest) {
236   HloSharding sharding = HloSharding::Replicate();
237   EXPECT_EQ(sharding.ToString(), "{replicated}");
238 }
239 
TEST_F(HloShardingTest,ToStringAssignDeviceTest)240 TEST_F(HloShardingTest, ToStringAssignDeviceTest) {
241   HloSharding sharding = HloSharding::AssignDevice(7);
242   EXPECT_EQ(sharding.ToString(), "{maximal device=7}");
243 }
244 
TEST_F(HloShardingTest,ToStringTiledTest)245 TEST_F(HloShardingTest, ToStringTiledTest) {
246   HloSharding sharding =
247       HloSharding::Tile(Array3D<int64>({{{2, 3}}, {{5, 7}}}));
248   EXPECT_EQ(sharding.ToString(), "{devices=[2,1,2]2,3,5,7}");
249 }
250 
TEST_F(HloShardingTest,ToStringTupleTest)251 TEST_F(HloShardingTest, ToStringTupleTest) {
252   HloSharding sharding = HloSharding::Tuple(
253       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}),
254                                  ShapeUtil::MakeShape(U32, {7, 25}),
255                                  ShapeUtil::MakeShape(S32, {9, 11})}),
256       {HloSharding::Replicate(), HloSharding::Tile(Array2D<int64>({{3, 5}})),
257        HloSharding::AssignDevice(3)});
258   EXPECT_EQ(sharding.ToString(),
259             "{{replicated}, {devices=[1,2]3,5}, {maximal device=3}}");
260 }
261 
TEST_F(HloShardingTest,OstreamTest)262 TEST_F(HloShardingTest, OstreamTest) {
263   HloSharding sharding =
264       HloSharding::Tile(Array4D<int64>({{{{0, 1}, {2, 3}}}}));
265   std::ostringstream oss;
266   oss << sharding;
267   EXPECT_EQ(oss.str(), "{devices=[1,1,2,2]0,1,2,3}");
268 }
269 
TEST_F(HloShardingTest,ParseHloString)270 TEST_F(HloShardingTest, ParseHloString) {
271   auto check = [](const HloSharding& sharding) {
272     TF_ASSERT_OK_AND_ASSIGN(auto parsed_sharding,
273                             ParseSharding(sharding.ToString()));
274     EXPECT_EQ(sharding, parsed_sharding);
275   };
276   check(HloSharding::Replicate());
277   check(HloSharding::AssignDevice(2));
278   check(HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})));
279   // Empty tuple. One sharding is required for empty tuples, as we need to be
280   // able to assign sharding to them, even though they have no leaves.
281   check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}),
282                            {HloSharding::Replicate()}));
283   {
284     // Non-nested tuple.
285     auto tuple_shape =
286         ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 1, 5, 7}),
287                                    ShapeUtil::MakeShape(F32, {3, 5, 7}),
288                                    ShapeUtil::MakeShape(F32, {3, 7})});
289     check(HloSharding::Tuple(
290         tuple_shape, {HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})),
291                       HloSharding::Replicate(), HloSharding::AssignDevice(1)}));
292   }
293   {
294     // Nested tuple.
295     auto tuple_shape = ShapeUtil::MakeTupleShape(
296         {ShapeUtil::MakeShape(F32, {3, 1, 5, 7}),
297          ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}),
298                                     ShapeUtil::MakeShape(F32, {3, 7})})});
299     std::vector<HloSharding> leaf_shardings = {
300         HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})),
301         HloSharding::Replicate(), HloSharding::AssignDevice(1)};
302     ShapeTree<HloSharding> sharding_tree(tuple_shape, HloSharding::Replicate());
303     // Assign leaf_shardings to sharding_tree leaves.
304     auto it = leaf_shardings.begin();
305     for (auto& index_to_sharding : sharding_tree.leaves()) {
306       index_to_sharding.second = *it++;
307     }
308     check(HloSharding::Tuple(sharding_tree));
309   }
310 }
311 
312 }  // namespace
313 }  // namespace xla
314