1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for testing serializable datasets."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23import numpy as np
24
25from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.ops import lookup_ops
32from tensorflow.python.ops import variables
33from tensorflow.python.ops.ragged import ragged_tensor_value
34from tensorflow.python.platform import gfile
35from tensorflow.python.platform import test
36from tensorflow.python.training import checkpoint_management
37from tensorflow.python.training import saver as saver_lib
38from tensorflow.python.util import nest
39
40
41def remove_variants(get_next_op):
42  # TODO(b/72408568): Remove this once session.run can get variant tensors.
43  """Remove variants from a nest structure, so sess.run will execute."""
44
45  def _remove_variant(x):
46    if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
47      return ()
48    else:
49      return x
50
51  return nest.map_structure(_remove_variant, get_next_op)
52
53
54class DatasetSerializationTestBase(test.TestCase):
55  """Base class for testing serializable datasets."""
56
57  def tearDown(self):
58    self._delete_ckpt()
59    super(DatasetSerializationTestBase, self).tearDown()
60
61  # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
62  # (deprecated) saveable `SparseTensorSliceDataset`, once the API
63  # `from_sparse_tensor_slices()` and related tests are deleted.
64  def run_core_tests(self, ds_fn, num_outputs, sparse_tensors=False):
65    """Runs the core tests.
66
67    Args:
68      ds_fn: 0-argument function that returns a Dataset.
69      num_outputs: Total number of outputs expected from this Dataset.
70      sparse_tensors: Whether dataset is built from SparseTensor(s).
71
72    Raises:
73      AssertionError if any test fails.
74    """
75    # NOTE: We disable all default optimizations in serialization tests in order
76    # to test the actual dataset in question.
77    options = dataset_ops.Options()
78    options.experimental_optimization.apply_default_optimizations = False
79
80    def ds_fn_no_opt():
81      return ds_fn().with_options(options)
82
83    self.verify_unused_iterator(
84        ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors)
85    self.verify_fully_used_iterator(
86        ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors)
87    self.verify_exhausted_iterator(
88        ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors)
89    self.verify_multiple_breaks(
90        ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors)
91    self.verify_reset_restored_iterator(
92        ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors)
93
94  def verify_unused_iterator(self,
95                             ds_fn,
96                             num_outputs,
97                             sparse_tensors=False,
98                             verify_exhausted=True):
99    """Verifies that saving and restoring an unused iterator works.
100
101    Args:
102      ds_fn: See `run_core_tests`.
103      num_outputs: See `run_core_tests`.
104      sparse_tensors: See `run_core_tests`.
105      verify_exhausted: See `gen_outputs`.
106
107    Raises:
108      AssertionError if any test fails.
109    """
110    self.verify_run_with_breaks(
111        ds_fn, [0],
112        num_outputs,
113        sparse_tensors=sparse_tensors,
114        verify_exhausted=verify_exhausted)
115
116  def verify_fully_used_iterator(self, ds_fn, num_outputs,
117                                 sparse_tensors=False):
118    """Verifies that saving and restoring a fully used iterator works.
119
120    Note that this only checks saving and restoring an iterator from which
121    `num_outputs` items have been produced but does not check for an
122    exhausted iterator, i.e., one from which an OutOfRange error has been
123    returned.
124
125    Args:
126      ds_fn: See `run_core_tests`.
127      num_outputs: See `run_core_tests`.
128      sparse_tensors: See `run_core_tests`.
129
130    Raises:
131      AssertionError if test fails.
132    """
133    self.verify_run_with_breaks(
134        ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
135
136  def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
137    """Verifies that saving and restoring an exhausted iterator works.
138
139    An exhausted iterator is one which has returned an OutOfRange error.
140
141    Args:
142      ds_fn: See `run_core_tests`.
143      num_outputs: See `run_core_tests`.
144      sparse_tensors: See `run_core_tests`.
145
146    Raises:
147      AssertionError if any test fails.
148    """
149    self.gen_outputs(
150        ds_fn, [],
151        num_outputs,
152        verify_exhausted=True,
153        sparse_tensors=sparse_tensors)
154    actual = self.gen_outputs(
155        ds_fn, [],
156        0,
157        ckpt_saved=True,
158        verify_exhausted=True,
159        sparse_tensors=sparse_tensors)
160    self.assertEqual(len(actual), 0)
161
162  def verify_multiple_breaks(self,
163                             ds_fn,
164                             num_outputs,
165                             num_breaks=10,
166                             sparse_tensors=False,
167                             verify_exhausted=True):
168    """Attempts to save/restore at multiple break points.
169
170    Args:
171      ds_fn: See `run_core_tests`.
172      num_outputs: See `run_core_tests`.
173      num_breaks: The number of break points. These are uniformly spread in
174        [0, num_outputs] both inclusive.
175      sparse_tensors: See `run_core_tests`.
176      verify_exhausted: See `gen_outputs`.
177
178    Raises:
179      AssertionError if any test fails.
180    """
181    self.verify_run_with_breaks(
182        ds_fn,
183        self.gen_break_points(num_outputs, num_breaks),
184        num_outputs,
185        sparse_tensors=sparse_tensors,
186        verify_exhausted=verify_exhausted)
187
188  def verify_reset_restored_iterator(self,
189                                     ds_fn,
190                                     num_outputs,
191                                     break_point=None,
192                                     sparse_tensors=False,
193                                     verify_exhausted=True):
194    """Attempts to re-initialize a restored iterator.
195
196    This is useful when restoring a training checkpoint during validation.
197
198    Args:
199      ds_fn: See `run_core_tests`.
200      num_outputs: See `run_core_tests`.
201      break_point: Break point. Optional. Defaults to num_outputs/2.
202      sparse_tensors: See `run_core_tests`.
203      verify_exhausted: See `gen_outputs`.
204
205    Raises:
206      AssertionError if any test fails.
207    """
208    break_point = num_outputs // 2 if not break_point else break_point
209
210    # Collect ground truth containing all outputs.
211    expected = self.gen_outputs(
212        ds_fn, [],
213        num_outputs,
214        sparse_tensors=sparse_tensors,
215        verify_exhausted=verify_exhausted)
216
217    # Skip some items and save checkpoint.
218    self.gen_outputs(
219        ds_fn, [],
220        break_point,
221        sparse_tensors=sparse_tensors,
222        verify_exhausted=False)
223
224    actual = []
225    # Restore from checkpoint and then run init_op.
226    with ops.Graph().as_default() as g:
227      saver = self._import_meta_graph()
228      init_op, get_next_op = self._get_iterator_ops_from_collection(
229          ds_fn, sparse_tensors=sparse_tensors)
230      get_next_op = remove_variants(get_next_op)
231      with self.session(graph=g) as sess:
232        self._initialize(init_op, sess)
233        self._restore(saver, sess)
234        self._initialize(init_op, sess)
235        for _ in range(num_outputs):
236          actual.append(sess.run(get_next_op))
237        if verify_exhausted:
238          with self.assertRaises(errors.OutOfRangeError):
239            sess.run(get_next_op)
240    self.match(expected, actual)
241
242  def verify_error_on_save(self,
243                           ds_fn,
244                           num_outputs,
245                           error,
246                           break_point=None,
247                           sparse_tensors=False):
248    """Attempts to save a non-saveable iterator.
249
250    Args:
251      ds_fn: See `run_core_tests`.
252      num_outputs: See `run_core_tests`.
253      error: Declared error when trying to save iterator.
254      break_point: Break point. Optional. Defaults to num_outputs/2.
255      sparse_tensors: See `run_core_tests`.
256
257    Raises:
258      AssertionError if any test fails.
259    """
260
261    break_point = num_outputs // 2 if not break_point else break_point
262    with ops.Graph().as_default() as g:
263      init_op, get_next_op, saver = self._build_graph(
264          ds_fn, sparse_tensors=sparse_tensors)
265      get_next_op = remove_variants(get_next_op)
266      with self.session(graph=g) as sess:
267        self._initialize(init_op, sess)
268        for _ in range(break_point):
269          sess.run(get_next_op)
270        with self.assertRaises(error):
271          self._save(sess, saver)
272
273  def verify_run_with_breaks(self,
274                             ds_fn,
275                             break_points,
276                             num_outputs,
277                             sparse_tensors=False,
278                             verify_exhausted=True):
279    """Verifies that ds_fn() produces the same outputs with and without breaks.
280
281    1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
282       *without* stopping at break points.
283    2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
284       with stopping at break points.
285
286    Deep matches outputs from 1 and 2.
287
288    Args:
289      ds_fn: See `gen_outputs`.
290      break_points: See `gen_outputs`.
291      num_outputs: See `gen_outputs`.
292      sparse_tensors: See `run_core_tests`.
293      verify_exhausted: See `gen_outputs`.
294
295    Raises:
296      AssertionError if any test fails.
297    """
298    expected = self.gen_outputs(
299        ds_fn, [],
300        num_outputs,
301        sparse_tensors=sparse_tensors,
302        verify_exhausted=verify_exhausted)
303
304    actual = self.gen_outputs(
305        ds_fn,
306        break_points,
307        num_outputs,
308        sparse_tensors=sparse_tensors,
309        verify_exhausted=verify_exhausted)
310
311    self.match(expected, actual)
312
313  def gen_outputs(self,
314                  ds_fn,
315                  break_points,
316                  num_outputs,
317                  ckpt_saved=False,
318                  sparse_tensors=False,
319                  verify_exhausted=True,
320                  save_checkpoint_at_end=True):
321    """Generates elements from input dataset while stopping at break points.
322
323    Produces `num_outputs` outputs and saves the state of the iterator in the
324    Saver checkpoint.
325
326    Args:
327      ds_fn: 0-argument function that returns the dataset.
328      break_points: A list of integers. For each `break_point` in
329        `break_points`, we produce outputs till `break_point` number of items
330        have been produced and then checkpoint the state. The current graph
331        and session are destroyed and a new graph and session are used to
332        produce outputs till next checkpoint or till `num_outputs` elements
333        have been produced. `break_point` must be <= `num_outputs`.
334      num_outputs: The total number of outputs to produce from the iterator.
335      ckpt_saved: Whether a checkpoint already exists.
336      sparse_tensors:  Whether dataset is built from SparseTensor(s).
337      verify_exhausted: Whether to verify that the iterator has been exhausted
338        after producing `num_outputs` elements.
339      save_checkpoint_at_end: Whether to save a checkpoint after producing all
340        outputs. If False, checkpoints are saved each break point but not at the
341        end. Note that checkpoints overwrite each other so there is always only
342        a single checkpoint available. Defaults to True.
343
344    Returns:
345      A list of `num_outputs` items.
346    """
347    outputs = []
348
349    def get_ops():
350      if ckpt_saved:
351        saver = self._import_meta_graph()
352        init_op, get_next_op = self._get_iterator_ops_from_collection(
353            ds_fn, sparse_tensors=sparse_tensors)
354      else:
355        init_op, get_next_op, saver = self._build_graph(
356            ds_fn, sparse_tensors=sparse_tensors)
357      return init_op, get_next_op, saver
358
359    for i in range(len(break_points) + 1):
360      with ops.Graph().as_default() as g:
361        init_op, get_next_op, saver = get_ops()
362        get_next_op = remove_variants(get_next_op)
363        with self.session(graph=g) as sess:
364          if ckpt_saved:
365            self._initialize(init_op, sess)
366            self._restore(saver, sess)
367          else:
368            self._initialize(init_op, sess)
369          start = break_points[i - 1] if i > 0 else 0
370          end = break_points[i] if i < len(break_points) else num_outputs
371          num_iters = end - start
372          for _ in range(num_iters):
373            outputs.append(sess.run(get_next_op))
374          if i == len(break_points) and verify_exhausted:
375            with self.assertRaises(errors.OutOfRangeError):
376              sess.run(get_next_op)
377          if save_checkpoint_at_end or i < len(break_points):
378            self._save(sess, saver)
379            ckpt_saved = True
380
381    return outputs
382
383  def match(self, expected, actual):
384    """Matches nested structures.
385
386    Recursively matches shape and values of `expected` and `actual`.
387    Handles scalars, numpy arrays and other python sequence containers
388    e.g. list, dict, as well as SparseTensorValue and RaggedTensorValue.
389
390    Args:
391      expected: Nested structure 1.
392      actual: Nested structure 2.
393
394    Raises:
395      AssertionError if matching fails.
396    """
397    if isinstance(expected, np.ndarray):
398      expected = expected.tolist()
399    if isinstance(actual, np.ndarray):
400      actual = actual.tolist()
401    self.assertEqual(type(expected), type(actual))
402
403    if nest.is_sequence(expected):
404      self.assertEqual(len(expected), len(actual))
405      if isinstance(expected, dict):
406        for key1, key2 in zip(sorted(expected), sorted(actual)):
407          self.assertEqual(key1, key2)
408          self.match(expected[key1], actual[key2])
409      else:
410        for item1, item2 in zip(expected, actual):
411          self.match(item1, item2)
412    elif isinstance(expected, sparse_tensor.SparseTensorValue):
413      self.match((expected.indices, expected.values, expected.dense_shape),
414                 (actual.indices, actual.values, actual.dense_shape))
415    elif isinstance(expected, ragged_tensor_value.RaggedTensorValue):
416      self.match((expected.values, expected.row_splits),
417                 (actual.values, actual.row_splits))
418    else:
419      self.assertEqual(expected, actual)
420
421  def does_not_match(self, expected, actual):
422    with self.assertRaises(AssertionError):
423      self.match(expected, actual)
424
425  def gen_break_points(self, num_outputs, num_samples=10):
426    """Generates `num_samples` breaks points in [0, num_outputs]."""
427    return np.linspace(0, num_outputs, num_samples, dtype=int)
428
429  def _build_graph(self, ds_fn, sparse_tensors=False):
430    iterator = dataset_ops.make_initializable_iterator(ds_fn())
431
432    saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
433    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
434    init_op = iterator.initializer
435    if sparse_tensors:
436      get_next = sparse_tensor.SparseTensor(*iterator.get_next())
437    else:
438      get_next = iterator.get_next()
439    self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
440                                         sparse_tensors)
441    saver = saver_lib.Saver(allow_empty=True)
442    return init_op, get_next, saver
443
444  def _add_iterator_ops_to_collection(self,
445                                      init_op,
446                                      get_next,
447                                      ds_fn,
448                                      sparse_tensors=False):
449    ops.add_to_collection("iterator_ops", init_op)
450    # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
451    # do not support tuples we flatten the tensors and restore the shape in
452    # `_get_iterator_ops_from_collection`.
453    if sparse_tensors:  # specific for deprecated `from_sparse_tensor_slices`.
454      ops.add_to_collection("iterator_ops", get_next.indices)
455      ops.add_to_collection("iterator_ops", get_next.values)
456      ops.add_to_collection("iterator_ops", get_next.dense_shape)
457      return
458
459    get_next_list = nest.flatten(get_next)
460    for i, output_class in enumerate(
461        nest.flatten(self._get_output_classes(ds_fn))):
462      if output_class is sparse_tensor.SparseTensor:
463        ops.add_to_collection("iterator_ops", get_next_list[i].indices)
464        ops.add_to_collection("iterator_ops", get_next_list[i].values)
465        ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
466      else:
467        ops.add_to_collection("iterator_ops", get_next_list[i])
468
469  def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
470    all_ops = ops.get_collection("iterator_ops")
471    if sparse_tensors:  # specific for deprecated `from_sparse_tensor_slices`.
472      init_op, indices, values, dense_shape = all_ops
473      return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
474    get_next_list = []
475    i = 1
476    for output_class in nest.flatten(self._get_output_classes(ds_fn)):
477      if output_class is sparse_tensor.SparseTensor:
478        indices, values, dense_shape = all_ops[i:i + 3]
479        i += 3
480        get_next_list.append(
481            sparse_tensor.SparseTensor(indices, values, dense_shape))
482      else:
483        get_next_list.append(all_ops[i])
484        i += 1
485    return all_ops[0], nest.pack_sequence_as(
486        self._get_output_types(ds_fn), get_next_list)
487
488  def _get_output_types(self, ds_fn):
489    with ops.Graph().as_default():
490      return dataset_ops.get_legacy_output_types(ds_fn())
491
492  def _get_output_shapes(self, ds_fn):
493    with ops.Graph().as_default():
494      return dataset_ops.get_legacy_output_shapes(ds_fn())
495
496  def _get_output_classes(self, ds_fn):
497    with ops.Graph().as_default():
498      return dataset_ops.get_legacy_output_classes(ds_fn())
499
500  def _ckpt_path(self):
501    return os.path.join(self.get_temp_dir(), "iterator")
502
503  def _latest_ckpt(self):
504    return checkpoint_management.latest_checkpoint(self.get_temp_dir())
505
506  def _save(self, sess, saver):
507    saver.save(sess, self._ckpt_path())
508
509  def _restore(self, saver, sess):
510    sess.run(lookup_ops.tables_initializer())
511    saver.restore(sess, self._latest_ckpt())
512
513  def _initialize(self, init_op, sess):
514    sess.run(variables.global_variables_initializer())
515    sess.run(lookup_ops.tables_initializer())
516    sess.run(init_op)
517
518  def _import_meta_graph(self):
519    meta_file_path = self._ckpt_path() + ".meta"
520    return saver_lib.import_meta_graph(meta_file_path)
521
522  def _delete_ckpt(self):
523    # Remove all checkpoint files.
524    prefix = self._ckpt_path()
525    pattern = prefix + "*"
526    files = gfile.Glob(pattern)
527    map(gfile.Remove, files)
528