1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for testing serializable datasets."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23import numpy as np
24
25from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.data.ops import iterator_ops
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import errors
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.ops import lookup_ops
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import gfile
35from tensorflow.python.platform import test
36from tensorflow.python.training import checkpoint_management
37from tensorflow.python.training import saver as saver_lib
38from tensorflow.python.util import nest
39
40
41def remove_variants(get_next_op):
42  # TODO(b/72408568): Remove this once session.run can get
43  # variant tensors.
44  """Remove variants from a nest structure, so sess.run will execute."""
45
46  def _remove_variant(x):
47    if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
48      return ()
49    else:
50      return x
51
52  return nest.map_structure(_remove_variant, get_next_op)
53
54
55class DatasetSerializationTestBase(test.TestCase):
56  """Base class for testing serializable datasets."""
57
58  def tearDown(self):
59    self._delete_ckpt()
60
61  # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
62  # (deprecated) saveable `SparseTensorSliceDataset`, once the API
63  # `from_sparse_tensor_slices()`and related tests are deleted.
64  def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
65    """Runs the core tests.
66
67    Args:
68      ds_fn1: 0-argument function that returns a Dataset.
69      ds_fn2: 0-argument function that returns a Dataset different from
70        ds_fn1. If None, verify_restore_in_modified_graph test is not run.
71      num_outputs: Total number of outputs expected from this Dataset.
72      sparse_tensors: Whether dataset is built from SparseTensor(s).
73
74    Raises:
75      AssertionError if any test fails.
76    """
77    # NOTE: We disable all default optimizations in serialization tests in order
78    # to test the actual dataset in question.
79    options = dataset_ops.Options()
80    options.experimental_optimization.apply_default_optimizations = False
81
82    def ds_fn1_no_opt():
83      return ds_fn1().with_options(options)
84
85    self.verify_unused_iterator(
86        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
87    self.verify_fully_used_iterator(
88        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
89    self.verify_exhausted_iterator(
90        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
91    self.verify_init_before_restore(
92        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
93    self.verify_multiple_breaks(
94        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
95    self.verify_reset_restored_iterator(
96        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
97    self.verify_restore_in_empty_graph(
98        ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)
99    if ds_fn2:
100
101      def ds_fn2_no_opt():
102        return ds_fn2().with_options(options)
103
104      self.verify_restore_in_modified_graph(
105          ds_fn1_no_opt,
106          ds_fn2_no_opt,
107          num_outputs,
108          sparse_tensors=sparse_tensors)
109
110  def verify_unused_iterator(self,
111                             ds_fn,
112                             num_outputs,
113                             sparse_tensors=False,
114                             verify_exhausted=True):
115    """Verifies that saving and restoring an unused iterator works.
116
117    Args:
118      ds_fn: See `run_core_tests`.
119      num_outputs: See `run_core_tests`.
120      sparse_tensors: See `run_core_tests`.
121      verify_exhausted: See `gen_outputs`.
122
123    Raises:
124      AssertionError if any test fails.
125    """
126    self.verify_run_with_breaks(
127        ds_fn, [0],
128        num_outputs,
129        sparse_tensors=sparse_tensors,
130        verify_exhausted=verify_exhausted)
131
132  def verify_fully_used_iterator(self, ds_fn, num_outputs,
133                                 sparse_tensors=False):
134    """Verifies that saving and restoring a fully used iterator works.
135
136    Note that this only checks saving and restoring an iterator from which
137    `num_outputs` items have been produced but does not check for an
138    exhausted iterator, i.e., one from which an OutOfRange error has been
139    returned.
140
141    Args:
142      ds_fn: See `run_core_tests`.
143      num_outputs: See `run_core_tests`.
144      sparse_tensors: See `run_core_tests`.
145
146    Raises:
147      AssertionError if test fails.
148    """
149    self.verify_run_with_breaks(
150        ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
151
152  def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
153    """Verifies that saving and restoring an exhausted iterator works.
154
155    An exhausted iterator is one which has returned an OutOfRange error.
156
157    Args:
158      ds_fn: See `run_core_tests`.
159      num_outputs: See `run_core_tests`.
160      sparse_tensors: See `run_core_tests`.
161
162    Raises:
163      AssertionError if any test fails.
164    """
165    self.gen_outputs(
166        ds_fn, [],
167        num_outputs,
168        verify_exhausted=True,
169        sparse_tensors=sparse_tensors)
170    actual = self.gen_outputs(
171        ds_fn, [],
172        0,
173        ckpt_saved=True,
174        verify_exhausted=True,
175        sparse_tensors=sparse_tensors)
176    self.assertEqual(len(actual), 0)
177
178  def verify_init_before_restore(self,
179                                 ds_fn,
180                                 num_outputs,
181                                 sparse_tensors=False,
182                                 verify_exhausted=True):
183    """Verifies that restoring into an already initialized iterator works.
184
185    Args:
186      ds_fn: See `run_core_tests`.
187      num_outputs: See `run_core_tests`.
188      sparse_tensors: See `run_core_tests`.
189      verify_exhausted: See `gen_outputs`.
190
191    Raises:
192      AssertionError if any test fails.
193    """
194    self.verify_run_with_breaks(
195        ds_fn,
196        self.gen_break_points(num_outputs),
197        num_outputs,
198        init_before_restore=True,
199        sparse_tensors=sparse_tensors,
200        verify_exhausted=verify_exhausted)
201
202  def verify_multiple_breaks(self,
203                             ds_fn,
204                             num_outputs,
205                             num_breaks=10,
206                             sparse_tensors=False,
207                             verify_exhausted=True):
208    """Attempts to save/restore at multiple break points.
209
210    Args:
211      ds_fn: See `run_core_tests`.
212      num_outputs: See `run_core_tests`.
213      num_breaks: The number of break points. These are uniformly spread in
214        [0, num_outputs] both inclusive.
215      sparse_tensors: See `run_core_tests`.
216      verify_exhausted: See `gen_outputs`.
217
218    Raises:
219      AssertionError if any test fails.
220    """
221    self.verify_run_with_breaks(
222        ds_fn,
223        self.gen_break_points(num_outputs, num_breaks),
224        num_outputs,
225        sparse_tensors=sparse_tensors,
226        verify_exhausted=verify_exhausted)
227
228  def verify_reset_restored_iterator(self,
229                                     ds_fn,
230                                     num_outputs,
231                                     break_point=None,
232                                     sparse_tensors=False,
233                                     verify_exhausted=True):
234    """Attempts to re-initialize a restored iterator.
235
236    This is useful when restoring a training checkpoint during validation.
237
238    Args:
239      ds_fn: See `run_core_tests`.
240      num_outputs: See `run_core_tests`.
241      break_point: Break point. Optional. Defaults to num_outputs/2.
242      sparse_tensors: See `run_core_tests`.
243      verify_exhausted: See `gen_outputs`.
244
245    Raises:
246      AssertionError if any test fails.
247    """
248    break_point = num_outputs // 2 if not break_point else break_point
249
250    # Collect ground truth containing all outputs.
251    expected = self.gen_outputs(
252        ds_fn, [],
253        num_outputs,
254        sparse_tensors=sparse_tensors,
255        verify_exhausted=verify_exhausted)
256
257    # Skip some items and save checkpoint.
258    self.gen_outputs(
259        ds_fn, [],
260        break_point,
261        sparse_tensors=sparse_tensors,
262        verify_exhausted=False)
263
264    actual = []
265    # Restore from checkpoint and then run init_op.
266    with ops.Graph().as_default() as g:
267      saver = self._import_meta_graph()
268      init_op, get_next_op = self._get_iterator_ops_from_collection(
269          ds_fn, sparse_tensors=sparse_tensors)
270      get_next_op = remove_variants(get_next_op)
271      with self.session(graph=g) as sess:
272        self._restore(saver, sess)
273        self._initialize(init_op, sess)
274        for _ in range(num_outputs):
275          actual.append(sess.run(get_next_op))
276        if verify_exhausted:
277          with self.assertRaises(errors.OutOfRangeError):
278            sess.run(get_next_op)
279    self.match(expected, actual)
280
281  def verify_restore_in_modified_graph(self,
282                                       ds_fn1,
283                                       ds_fn2,
284                                       num_outputs,
285                                       break_point=None,
286                                       sparse_tensors=False,
287                                       verify_exhausted=True):
288    """Attempts to restore an iterator in a modified graph.
289
290    Builds an input pipeline using ds_fn1, runs it for `break_point` steps
291    and saves a checkpoint. Then builds a new graph using ds_fn2, restores
292    the checkpoint from ds_fn1 and verifies that the restore is successful.
293
294    Args:
295      ds_fn1: See `run_core_tests`.
296      ds_fn2: See `run_core_tests`.
297      num_outputs: See `run_core_tests`.
298      break_point: Break point. Optional. Defaults to num_outputs/2.
299      sparse_tensors: See `run_core_tests`.
300      verify_exhausted: See `gen_outputs`.
301
302    Raises:
303      AssertionError if any test fails.
304    """
305    break_point = num_outputs // 2 if not break_point else break_point
306
307    # Skip `break_point` items and store the remaining produced from ds_fn1
308    # in `expected`.
309    self.gen_outputs(
310        ds_fn1, [],
311        break_point,
312        sparse_tensors=sparse_tensors,
313        verify_exhausted=False)
314    expected = self.gen_outputs(
315        ds_fn1, [],
316        num_outputs - break_point,
317        ckpt_saved=True,
318        sparse_tensors=sparse_tensors,
319        verify_exhausted=verify_exhausted)
320
321    # Generate `break_point` items from ds_fn1 and save checkpoint.
322    self.gen_outputs(
323        ds_fn1, [],
324        break_point,
325        sparse_tensors=sparse_tensors,
326        verify_exhausted=False)
327
328    actual = []
329    # Build graph for ds_fn2 but load checkpoint for ds_fn1.
330    with ops.Graph().as_default() as g:
331      _, get_next_op, saver = self._build_graph(
332          ds_fn2, sparse_tensors=sparse_tensors)
333      get_next_op = remove_variants(get_next_op)
334      with self.session(graph=g) as sess:
335        self._restore(saver, sess)
336        for _ in range(num_outputs - break_point):
337          actual.append(sess.run(get_next_op))
338        if verify_exhausted:
339          with self.assertRaises(errors.OutOfRangeError):
340            sess.run(get_next_op)
341
342    self.match(expected, actual)
343
344  def verify_restore_in_empty_graph(self,
345                                    ds_fn,
346                                    num_outputs,
347                                    break_point=None,
348                                    sparse_tensors=False,
349                                    verify_exhausted=True):
350    """Attempts to restore an iterator in an empty graph.
351
352    Builds an input pipeline using ds_fn, runs it for `break_point` steps
353    and saves a checkpoint. Then builds a new empty graph, restores
354    the checkpoint from ds_fn and verifies that the restore is successful.
355
356    Args:
357      ds_fn: See `run_core_tests`.
358      num_outputs: See `run_core_tests`.
359      break_point: Break point. Optional. Defaults to num_outputs/2.
360      sparse_tensors: See `run_core_tests`.
361      verify_exhausted: See `gen_outputs`.
362
363    Raises:
364      AssertionError if any test fails.
365    """
366    break_point = num_outputs // 2 if not break_point else break_point
367
368    # Skip `break_point` items and store the remaining produced from ds_fn
369    # in `expected`.
370    self.gen_outputs(
371        ds_fn, [],
372        break_point,
373        sparse_tensors=sparse_tensors,
374        verify_exhausted=False)
375    expected = self.gen_outputs(
376        ds_fn, [],
377        num_outputs - break_point,
378        ckpt_saved=True,
379        sparse_tensors=sparse_tensors,
380        verify_exhausted=verify_exhausted)
381
382    # Generate `break_point` items from ds_fn and save checkpoint.
383    self.gen_outputs(
384        ds_fn, [],
385        break_point,
386        sparse_tensors=sparse_tensors,
387        verify_exhausted=False)
388
389    actual = []
390    # Build an empty graph but load checkpoint for ds_fn.
391    with ops.Graph().as_default() as g:
392      get_next_op, saver = self._build_empty_graph(
393          ds_fn, sparse_tensors=sparse_tensors)
394      get_next_op = remove_variants(get_next_op)
395      with self.session(graph=g) as sess:
396        self._restore(saver, sess)
397        for _ in range(num_outputs - break_point):
398          actual.append(sess.run(get_next_op))
399        if verify_exhausted:
400          with self.assertRaises(errors.OutOfRangeError):
401            sess.run(get_next_op)
402
403    self.match(expected, actual)
404
405  def verify_error_on_save(self,
406                           ds_fn,
407                           num_outputs,
408                           error,
409                           break_point=None,
410                           sparse_tensors=False):
411    """Attempts to save a non-saveable iterator.
412
413    Args:
414      ds_fn: See `run_core_tests`.
415      num_outputs: See `run_core_tests`.
416      error: Declared error when trying to save iterator.
417      break_point: Break point. Optional. Defaults to num_outputs/2.
418      sparse_tensors: See `run_core_tests`.
419
420    Raises:
421      AssertionError if any test fails.
422    """
423
424    break_point = num_outputs // 2 if not break_point else break_point
425    with ops.Graph().as_default() as g:
426      init_op, get_next_op, saver = self._build_graph(
427          ds_fn, sparse_tensors=sparse_tensors)
428      get_next_op = remove_variants(get_next_op)
429      with self.session(graph=g) as sess:
430        self._initialize(init_op, sess)
431        for _ in range(break_point):
432          sess.run(get_next_op)
433        with self.assertRaises(error):
434          self._save(sess, saver)
435
436  def verify_run_with_breaks(self,
437                             ds_fn,
438                             break_points,
439                             num_outputs,
440                             init_before_restore=False,
441                             sparse_tensors=False,
442                             verify_exhausted=True):
443    """Verifies that ds_fn() produces the same outputs with and without breaks.
444
445    1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
446       *without* stopping at break points.
447    2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
448       with stopping at break points.
449
450    Deep matches outputs from 1 and 2.
451
452    Args:
453      ds_fn: See `gen_outputs`.
454      break_points: See `gen_outputs`.
455      num_outputs: See `gen_outputs`.
456      init_before_restore: See `gen_outputs`.
457      sparse_tensors: See `run_core_tests`.
458      verify_exhausted: See `gen_outputs`.
459
460    Raises:
461      AssertionError if any test fails.
462    """
463    expected = self.gen_outputs(
464        ds_fn, [],
465        num_outputs,
466        init_before_restore=init_before_restore,
467        sparse_tensors=sparse_tensors,
468        verify_exhausted=verify_exhausted)
469
470    actual = self.gen_outputs(
471        ds_fn,
472        break_points,
473        num_outputs,
474        init_before_restore=init_before_restore,
475        sparse_tensors=sparse_tensors,
476        verify_exhausted=verify_exhausted)
477
478    self.match(expected, actual)
479
480  def gen_outputs(self,
481                  ds_fn,
482                  break_points,
483                  num_outputs,
484                  ckpt_saved=False,
485                  init_before_restore=False,
486                  sparse_tensors=False,
487                  verify_exhausted=True,
488                  save_checkpoint_at_end=True):
489    """Generates elements from input dataset while stopping at break points.
490
491    Produces `num_outputs` outputs and saves the state of the iterator in the
492    Saver checkpoint.
493
494    Args:
495      ds_fn: 0-argument function that returns the dataset.
496      break_points: A list of integers. For each `break_point` in
497        `break_points`, we produce outputs till `break_point` number of items
498        have been produced and then checkpoint the state. The current graph
499        and session are destroyed and a new graph and session are used to
500        produce outputs till next checkpoint or till `num_outputs` elements
501        have been produced. `break_point` must be <= `num_outputs`.
502      num_outputs: The total number of outputs to produce from the iterator.
503      ckpt_saved: Whether a checkpoint already exists. If False, we build the
504        graph from ds_fn.
505      init_before_restore: Whether init should be called before saver.restore.
506        This is just so that we can verify that restoring an already initialized
507        iterator works.
508      sparse_tensors:  Whether dataset is built from SparseTensor(s).
509      verify_exhausted: Whether to verify that the iterator has been exhausted
510        after producing `num_outputs` elements.
511      save_checkpoint_at_end: Whether to save a checkpoint after producing all
512        outputs. If False, checkpoints are saved each break point but not at the
513        end. Note that checkpoints overwrite each other so there is always only
514        a single checkpoint available. Defaults to True.
515
516    Returns:
517      A list of `num_outputs` items.
518    """
519    outputs = []
520
521    def get_ops():
522      if ckpt_saved:
523        saver = self._import_meta_graph()
524        init_op, get_next_op = self._get_iterator_ops_from_collection(
525            ds_fn, sparse_tensors=sparse_tensors)
526      else:
527        init_op, get_next_op, saver = self._build_graph(
528            ds_fn, sparse_tensors=sparse_tensors)
529      return init_op, get_next_op, saver
530
531    for i in range(len(break_points) + 1):
532      with ops.Graph().as_default() as g:
533        init_op, get_next_op, saver = get_ops()
534        get_next_op = remove_variants(get_next_op)
535        with self.session(graph=g) as sess:
536          if ckpt_saved:
537            if init_before_restore:
538              self._initialize(init_op, sess)
539            self._restore(saver, sess)
540          else:
541            self._initialize(init_op, sess)
542          start = break_points[i - 1] if i > 0 else 0
543          end = break_points[i] if i < len(break_points) else num_outputs
544          num_iters = end - start
545          for _ in range(num_iters):
546            outputs.append(sess.run(get_next_op))
547          if i == len(break_points) and verify_exhausted:
548            with self.assertRaises(errors.OutOfRangeError):
549              sess.run(get_next_op)
550          if save_checkpoint_at_end or i < len(break_points):
551            self._save(sess, saver)
552            ckpt_saved = True
553
554    return outputs
555
556  def match(self, expected, actual):
557    """Matches nested structures.
558
559    Recursively matches shape and values of `expected` and `actual`.
560    Handles scalars, numpy arrays and other python sequence containers
561    e.g. list, dict.
562
563    Args:
564      expected: Nested structure 1.
565      actual: Nested structure 2.
566
567    Raises:
568      AssertionError if matching fails.
569    """
570    if isinstance(expected, np.ndarray):
571      expected = expected.tolist()
572    if isinstance(actual, np.ndarray):
573      actual = actual.tolist()
574    self.assertEqual(type(expected), type(actual))
575
576    if nest.is_sequence(expected):
577      self.assertEqual(len(expected), len(actual))
578      if isinstance(expected, dict):
579        for key1, key2 in zip(sorted(expected), sorted(actual)):
580          self.assertEqual(key1, key2)
581          self.match(expected[key1], actual[key2])
582      else:
583        for item1, item2 in zip(expected, actual):
584          self.match(item1, item2)
585    else:
586      self.assertEqual(expected, actual)
587
588  def does_not_match(self, expected, actual):
589    with self.assertRaises(AssertionError):
590      self.match(expected, actual)
591
592  def gen_break_points(self, num_outputs, num_samples=10):
593    """Generates `num_samples` breaks points in [0, num_outputs]."""
594    return np.linspace(0, num_outputs, num_samples, dtype=int)
595
596  def _build_graph(self, ds_fn, sparse_tensors=False):
597    iterator = dataset_ops.make_initializable_iterator(ds_fn())
598
599    saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
600    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
601    init_op = iterator.initializer
602    if sparse_tensors:
603      get_next = sparse_tensor.SparseTensor(*iterator.get_next())
604    else:
605      get_next = iterator.get_next()
606    self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
607                                         sparse_tensors)
608    saver = saver_lib.Saver(allow_empty=True)
609    return init_op, get_next, saver
610
611  def _build_empty_graph(self, ds_fn, sparse_tensors=False):
612    iterator = iterator_ops.Iterator.from_structure(
613        self._get_output_types(ds_fn),
614        output_shapes=self._get_output_shapes(ds_fn),
615        output_classes=self._get_output_classes(ds_fn))
616    saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
617    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
618    if sparse_tensors:
619      get_next = sparse_tensor.SparseTensor(*iterator.get_next())
620    else:
621      get_next = iterator.get_next()
622    saver = saver_lib.Saver(allow_empty=True)
623    return get_next, saver
624
625  def _add_iterator_ops_to_collection(self,
626                                      init_op,
627                                      get_next,
628                                      ds_fn,
629                                      sparse_tensors=False):
630    ops.add_to_collection("iterator_ops", init_op)
631    # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
632    # do not support tuples we flatten the tensors and restore the shape in
633    # `_get_iterator_ops_from_collection`.
634    if sparse_tensors:  # specific for deprecated `from_sparse_tensor_slices`.
635      ops.add_to_collection("iterator_ops", get_next.indices)
636      ops.add_to_collection("iterator_ops", get_next.values)
637      ops.add_to_collection("iterator_ops", get_next.dense_shape)
638      return
639
640    get_next_list = nest.flatten(get_next)
641    for i, output_class in enumerate(
642        nest.flatten(self._get_output_classes(ds_fn))):
643      if output_class is sparse_tensor.SparseTensor:
644        ops.add_to_collection("iterator_ops", get_next_list[i].indices)
645        ops.add_to_collection("iterator_ops", get_next_list[i].values)
646        ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
647      else:
648        ops.add_to_collection("iterator_ops", get_next_list[i])
649
650  def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
651    all_ops = ops.get_collection("iterator_ops")
652    if sparse_tensors:  # specific for deprecated `from_sparse_tensor_slices`.
653      init_op, indices, values, dense_shape = all_ops
654      return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
655    get_next_list = []
656    i = 1
657    for output_class in nest.flatten(self._get_output_classes(ds_fn)):
658      if output_class is sparse_tensor.SparseTensor:
659        indices, values, dense_shape = all_ops[i:i + 3]
660        i += 3
661        get_next_list.append(
662            sparse_tensor.SparseTensor(indices, values, dense_shape))
663      else:
664        get_next_list.append(all_ops[i])
665        i += 1
666    return all_ops[0], nest.pack_sequence_as(
667        self._get_output_types(ds_fn), get_next_list)
668
669  def _get_output_types(self, ds_fn):
670    with ops.Graph().as_default():
671      return dataset_ops.get_legacy_output_types(ds_fn())
672
673  def _get_output_shapes(self, ds_fn):
674    with ops.Graph().as_default():
675      return dataset_ops.get_legacy_output_shapes(ds_fn())
676
677  def _get_output_classes(self, ds_fn):
678    with ops.Graph().as_default():
679      return dataset_ops.get_legacy_output_classes(ds_fn())
680
681  def _ckpt_path(self):
682    return os.path.join(self.get_temp_dir(), "iterator")
683
684  def _latest_ckpt(self):
685    return checkpoint_management.latest_checkpoint(self.get_temp_dir())
686
687  def _save(self, sess, saver):
688    saver.save(sess, self._ckpt_path())
689
690  def _restore(self, saver, sess):
691    sess.run(lookup_ops.tables_initializer())
692    saver.restore(sess, self._latest_ckpt())
693
694  def _initialize(self, init_op, sess):
695    sess.run(variables.global_variables_initializer())
696    sess.run(lookup_ops.tables_initializer())
697    sess.run(init_op)
698
699  def _import_meta_graph(self):
700    meta_file_path = self._ckpt_path() + ".meta"
701    return saver_lib.import_meta_graph(meta_file_path)
702
703  def _delete_ckpt(self):
704    # Remove all checkpoint files.
705    prefix = self._ckpt_path()
706    pattern = prefix + "*"
707    files = gfile.Glob(pattern)
708    map(gfile.Remove, files)
709