1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for StructuredTensorSpec.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.framework import tensor_spec 25from tensorflow.python.framework import test_util 26from tensorflow.python.framework import type_spec 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops.ragged import ragged_factory_ops 29from tensorflow.python.ops.ragged import ragged_tensor 30from tensorflow.python.ops.structured import structured_tensor 31from tensorflow.python.ops.structured.structured_tensor import StructuredTensor 32from tensorflow.python.ops.structured.structured_tensor import StructuredTensorSpec 33from tensorflow.python.platform import googletest 34 35 36# TypeSpecs consts for fields types. 37T_3 = tensor_spec.TensorSpec([3]) 38T_1_2 = tensor_spec.TensorSpec([1, 2]) 39T_1_2_8 = tensor_spec.TensorSpec([1, 2, 8]) 40T_1_2_3_4 = tensor_spec.TensorSpec([1, 2, 3, 4]) 41T_2_3 = tensor_spec.TensorSpec([2, 3]) 42R_1_N = ragged_tensor.RaggedTensorSpec([1, None]) 43R_1_N_N = ragged_tensor.RaggedTensorSpec([1, None, None]) 44R_2_1_N = ragged_tensor.RaggedTensorSpec([2, 1, None]) 45 46 47# pylint: disable=g-long-lambda 48@test_util.run_all_in_graph_and_eager_modes 49class StructuredTensorSpecTest(test_util.TensorFlowTestCase, 50 parameterized.TestCase): 51 52 # TODO(edloper): Add a subclass of TensorFlowTestCase that overrides 53 # assertAllEqual etc to work with StructuredTensors. 54 def assertAllEqual(self, a, b, msg=None): 55 if not (isinstance(a, structured_tensor.StructuredTensor) or 56 isinstance(b, structured_tensor.StructuredTensor)): 57 return super(StructuredTensorSpecTest, self).assertAllEqual(a, b, msg) 58 if not (isinstance(a, structured_tensor.StructuredTensor) and 59 isinstance(b, structured_tensor.StructuredTensor)): 60 # TODO(edloper) Add support for this once structured_factory_ops is added. 61 raise ValueError('Not supported yet') 62 63 self.assertEqual(repr(a.shape), repr(b.shape)) 64 self.assertEqual(set(a.field_names()), set(b.field_names())) 65 for field in a.field_names(): 66 self.assertAllEqual(a.field_value(field), b.field_value(field)) 67 68 def assertAllTensorsEqual(self, list1, list2): 69 self.assertLen(list1, len(list2)) 70 for (t1, t2) in zip(list1, list2): 71 self.assertAllEqual(t1, t2) 72 73 def testConstruction(self): 74 spec1_fields = dict(a=T_1_2_3_4) 75 spec1 = StructuredTensorSpec([1, 2, 3], spec1_fields) 76 self.assertEqual(spec1._shape, (1, 2, 3)) 77 self.assertEqual(spec1._field_specs, spec1_fields) 78 79 spec2_fields = dict(a=T_1_2, b=T_1_2_8, c=R_1_N, d=R_1_N_N, s=spec1) 80 spec2 = StructuredTensorSpec([1, 2], spec2_fields) 81 self.assertEqual(spec2._shape, (1, 2)) 82 self.assertEqual(spec2._field_specs, spec2_fields) 83 84 @parameterized.parameters([ 85 (None, {}, r"StructuredTensor's shape must have known rank\."), 86 ([], None, r'field_specs must be a dictionary\.'), 87 ([], {1: tensor_spec.TensorSpec(None)}, 88 r'field_specs must be a dictionary with string keys\.'), 89 ([], {'x': 0}, 90 r'field_specs must be a dictionary with TypeSpec values\.'), 91 ]) 92 def testConstructionErrors(self, shape, field_specs, error): 93 with self.assertRaisesRegex(TypeError, error): 94 structured_tensor.StructuredTensorSpec(shape, field_specs) 95 96 def testValueType(self): 97 spec1 = StructuredTensorSpec([1, 2, 3], dict(a=T_1_2)) 98 self.assertEqual(spec1.value_type, StructuredTensor) 99 100 @parameterized.parameters([ 101 (StructuredTensorSpec([1, 2, 3], {}), 102 (tensor_shape.TensorShape([1, 2, 3]), {})), 103 (StructuredTensorSpec([], {'a': T_1_2}), 104 (tensor_shape.TensorShape([]), {'a': T_1_2})), 105 (StructuredTensorSpec([1, 2], {'a': T_1_2, 'b': R_1_N}), 106 (tensor_shape.TensorShape([1, 2]), {'a': T_1_2, 'b': R_1_N})), 107 (StructuredTensorSpec([], {'a': T_1_2}), 108 (tensor_shape.TensorShape([]), {'a': T_1_2})), 109 ]) # pyformat: disable 110 def testSerialize(self, spec, expected): 111 serialization = spec._serialize() 112 # Note that we can only use assertEqual because none of our cases include 113 # a None dimension. A TensorShape with a None dimension is never equal 114 # to another TensorShape. 115 self.assertEqual(serialization, expected) 116 117 @parameterized.parameters([ 118 (StructuredTensorSpec([1, 2, 3], {}), {}), 119 (StructuredTensorSpec([], {'a': T_1_2}), {'a': T_1_2}), 120 (StructuredTensorSpec([1, 2], {'a': T_1_2, 'b': R_1_N}), 121 {'a': T_1_2, 'b': R_1_N}), 122 (StructuredTensorSpec([], {'a': T_1_2}), {'a': T_1_2}), 123 ]) # pyformat: disable 124 def testComponentSpecs(self, spec, expected): 125 self.assertEqual(spec._component_specs, expected) 126 127 @parameterized.parameters([ 128 { 129 'shape': [], 130 'fields': dict(x=[[1.0, 2.0]]), 131 'field_specs': dict(x=T_1_2), 132 }, 133 # TODO(edloper): Enable this test once we update StructuredTensorSpec 134 # to contain the shared row partitions. 135 #{ 136 # 'shape': [1, 2, 3], 137 # 'fields': {}, 138 # 'field_specs': {}, 139 #}, 140 { 141 'shape': [2], 142 'fields': dict( 143 a=ragged_factory_ops.constant_value([[1.0], [2.0, 3.0]]), 144 b=[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]), 145 'field_specs': dict(a=R_1_N, b=T_2_3), 146 }, 147 ]) # pyformat: disable 148 def testToFromComponents(self, shape, fields, field_specs): 149 components = fields 150 struct = StructuredTensor.from_fields(fields, shape) 151 spec = StructuredTensorSpec(shape, field_specs) 152 actual_components = spec._to_components(struct) 153 self.assertAllTensorsEqual(actual_components, components) 154 rt_reconstructed = spec._from_components(actual_components) 155 self.assertAllEqual(struct, rt_reconstructed) 156 157 @parameterized.parameters([ 158 { 159 'unbatched': StructuredTensorSpec([], {}), 160 'batch_size': 5, 161 'batched': StructuredTensorSpec([5], {}), 162 }, 163 { 164 'unbatched': StructuredTensorSpec([1, 2], {}), 165 'batch_size': 5, 166 'batched': StructuredTensorSpec([5, 1, 2], {}), 167 }, 168 { 169 'unbatched': StructuredTensorSpec([], dict(a=T_3, b=R_1_N)), 170 'batch_size': 2, 171 'batched': StructuredTensorSpec([2], dict(a=T_2_3, b=R_2_1_N)), 172 } 173 ]) # pyformat: disable 174 def testBatchUnbatch(self, unbatched, batch_size, batched): 175 self.assertEqual(unbatched._batch(batch_size), batched) 176 self.assertEqual(batched._unbatch(), unbatched) 177 178 @parameterized.parameters([ 179 { 180 'unbatched': lambda: [ 181 StructuredTensor.from_fields({'a': 1, 'b': [5, 6]}), 182 StructuredTensor.from_fields({'a': 2, 'b': [7, 8]})], 183 'batch_size': 2, 184 'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={ 185 'a': [1, 2], 186 'b': [[5, 6], [7, 8]]}), 187 }, 188 { 189 'unbatched': lambda: [ 190 StructuredTensor.from_fields(shape=[3], fields={ 191 'a': [1, 2, 3], 192 'b': [[5, 6], [6, 7], [7, 8]]}), 193 StructuredTensor.from_fields(shape=[3], fields={ 194 'a': [2, 3, 4], 195 'b': [[2, 2], [3, 3], [4, 4]]})], 196 'batch_size': 2, 197 'batched': lambda: StructuredTensor.from_fields(shape=[2, 3], fields={ 198 'a': [[1, 2, 3], [2, 3, 4]], 199 'b': [[[5, 6], [6, 7], [7, 8]], 200 [[2, 2], [3, 3], [4, 4]]]}), 201 }, 202 { 203 'unbatched': lambda: [ 204 StructuredTensor.from_fields(shape=[], fields={ 205 'a': 1, 206 'b': StructuredTensor.from_fields({'x': [5]})}), 207 StructuredTensor.from_fields(shape=[], fields={ 208 'a': 2, 209 'b': StructuredTensor.from_fields({'x': [6]})})], 210 'batch_size': 2, 211 'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={ 212 'a': [1, 2], 213 'b': StructuredTensor.from_fields(shape=[2], fields={ 214 'x': [[5], [6]]})}), 215 }, 216 { 217 'unbatched': lambda: [ 218 StructuredTensor.from_fields(shape=[], fields={ 219 'Ragged3d': ragged_factory_ops.constant_value([[1, 2], [3]]), 220 'Ragged2d': ragged_factory_ops.constant_value([1]), 221 }), 222 StructuredTensor.from_fields(shape=[], fields={ 223 'Ragged3d': ragged_factory_ops.constant_value([[1]]), 224 'Ragged2d': ragged_factory_ops.constant_value([2, 3]), 225 })], 226 'batch_size': 2, 227 'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={ 228 'Ragged3d': ragged_factory_ops.constant_value( 229 [[[1, 2], [3]], [[1]]]), 230 'Ragged2d': ragged_factory_ops.constant_value([[1], [2, 3]]), 231 }), 232 'use_only_batched_spec': True, 233 }, 234 ]) # pyformat: disable 235 def testBatchUnbatchValues(self, unbatched, batch_size, batched, 236 use_only_batched_spec=False): 237 batched = batched() # Deferred init because it creates tensors. 238 unbatched = unbatched() # Deferred init because it creates tensors. 239 240 # Test batching. 241 if use_only_batched_spec: 242 unbatched_spec = type_spec.type_spec_from_value(batched)._unbatch() 243 else: 244 unbatched_spec = type_spec.type_spec_from_value(unbatched[0]) 245 unbatched_tensor_lists = [unbatched_spec._to_tensor_list(st) 246 for st in unbatched] 247 batched_tensor_list = [array_ops.stack(tensors) 248 for tensors in zip(*unbatched_tensor_lists)] 249 actual_batched = unbatched_spec._batch(batch_size)._from_tensor_list( 250 batched_tensor_list) 251 self.assertTrue( 252 unbatched_spec._batch(batch_size).is_compatible_with(actual_batched)) 253 self.assertAllEqual(actual_batched, batched) 254 255 # Test unbatching 256 batched_spec = type_spec.type_spec_from_value(batched) 257 batched_tensor_list = batched_spec._to_batched_tensor_list(batched) 258 unbatched_tensor_lists = zip( 259 *[array_ops.unstack(tensor) for tensor in batched_tensor_list]) 260 actual_unbatched = [ 261 batched_spec._unbatch()._from_tensor_list(tensor_list) 262 for tensor_list in unbatched_tensor_lists] 263 self.assertLen(actual_unbatched, len(unbatched)) 264 for st in actual_unbatched: 265 self.assertTrue(batched_spec._unbatch().is_compatible_with(st)) 266 for (actual, expected) in zip(actual_unbatched, unbatched): 267 self.assertAllEqual(actual, expected) 268 269 270if __name__ == '__main__': 271 googletest.main() 272