1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for StructuredTensorSpec."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.framework import tensor_spec
25from tensorflow.python.framework import test_util
26from tensorflow.python.framework import type_spec
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops.ragged import ragged_factory_ops
29from tensorflow.python.ops.ragged import ragged_tensor
30from tensorflow.python.ops.structured import structured_tensor
31from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
32from tensorflow.python.ops.structured.structured_tensor import StructuredTensorSpec
33from tensorflow.python.platform import googletest
34
35
36# TypeSpecs consts for fields types.
37T_3 = tensor_spec.TensorSpec([3])
38T_1_2 = tensor_spec.TensorSpec([1, 2])
39T_1_2_8 = tensor_spec.TensorSpec([1, 2, 8])
40T_1_2_3_4 = tensor_spec.TensorSpec([1, 2, 3, 4])
41T_2_3 = tensor_spec.TensorSpec([2, 3])
42R_1_N = ragged_tensor.RaggedTensorSpec([1, None])
43R_1_N_N = ragged_tensor.RaggedTensorSpec([1, None, None])
44R_2_1_N = ragged_tensor.RaggedTensorSpec([2, 1, None])
45
46
47# pylint: disable=g-long-lambda
48@test_util.run_all_in_graph_and_eager_modes
49class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
50                               parameterized.TestCase):
51
52  # TODO(edloper): Add a subclass of TensorFlowTestCase that overrides
53  # assertAllEqual etc to work with StructuredTensors.
54  def assertAllEqual(self, a, b, msg=None):
55    if not (isinstance(a, structured_tensor.StructuredTensor) or
56            isinstance(b, structured_tensor.StructuredTensor)):
57      return super(StructuredTensorSpecTest, self).assertAllEqual(a, b, msg)
58    if not (isinstance(a, structured_tensor.StructuredTensor) and
59            isinstance(b, structured_tensor.StructuredTensor)):
60      # TODO(edloper) Add support for this once structured_factory_ops is added.
61      raise ValueError('Not supported yet')
62
63    self.assertEqual(repr(a.shape), repr(b.shape))
64    self.assertEqual(set(a.field_names()), set(b.field_names()))
65    for field in a.field_names():
66      self.assertAllEqual(a.field_value(field), b.field_value(field))
67
68  def assertAllTensorsEqual(self, list1, list2):
69    self.assertLen(list1, len(list2))
70    for (t1, t2) in zip(list1, list2):
71      self.assertAllEqual(t1, t2)
72
73  def testConstruction(self):
74    spec1_fields = dict(a=T_1_2_3_4)
75    spec1 = StructuredTensorSpec([1, 2, 3], spec1_fields)
76    self.assertEqual(spec1._shape, (1, 2, 3))
77    self.assertEqual(spec1._field_specs, spec1_fields)
78
79    spec2_fields = dict(a=T_1_2, b=T_1_2_8, c=R_1_N, d=R_1_N_N, s=spec1)
80    spec2 = StructuredTensorSpec([1, 2], spec2_fields)
81    self.assertEqual(spec2._shape, (1, 2))
82    self.assertEqual(spec2._field_specs, spec2_fields)
83
84  @parameterized.parameters([
85      (None, {}, r"StructuredTensor's shape must have known rank\."),
86      ([], None, r'field_specs must be a dictionary\.'),
87      ([], {1: tensor_spec.TensorSpec(None)},
88       r'field_specs must be a dictionary with string keys\.'),
89      ([], {'x': 0},
90       r'field_specs must be a dictionary with TypeSpec values\.'),
91  ])
92  def testConstructionErrors(self, shape, field_specs, error):
93    with self.assertRaisesRegex(TypeError, error):
94      structured_tensor.StructuredTensorSpec(shape, field_specs)
95
96  def testValueType(self):
97    spec1 = StructuredTensorSpec([1, 2, 3], dict(a=T_1_2))
98    self.assertEqual(spec1.value_type, StructuredTensor)
99
100  @parameterized.parameters([
101      (StructuredTensorSpec([1, 2, 3], {}),
102       (tensor_shape.TensorShape([1, 2, 3]), {})),
103      (StructuredTensorSpec([], {'a': T_1_2}),
104       (tensor_shape.TensorShape([]), {'a': T_1_2})),
105      (StructuredTensorSpec([1, 2], {'a': T_1_2, 'b': R_1_N}),
106       (tensor_shape.TensorShape([1, 2]), {'a': T_1_2, 'b': R_1_N})),
107      (StructuredTensorSpec([], {'a': T_1_2}),
108       (tensor_shape.TensorShape([]), {'a': T_1_2})),
109  ])  # pyformat: disable
110  def testSerialize(self, spec, expected):
111    serialization = spec._serialize()
112    # Note that we can only use assertEqual because none of our cases include
113    # a None dimension. A TensorShape with a None dimension is never equal
114    # to another TensorShape.
115    self.assertEqual(serialization, expected)
116
117  @parameterized.parameters([
118      (StructuredTensorSpec([1, 2, 3], {}), {}),
119      (StructuredTensorSpec([], {'a': T_1_2}), {'a': T_1_2}),
120      (StructuredTensorSpec([1, 2], {'a': T_1_2, 'b': R_1_N}),
121       {'a': T_1_2, 'b': R_1_N}),
122      (StructuredTensorSpec([], {'a': T_1_2}), {'a': T_1_2}),
123  ])  # pyformat: disable
124  def testComponentSpecs(self, spec, expected):
125    self.assertEqual(spec._component_specs, expected)
126
127  @parameterized.parameters([
128      {
129          'shape': [],
130          'fields': dict(x=[[1.0, 2.0]]),
131          'field_specs': dict(x=T_1_2),
132      },
133      # TODO(edloper): Enable this test once we update StructuredTensorSpec
134      # to contain the shared row partitions.
135      #{
136      #    'shape': [1, 2, 3],
137      #    'fields': {},
138      #    'field_specs': {},
139      #},
140      {
141          'shape': [2],
142          'fields': dict(
143              a=ragged_factory_ops.constant_value([[1.0], [2.0, 3.0]]),
144              b=[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
145          'field_specs': dict(a=R_1_N, b=T_2_3),
146      },
147  ])  # pyformat: disable
148  def testToFromComponents(self, shape, fields, field_specs):
149    components = fields
150    struct = StructuredTensor.from_fields(fields, shape)
151    spec = StructuredTensorSpec(shape, field_specs)
152    actual_components = spec._to_components(struct)
153    self.assertAllTensorsEqual(actual_components, components)
154    rt_reconstructed = spec._from_components(actual_components)
155    self.assertAllEqual(struct, rt_reconstructed)
156
157  @parameterized.parameters([
158      {
159          'unbatched': StructuredTensorSpec([], {}),
160          'batch_size': 5,
161          'batched': StructuredTensorSpec([5], {}),
162      },
163      {
164          'unbatched': StructuredTensorSpec([1, 2], {}),
165          'batch_size': 5,
166          'batched': StructuredTensorSpec([5, 1, 2], {}),
167      },
168      {
169          'unbatched': StructuredTensorSpec([], dict(a=T_3, b=R_1_N)),
170          'batch_size': 2,
171          'batched': StructuredTensorSpec([2], dict(a=T_2_3, b=R_2_1_N)),
172      }
173  ])  # pyformat: disable
174  def testBatchUnbatch(self, unbatched, batch_size, batched):
175    self.assertEqual(unbatched._batch(batch_size), batched)
176    self.assertEqual(batched._unbatch(), unbatched)
177
178  @parameterized.parameters([
179      {
180          'unbatched': lambda: [
181              StructuredTensor.from_fields({'a': 1, 'b': [5, 6]}),
182              StructuredTensor.from_fields({'a': 2, 'b': [7, 8]})],
183          'batch_size': 2,
184          'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={
185              'a': [1, 2],
186              'b': [[5, 6], [7, 8]]}),
187      },
188      {
189          'unbatched': lambda: [
190              StructuredTensor.from_fields(shape=[3], fields={
191                  'a': [1, 2, 3],
192                  'b': [[5, 6], [6, 7], [7, 8]]}),
193              StructuredTensor.from_fields(shape=[3], fields={
194                  'a': [2, 3, 4],
195                  'b': [[2, 2], [3, 3], [4, 4]]})],
196          'batch_size': 2,
197          'batched': lambda: StructuredTensor.from_fields(shape=[2, 3], fields={
198              'a': [[1, 2, 3], [2, 3, 4]],
199              'b': [[[5, 6], [6, 7], [7, 8]],
200                    [[2, 2], [3, 3], [4, 4]]]}),
201      },
202      {
203          'unbatched': lambda: [
204              StructuredTensor.from_fields(shape=[], fields={
205                  'a': 1,
206                  'b': StructuredTensor.from_fields({'x': [5]})}),
207              StructuredTensor.from_fields(shape=[], fields={
208                  'a': 2,
209                  'b': StructuredTensor.from_fields({'x': [6]})})],
210          'batch_size': 2,
211          'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={
212              'a': [1, 2],
213              'b': StructuredTensor.from_fields(shape=[2], fields={
214                  'x': [[5], [6]]})}),
215      },
216      {
217          'unbatched': lambda: [
218              StructuredTensor.from_fields(shape=[], fields={
219                  'Ragged3d': ragged_factory_ops.constant_value([[1, 2], [3]]),
220                  'Ragged2d': ragged_factory_ops.constant_value([1]),
221              }),
222              StructuredTensor.from_fields(shape=[], fields={
223                  'Ragged3d': ragged_factory_ops.constant_value([[1]]),
224                  'Ragged2d': ragged_factory_ops.constant_value([2, 3]),
225              })],
226          'batch_size': 2,
227          'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={
228              'Ragged3d': ragged_factory_ops.constant_value(
229                  [[[1, 2], [3]], [[1]]]),
230              'Ragged2d': ragged_factory_ops.constant_value([[1], [2, 3]]),
231          }),
232          'use_only_batched_spec': True,
233      },
234  ])  # pyformat: disable
235  def testBatchUnbatchValues(self, unbatched, batch_size, batched,
236                             use_only_batched_spec=False):
237    batched = batched()  # Deferred init because it creates tensors.
238    unbatched = unbatched()  # Deferred init because it creates tensors.
239
240    # Test batching.
241    if use_only_batched_spec:
242      unbatched_spec = type_spec.type_spec_from_value(batched)._unbatch()
243    else:
244      unbatched_spec = type_spec.type_spec_from_value(unbatched[0])
245    unbatched_tensor_lists = [unbatched_spec._to_tensor_list(st)
246                              for st in unbatched]
247    batched_tensor_list = [array_ops.stack(tensors)
248                           for tensors in zip(*unbatched_tensor_lists)]
249    actual_batched = unbatched_spec._batch(batch_size)._from_tensor_list(
250        batched_tensor_list)
251    self.assertTrue(
252        unbatched_spec._batch(batch_size).is_compatible_with(actual_batched))
253    self.assertAllEqual(actual_batched, batched)
254
255    # Test unbatching
256    batched_spec = type_spec.type_spec_from_value(batched)
257    batched_tensor_list = batched_spec._to_batched_tensor_list(batched)
258    unbatched_tensor_lists = zip(
259        *[array_ops.unstack(tensor) for tensor in batched_tensor_list])
260    actual_unbatched = [
261        batched_spec._unbatch()._from_tensor_list(tensor_list)
262        for tensor_list in unbatched_tensor_lists]
263    self.assertLen(actual_unbatched, len(unbatched))
264    for st in actual_unbatched:
265      self.assertTrue(batched_spec._unbatch().is_compatible_with(st))
266    for (actual, expected) in zip(actual_unbatched, unbatched):
267      self.assertAllEqual(actual, expected)
268
269
270if __name__ == '__main__':
271  googletest.main()
272