1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utilities for tf.data functionality."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import re
21
22from tensorflow.python import tf2
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.data.util import nest
25from tensorflow.python.eager import context
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.ops import array_ops
30from tensorflow.python.platform import test
31
32
33class DatasetTestBase(test.TestCase):
34  """Base class for dataset tests."""
35
36  @classmethod
37  def setUpClass(cls):
38    if tf2.enabled():
39      dataset_ops.Dataset = dataset_ops.DatasetV2
40    else:
41      dataset_ops.Dataset = dataset_ops.DatasetV1
42
43  def assertSparseValuesEqual(self, a, b):
44    """Asserts that two SparseTensors/SparseTensorValues are equal."""
45    self.assertAllEqual(a.indices, b.indices)
46    self.assertAllEqual(a.values, b.values)
47    self.assertAllEqual(a.dense_shape, b.dense_shape)
48
49  def getNext(self, dataset, requires_initialization=False):
50    """Returns a callable that returns the next element of the dataset.
51
52    Example use:
53    ```python
54    # In both graph and eager modes
55    dataset = ...
56    get_next = self.getNext(dataset)
57    result = self.evaluate(get_next())
58    ```
59
60    Args:
61      dataset: A dataset whose elements will be returned.
62      requires_initialization: Indicates that when the test is executed in graph
63        mode, it should use an initializable iterator to iterate through the
64        dataset (e.g. when it contains stateful nodes). Defaults to False.
65    Returns:
66      A callable that returns the next element of `dataset`.
67    """
68    if context.executing_eagerly():
69      iterator = iter(dataset)
70      return iterator._next_internal  # pylint: disable=protected-access
71    else:
72      if requires_initialization:
73        iterator = dataset_ops.make_initializable_iterator(dataset)
74        self.evaluate(iterator.initializer)
75      else:
76        iterator = dataset_ops.make_one_shot_iterator(dataset)
77      get_next = iterator.get_next()
78      return lambda: get_next
79
80  def _compareOutputToExpected(self, result_values, expected_values,
81                               assert_items_equal):
82    if assert_items_equal:
83      # TODO(shivaniagrawal): add support for nested elements containing sparse
84      # tensors when needed.
85      self.assertItemsEqual(result_values, expected_values)
86      return
87    for i in range(len(result_values)):
88      nest.assert_same_structure(result_values[i], expected_values[i])
89      for result_value, expected_value in zip(
90          nest.flatten(result_values[i]), nest.flatten(expected_values[i])):
91        if sparse_tensor.is_sparse(result_value):
92          self.assertSparseValuesEqual(result_value, expected_value)
93        else:
94          self.assertAllEqual(result_value, expected_value)
95
96  def assertDatasetProduces(self,
97                            dataset,
98                            expected_output=None,
99                            expected_shapes=None,
100                            expected_error=None,
101                            requires_initialization=False,
102                            num_test_iterations=1,
103                            assert_items_equal=False,
104                            expected_error_iter=1):
105    """Asserts that a dataset produces the expected output / error.
106
107    Args:
108      dataset: A dataset to check for the expected output / error.
109      expected_output: A list of elements that the dataset is expected to
110        produce.
111      expected_shapes: A list of TensorShapes which is expected to match
112        output_shapes of dataset.
113      expected_error: A tuple `(type, predicate)` identifying the expected error
114        `dataset` should raise. The `type` should match the expected exception
115        type, while `predicate` should either be 1) a unary function that inputs
116        the raised exception and returns a boolean indicator of success or 2) a
117        regular expression that is expected to match the error message
118        partially.
119      requires_initialization: Indicates that when the test is executed in graph
120        mode, it should use an initializable iterator to iterate through the
121        dataset (e.g. when it contains stateful nodes). Defaults to False.
122      num_test_iterations: Number of times `dataset` will be iterated. Defaults
123        to 2.
124      assert_items_equal: Tests expected_output has (only) the same elements
125        regardless of order.
126      expected_error_iter: How many times to iterate before expecting an error,
127        if an error is expected.
128    """
129    self.assertTrue(
130        expected_error is not None or expected_output is not None,
131        "Exactly one of expected_output or expected error should be provided.")
132    if expected_error:
133      self.assertTrue(
134          expected_output is None,
135          "Exactly one of expected_output or expected error should be provided."
136      )
137      with self.assertRaisesWithPredicateMatch(expected_error[0],
138                                               expected_error[1]):
139        get_next = self.getNext(
140            dataset, requires_initialization=requires_initialization)
141        for _ in range(expected_error_iter):
142          self.evaluate(get_next())
143      return
144    if expected_shapes:
145      self.assertEqual(expected_shapes,
146                       dataset_ops.get_legacy_output_shapes(dataset))
147    self.assertGreater(num_test_iterations, 0)
148    for _ in range(num_test_iterations):
149      get_next = self.getNext(
150          dataset, requires_initialization=requires_initialization)
151      result = []
152      for _ in range(len(expected_output)):
153        result.append(self.evaluate(get_next()))
154      self._compareOutputToExpected(result, expected_output, assert_items_equal)
155      with self.assertRaises(errors.OutOfRangeError):
156        self.evaluate(get_next())
157      with self.assertRaises(errors.OutOfRangeError):
158        self.evaluate(get_next())
159
160  def assertDatasetsEqual(self, dataset1, dataset2):
161    """Checks that datasets are equal. Supports both graph and eager mode."""
162    self.assertTrue(dataset_ops.get_structure(dataset1).is_compatible_with(
163        dataset_ops.get_structure(dataset2)))
164    self.assertTrue(dataset_ops.get_structure(dataset2).is_compatible_with(
165        dataset_ops.get_structure(dataset1)))
166    flattened_types = nest.flatten(
167        dataset_ops.get_legacy_output_types(dataset1))
168
169    next1 = self.getNext(dataset1)
170    next2 = self.getNext(dataset2)
171    while True:
172      try:
173        op1 = self.evaluate(next1())
174      except errors.OutOfRangeError:
175        with self.assertRaises(errors.OutOfRangeError):
176          self.evaluate(next2())
177        break
178      op2 = self.evaluate(next2())
179
180      op1 = nest.flatten(op1)
181      op2 = nest.flatten(op2)
182      assert len(op1) == len(op2)
183      for i in range(len(op1)):
184        if sparse_tensor.is_sparse(op1[i]):
185          self.assertSparseValuesEqual(op1[i], op2[i])
186        elif flattened_types[i] == dtypes.string:
187          self.assertAllEqual(op1[i], op2[i])
188        else:
189          self.assertAllClose(op1[i], op2[i])
190
191  def assertDatasetsRaiseSameError(self,
192                                   dataset1,
193                                   dataset2,
194                                   exception_class,
195                                   replacements=None):
196    """Checks that datasets raise the same error on the first get_next call."""
197    if replacements is None:
198      replacements = []
199    next1 = self.getNext(dataset1)
200    next2 = self.getNext(dataset2)
201    try:
202      self.evaluate(next1())
203      raise ValueError(
204          "Expected dataset to raise an error of type %s, but it did not." %
205          repr(exception_class))
206    except exception_class as e:
207      expected_message = e.message
208      for old, new, count in replacements:
209        expected_message = expected_message.replace(old, new, count)
210      # Check that the first segment of the error messages are the same.
211      with self.assertRaisesRegexp(exception_class,
212                                   re.escape(expected_message)):
213        self.evaluate(next2())
214
215  def structuredDataset(self, structure, shape=None, dtype=dtypes.int64):
216    """Returns a singleton dataset with the given structure."""
217    if shape is None:
218      shape = []
219    if structure is None:
220      return dataset_ops.Dataset.from_tensors(
221          array_ops.zeros(shape, dtype=dtype))
222    else:
223      return dataset_ops.Dataset.zip(
224          tuple([
225              self.structuredDataset(substructure, shape, dtype)
226              for substructure in structure
227          ]))
228
229  def structuredElement(self, structure, shape=None, dtype=dtypes.int64):
230    """Returns an element with the given structure."""
231    if shape is None:
232      shape = []
233    if structure is None:
234      return array_ops.zeros(shape, dtype=dtype)
235    else:
236      return tuple([
237          self.structuredElement(substructure, shape, dtype)
238          for substructure in structure
239      ])
240