1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Tests for tpu_function helpers.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.platform import test 25from tensorflow.python.tpu import tpu_sharding 26 27 28class ShardingTest(test.TestCase): 29 30 def testFreeze(self): 31 """Tests that freezing a policy applies default values.""" 32 p1 = tpu_sharding.ShardingPolicy() 33 p1.freeze() 34 self.assertEqual(p1.number_of_shards, 35 tpu_sharding._DEFAULT_NUMBER_OF_SHARDS) 36 self.assertEqual(p1.shard_dimension, tpu_sharding._DEFAULT_SHARD_DIMENSION) 37 p2 = tpu_sharding.ShardingPolicy() 38 p2.set_number_of_shards(17) 39 p2.set_shard_dimension(23) 40 p2.freeze() 41 self.assertEqual(p2.number_of_shards, 17) 42 self.assertEqual(p2.shard_dimension, 23) 43 44 def testFrozen(self): 45 """Tests that frozen policies can't be changed.""" 46 p1 = tpu_sharding.ShardingPolicy() 47 p1.freeze() 48 with self.assertRaises(ValueError): 49 p1.set_number_of_shards(17) 50 with self.assertRaises(ValueError): 51 p1.set_shard_dimension(22) 52 53 def testStr(self): 54 """Tests the string representation.""" 55 p1 = tpu_sharding.ShardingPolicy() 56 self.assertEqual(str(p1), "ShardingPolicy(unset)") 57 p1.set_number_of_shards(17) 58 self.assertEqual(str(p1), "ShardingPolicy(unset)") 59 p1.set_shard_dimension(8) 60 self.assertEqual(str(p1), "ShardingPolicy(17 shards dimension 8)") 61 62 def testMerge(self): 63 """Tests that merging works.""" 64 p1 = tpu_sharding.ShardingPolicy() 65 p1.set_number_of_shards(17) 66 p1.set_shard_dimension(23) 67 p2 = tpu_sharding.ShardingPolicy() 68 p2.merge(p1) 69 self.assertEqual(p2.number_of_shards, 17) 70 self.assertEqual(p2.shard_dimension, 23) 71 p1 = tpu_sharding.ShardingPolicy() 72 p1.set_shard_dimension(12) 73 p2.merge(p1) 74 self.assertEqual(p2.number_of_shards, 17) 75 self.assertEqual(p2.shard_dimension, 12) 76 p2.freeze() 77 p2.merge(p1) 78 self.assertEqual(p2.number_of_shards, 17) 79 self.assertEqual(p2.shard_dimension, 12) 80 p1.set_number_of_shards(1) 81 with self.assertRaises(ValueError): 82 p2.merge(p1) 83 p1 = tpu_sharding.ShardingPolicy() 84 p1.set_number_of_shards(17) 85 p2.merge(p1) 86 p1.set_shard_dimension(2) 87 with self.assertRaises(ValueError): 88 p2.merge(p1) 89 90 def testGetShardedShape(self): 91 """Tests getting a sharded shape.""" 92 p = tpu_sharding.ShardingPolicy() 93 p.set_number_of_shards(3) 94 p.set_shard_dimension(1) 95 self.assertEqual(p.get_sharded_shape([4, 9]), [4, 3]) 96 p.freeze() 97 with self.assertRaises(ValueError): 98 p.set_shard_dimension(0) 99 with self.assertRaises(ValueError): 100 _ = p.get_sharded_shape([4, 9], shard_index=4) 101 with self.assertRaises(ValueError): 102 _ = p.get_sharded_shape([4, 9], shard_index=-1) 103 with self.assertRaises(TypeError): 104 _ = p.get_sharded_shape("not_a_shape") 105 with self.assertRaises(ValueError): 106 _ = p.get_sharded_shape(tensor_shape.TensorShape(None)) 107 with self.assertRaises(ValueError): 108 _ = p.get_sharded_shape([4, 10], shard_index=-1) 109 110 def testGetUnpartitionedShape(self): 111 """Tests getting a sharded shape.""" 112 p = tpu_sharding.ShardingPolicy() 113 p.set_number_of_shards(3) 114 p.set_shard_dimension(1) 115 p.set_number_of_partitions(4) 116 self.assertEqual(p.get_unpartitioned_shape([3, 5]), [3, 20]) 117 p.freeze() 118 with self.assertRaises(ValueError): 119 _ = p.get_unpartitioned_shape([3, None]) 120 121 def testGetUnshardedShape(self): 122 """Tests getting an unsharded shape.""" 123 p = tpu_sharding.ShardingPolicy() 124 p.set_number_of_shards(2) 125 p.set_shard_dimension(1) 126 self.assertEqual(p.get_unsharded_shape([[4, 3], [4, 3]]), [4, 6]) 127 with self.assertRaises(ValueError): 128 _ = p.get_unsharded_shape([[4, 3]]) 129 with self.assertRaises(ValueError): 130 _ = p.get_unsharded_shape([[4, 3], [4, 3], [4, 3]]) 131 with self.assertRaises(ValueError): 132 _ = p.get_unsharded_shape([[4, 3], [4, 2]]) 133 with self.assertRaises(TypeError): 134 _ = p.get_unsharded_shape([[4, 3], "not_a_shape"]) 135 with self.assertRaises(ValueError): 136 _ = p.get_unsharded_shape([None, [4, 3]]) 137 with self.assertRaises(ValueError): 138 _ = p.get_unsharded_shape([[2], [4, 3]]) 139 140 def testScalar(self): 141 """Tests sharding and unsharding scalars.""" 142 p = tpu_sharding.ShardingPolicy() 143 p.freeze() 144 self.assertEqual(p.get_sharded_shape([]), []) 145 self.assertEqual(p.get_unsharded_shape([[]]), []) 146 147 148if __name__ == "__main__": 149 test.main() 150