1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Tests for tpu_function helpers."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.platform import test
25from tensorflow.python.tpu import tpu_sharding
26
27
28class ShardingTest(test.TestCase):
29
30  def testFreeze(self):
31    """Tests that freezing a policy applies default values."""
32    p1 = tpu_sharding.ShardingPolicy()
33    p1.freeze()
34    self.assertEqual(p1.number_of_shards,
35                     tpu_sharding._DEFAULT_NUMBER_OF_SHARDS)
36    self.assertEqual(p1.shard_dimension, tpu_sharding._DEFAULT_SHARD_DIMENSION)
37    p2 = tpu_sharding.ShardingPolicy()
38    p2.set_number_of_shards(17)
39    p2.set_shard_dimension(23)
40    p2.freeze()
41    self.assertEqual(p2.number_of_shards, 17)
42    self.assertEqual(p2.shard_dimension, 23)
43
44  def testFrozen(self):
45    """Tests that frozen policies can't be changed."""
46    p1 = tpu_sharding.ShardingPolicy()
47    p1.freeze()
48    with self.assertRaises(ValueError):
49      p1.set_number_of_shards(17)
50    with self.assertRaises(ValueError):
51      p1.set_shard_dimension(22)
52
53  def testStr(self):
54    """Tests the string representation."""
55    p1 = tpu_sharding.ShardingPolicy()
56    self.assertEqual(str(p1), "ShardingPolicy(unset)")
57    p1.set_number_of_shards(17)
58    self.assertEqual(str(p1), "ShardingPolicy(unset)")
59    p1.set_shard_dimension(8)
60    self.assertEqual(str(p1), "ShardingPolicy(17 shards dimension 8)")
61
62  def testMerge(self):
63    """Tests that merging works."""
64    p1 = tpu_sharding.ShardingPolicy()
65    p1.set_number_of_shards(17)
66    p1.set_shard_dimension(23)
67    p2 = tpu_sharding.ShardingPolicy()
68    p2.merge(p1)
69    self.assertEqual(p2.number_of_shards, 17)
70    self.assertEqual(p2.shard_dimension, 23)
71    p1 = tpu_sharding.ShardingPolicy()
72    p1.set_shard_dimension(12)
73    p2.merge(p1)
74    self.assertEqual(p2.number_of_shards, 17)
75    self.assertEqual(p2.shard_dimension, 12)
76    p2.freeze()
77    p2.merge(p1)
78    self.assertEqual(p2.number_of_shards, 17)
79    self.assertEqual(p2.shard_dimension, 12)
80    p1.set_number_of_shards(1)
81    with self.assertRaises(ValueError):
82      p2.merge(p1)
83    p1 = tpu_sharding.ShardingPolicy()
84    p1.set_number_of_shards(17)
85    p2.merge(p1)
86    p1.set_shard_dimension(2)
87    with self.assertRaises(ValueError):
88      p2.merge(p1)
89
90  def testGetShardedShape(self):
91    """Tests getting a sharded shape."""
92    p = tpu_sharding.ShardingPolicy()
93    p.set_number_of_shards(3)
94    p.set_shard_dimension(1)
95    self.assertEqual(p.get_sharded_shape([4, 9]), [4, 3])
96    p.freeze()
97    with self.assertRaises(ValueError):
98      p.set_shard_dimension(0)
99    with self.assertRaises(ValueError):
100      _ = p.get_sharded_shape([4, 9], shard_index=4)
101    with self.assertRaises(ValueError):
102      _ = p.get_sharded_shape([4, 9], shard_index=-1)
103    with self.assertRaises(TypeError):
104      _ = p.get_sharded_shape("not_a_shape")
105    with self.assertRaises(ValueError):
106      _ = p.get_sharded_shape(tensor_shape.TensorShape(None))
107    with self.assertRaises(ValueError):
108      _ = p.get_sharded_shape([4, 10], shard_index=-1)
109
110  def testGetUnpartitionedShape(self):
111    """Tests getting a sharded shape."""
112    p = tpu_sharding.ShardingPolicy()
113    p.set_number_of_shards(3)
114    p.set_shard_dimension(1)
115    p.set_number_of_partitions(4)
116    self.assertEqual(p.get_unpartitioned_shape([3, 5]), [3, 20])
117    p.freeze()
118    with self.assertRaises(ValueError):
119      _ = p.get_unpartitioned_shape([3, None])
120
121  def testGetUnshardedShape(self):
122    """Tests getting an unsharded shape."""
123    p = tpu_sharding.ShardingPolicy()
124    p.set_number_of_shards(2)
125    p.set_shard_dimension(1)
126    self.assertEqual(p.get_unsharded_shape([[4, 3], [4, 3]]), [4, 6])
127    with self.assertRaises(ValueError):
128      _ = p.get_unsharded_shape([[4, 3]])
129    with self.assertRaises(ValueError):
130      _ = p.get_unsharded_shape([[4, 3], [4, 3], [4, 3]])
131    with self.assertRaises(ValueError):
132      _ = p.get_unsharded_shape([[4, 3], [4, 2]])
133    with self.assertRaises(TypeError):
134      _ = p.get_unsharded_shape([[4, 3], "not_a_shape"])
135    with self.assertRaises(ValueError):
136      _ = p.get_unsharded_shape([None, [4, 3]])
137    with self.assertRaises(ValueError):
138      _ = p.get_unsharded_shape([[2], [4, 3]])
139
140  def testScalar(self):
141    """Tests sharding and unsharding scalars."""
142    p = tpu_sharding.ShardingPolicy()
143    p.freeze()
144    self.assertEqual(p.get_sharded_shape([]), [])
145    self.assertEqual(p.get_unsharded_shape([[]]), [])
146
147
148if __name__ == "__main__":
149  test.main()
150