1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base class for testing serializable datasets.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23import numpy as np 24 25from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.ops import lookup_ops 32from tensorflow.python.ops import variables 33from tensorflow.python.ops.ragged import ragged_tensor_value 34from tensorflow.python.platform import gfile 35from tensorflow.python.platform import test 36from tensorflow.python.training import checkpoint_management 37from tensorflow.python.training import saver as saver_lib 38from tensorflow.python.util import nest 39 40 41def remove_variants(get_next_op): 42 # TODO(b/72408568): Remove this once session.run can get variant tensors. 43 """Remove variants from a nest structure, so sess.run will execute.""" 44 45 def _remove_variant(x): 46 if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant: 47 return () 48 else: 49 return x 50 51 return nest.map_structure(_remove_variant, get_next_op) 52 53 54class DatasetSerializationTestBase(test.TestCase): 55 """Base class for testing serializable datasets.""" 56 57 def tearDown(self): 58 self._delete_ckpt() 59 super(DatasetSerializationTestBase, self).tearDown() 60 61 # TODO(b/72657739): Remove sparse_tensor argument, which is to test the 62 # (deprecated) saveable `SparseTensorSliceDataset`, once the API 63 # `from_sparse_tensor_slices()` and related tests are deleted. 64 def run_core_tests(self, ds_fn, num_outputs, sparse_tensors=False): 65 """Runs the core tests. 66 67 Args: 68 ds_fn: 0-argument function that returns a Dataset. 69 num_outputs: Total number of outputs expected from this Dataset. 70 sparse_tensors: Whether dataset is built from SparseTensor(s). 71 72 Raises: 73 AssertionError if any test fails. 74 """ 75 # NOTE: We disable all default optimizations in serialization tests in order 76 # to test the actual dataset in question. 77 options = dataset_ops.Options() 78 options.experimental_optimization.apply_default_optimizations = False 79 80 def ds_fn_no_opt(): 81 return ds_fn().with_options(options) 82 83 self.verify_unused_iterator( 84 ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors) 85 self.verify_fully_used_iterator( 86 ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors) 87 self.verify_exhausted_iterator( 88 ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors) 89 self.verify_multiple_breaks( 90 ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors) 91 self.verify_reset_restored_iterator( 92 ds_fn_no_opt, num_outputs, sparse_tensors=sparse_tensors) 93 94 def verify_unused_iterator(self, 95 ds_fn, 96 num_outputs, 97 sparse_tensors=False, 98 verify_exhausted=True): 99 """Verifies that saving and restoring an unused iterator works. 100 101 Args: 102 ds_fn: See `run_core_tests`. 103 num_outputs: See `run_core_tests`. 104 sparse_tensors: See `run_core_tests`. 105 verify_exhausted: See `gen_outputs`. 106 107 Raises: 108 AssertionError if any test fails. 109 """ 110 self.verify_run_with_breaks( 111 ds_fn, [0], 112 num_outputs, 113 sparse_tensors=sparse_tensors, 114 verify_exhausted=verify_exhausted) 115 116 def verify_fully_used_iterator(self, ds_fn, num_outputs, 117 sparse_tensors=False): 118 """Verifies that saving and restoring a fully used iterator works. 119 120 Note that this only checks saving and restoring an iterator from which 121 `num_outputs` items have been produced but does not check for an 122 exhausted iterator, i.e., one from which an OutOfRange error has been 123 returned. 124 125 Args: 126 ds_fn: See `run_core_tests`. 127 num_outputs: See `run_core_tests`. 128 sparse_tensors: See `run_core_tests`. 129 130 Raises: 131 AssertionError if test fails. 132 """ 133 self.verify_run_with_breaks( 134 ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors) 135 136 def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False): 137 """Verifies that saving and restoring an exhausted iterator works. 138 139 An exhausted iterator is one which has returned an OutOfRange error. 140 141 Args: 142 ds_fn: See `run_core_tests`. 143 num_outputs: See `run_core_tests`. 144 sparse_tensors: See `run_core_tests`. 145 146 Raises: 147 AssertionError if any test fails. 148 """ 149 self.gen_outputs( 150 ds_fn, [], 151 num_outputs, 152 verify_exhausted=True, 153 sparse_tensors=sparse_tensors) 154 actual = self.gen_outputs( 155 ds_fn, [], 156 0, 157 ckpt_saved=True, 158 verify_exhausted=True, 159 sparse_tensors=sparse_tensors) 160 self.assertEqual(len(actual), 0) 161 162 def verify_multiple_breaks(self, 163 ds_fn, 164 num_outputs, 165 num_breaks=10, 166 sparse_tensors=False, 167 verify_exhausted=True): 168 """Attempts to save/restore at multiple break points. 169 170 Args: 171 ds_fn: See `run_core_tests`. 172 num_outputs: See `run_core_tests`. 173 num_breaks: The number of break points. These are uniformly spread in 174 [0, num_outputs] both inclusive. 175 sparse_tensors: See `run_core_tests`. 176 verify_exhausted: See `gen_outputs`. 177 178 Raises: 179 AssertionError if any test fails. 180 """ 181 self.verify_run_with_breaks( 182 ds_fn, 183 self.gen_break_points(num_outputs, num_breaks), 184 num_outputs, 185 sparse_tensors=sparse_tensors, 186 verify_exhausted=verify_exhausted) 187 188 def verify_reset_restored_iterator(self, 189 ds_fn, 190 num_outputs, 191 break_point=None, 192 sparse_tensors=False, 193 verify_exhausted=True): 194 """Attempts to re-initialize a restored iterator. 195 196 This is useful when restoring a training checkpoint during validation. 197 198 Args: 199 ds_fn: See `run_core_tests`. 200 num_outputs: See `run_core_tests`. 201 break_point: Break point. Optional. Defaults to num_outputs/2. 202 sparse_tensors: See `run_core_tests`. 203 verify_exhausted: See `gen_outputs`. 204 205 Raises: 206 AssertionError if any test fails. 207 """ 208 break_point = num_outputs // 2 if not break_point else break_point 209 210 # Collect ground truth containing all outputs. 211 expected = self.gen_outputs( 212 ds_fn, [], 213 num_outputs, 214 sparse_tensors=sparse_tensors, 215 verify_exhausted=verify_exhausted) 216 217 # Skip some items and save checkpoint. 218 self.gen_outputs( 219 ds_fn, [], 220 break_point, 221 sparse_tensors=sparse_tensors, 222 verify_exhausted=False) 223 224 actual = [] 225 # Restore from checkpoint and then run init_op. 226 with ops.Graph().as_default() as g: 227 saver = self._import_meta_graph() 228 init_op, get_next_op = self._get_iterator_ops_from_collection( 229 ds_fn, sparse_tensors=sparse_tensors) 230 get_next_op = remove_variants(get_next_op) 231 with self.session(graph=g) as sess: 232 self._initialize(init_op, sess) 233 self._restore(saver, sess) 234 self._initialize(init_op, sess) 235 for _ in range(num_outputs): 236 actual.append(sess.run(get_next_op)) 237 if verify_exhausted: 238 with self.assertRaises(errors.OutOfRangeError): 239 sess.run(get_next_op) 240 self.match(expected, actual) 241 242 def verify_error_on_save(self, 243 ds_fn, 244 num_outputs, 245 error, 246 break_point=None, 247 sparse_tensors=False): 248 """Attempts to save a non-saveable iterator. 249 250 Args: 251 ds_fn: See `run_core_tests`. 252 num_outputs: See `run_core_tests`. 253 error: Declared error when trying to save iterator. 254 break_point: Break point. Optional. Defaults to num_outputs/2. 255 sparse_tensors: See `run_core_tests`. 256 257 Raises: 258 AssertionError if any test fails. 259 """ 260 261 break_point = num_outputs // 2 if not break_point else break_point 262 with ops.Graph().as_default() as g: 263 init_op, get_next_op, saver = self._build_graph( 264 ds_fn, sparse_tensors=sparse_tensors) 265 get_next_op = remove_variants(get_next_op) 266 with self.session(graph=g) as sess: 267 self._initialize(init_op, sess) 268 for _ in range(break_point): 269 sess.run(get_next_op) 270 with self.assertRaises(error): 271 self._save(sess, saver) 272 273 def verify_run_with_breaks(self, 274 ds_fn, 275 break_points, 276 num_outputs, 277 sparse_tensors=False, 278 verify_exhausted=True): 279 """Verifies that ds_fn() produces the same outputs with and without breaks. 280 281 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it 282 *without* stopping at break points. 283 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it 284 with stopping at break points. 285 286 Deep matches outputs from 1 and 2. 287 288 Args: 289 ds_fn: See `gen_outputs`. 290 break_points: See `gen_outputs`. 291 num_outputs: See `gen_outputs`. 292 sparse_tensors: See `run_core_tests`. 293 verify_exhausted: See `gen_outputs`. 294 295 Raises: 296 AssertionError if any test fails. 297 """ 298 expected = self.gen_outputs( 299 ds_fn, [], 300 num_outputs, 301 sparse_tensors=sparse_tensors, 302 verify_exhausted=verify_exhausted) 303 304 actual = self.gen_outputs( 305 ds_fn, 306 break_points, 307 num_outputs, 308 sparse_tensors=sparse_tensors, 309 verify_exhausted=verify_exhausted) 310 311 self.match(expected, actual) 312 313 def gen_outputs(self, 314 ds_fn, 315 break_points, 316 num_outputs, 317 ckpt_saved=False, 318 sparse_tensors=False, 319 verify_exhausted=True, 320 save_checkpoint_at_end=True): 321 """Generates elements from input dataset while stopping at break points. 322 323 Produces `num_outputs` outputs and saves the state of the iterator in the 324 Saver checkpoint. 325 326 Args: 327 ds_fn: 0-argument function that returns the dataset. 328 break_points: A list of integers. For each `break_point` in 329 `break_points`, we produce outputs till `break_point` number of items 330 have been produced and then checkpoint the state. The current graph 331 and session are destroyed and a new graph and session are used to 332 produce outputs till next checkpoint or till `num_outputs` elements 333 have been produced. `break_point` must be <= `num_outputs`. 334 num_outputs: The total number of outputs to produce from the iterator. 335 ckpt_saved: Whether a checkpoint already exists. 336 sparse_tensors: Whether dataset is built from SparseTensor(s). 337 verify_exhausted: Whether to verify that the iterator has been exhausted 338 after producing `num_outputs` elements. 339 save_checkpoint_at_end: Whether to save a checkpoint after producing all 340 outputs. If False, checkpoints are saved each break point but not at the 341 end. Note that checkpoints overwrite each other so there is always only 342 a single checkpoint available. Defaults to True. 343 344 Returns: 345 A list of `num_outputs` items. 346 """ 347 outputs = [] 348 349 def get_ops(): 350 if ckpt_saved: 351 saver = self._import_meta_graph() 352 init_op, get_next_op = self._get_iterator_ops_from_collection( 353 ds_fn, sparse_tensors=sparse_tensors) 354 else: 355 init_op, get_next_op, saver = self._build_graph( 356 ds_fn, sparse_tensors=sparse_tensors) 357 return init_op, get_next_op, saver 358 359 for i in range(len(break_points) + 1): 360 with ops.Graph().as_default() as g: 361 init_op, get_next_op, saver = get_ops() 362 get_next_op = remove_variants(get_next_op) 363 with self.session(graph=g) as sess: 364 if ckpt_saved: 365 self._initialize(init_op, sess) 366 self._restore(saver, sess) 367 else: 368 self._initialize(init_op, sess) 369 start = break_points[i - 1] if i > 0 else 0 370 end = break_points[i] if i < len(break_points) else num_outputs 371 num_iters = end - start 372 for _ in range(num_iters): 373 outputs.append(sess.run(get_next_op)) 374 if i == len(break_points) and verify_exhausted: 375 with self.assertRaises(errors.OutOfRangeError): 376 sess.run(get_next_op) 377 if save_checkpoint_at_end or i < len(break_points): 378 self._save(sess, saver) 379 ckpt_saved = True 380 381 return outputs 382 383 def match(self, expected, actual): 384 """Matches nested structures. 385 386 Recursively matches shape and values of `expected` and `actual`. 387 Handles scalars, numpy arrays and other python sequence containers 388 e.g. list, dict, as well as SparseTensorValue and RaggedTensorValue. 389 390 Args: 391 expected: Nested structure 1. 392 actual: Nested structure 2. 393 394 Raises: 395 AssertionError if matching fails. 396 """ 397 if isinstance(expected, np.ndarray): 398 expected = expected.tolist() 399 if isinstance(actual, np.ndarray): 400 actual = actual.tolist() 401 self.assertEqual(type(expected), type(actual)) 402 403 if nest.is_sequence(expected): 404 self.assertEqual(len(expected), len(actual)) 405 if isinstance(expected, dict): 406 for key1, key2 in zip(sorted(expected), sorted(actual)): 407 self.assertEqual(key1, key2) 408 self.match(expected[key1], actual[key2]) 409 else: 410 for item1, item2 in zip(expected, actual): 411 self.match(item1, item2) 412 elif isinstance(expected, sparse_tensor.SparseTensorValue): 413 self.match((expected.indices, expected.values, expected.dense_shape), 414 (actual.indices, actual.values, actual.dense_shape)) 415 elif isinstance(expected, ragged_tensor_value.RaggedTensorValue): 416 self.match((expected.values, expected.row_splits), 417 (actual.values, actual.row_splits)) 418 else: 419 self.assertEqual(expected, actual) 420 421 def does_not_match(self, expected, actual): 422 with self.assertRaises(AssertionError): 423 self.match(expected, actual) 424 425 def gen_break_points(self, num_outputs, num_samples=10): 426 """Generates `num_samples` breaks points in [0, num_outputs].""" 427 return np.linspace(0, num_outputs, num_samples, dtype=int) 428 429 def _build_graph(self, ds_fn, sparse_tensors=False): 430 iterator = dataset_ops.make_initializable_iterator(ds_fn()) 431 432 saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) 433 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 434 init_op = iterator.initializer 435 if sparse_tensors: 436 get_next = sparse_tensor.SparseTensor(*iterator.get_next()) 437 else: 438 get_next = iterator.get_next() 439 self._add_iterator_ops_to_collection(init_op, get_next, ds_fn, 440 sparse_tensors) 441 saver = saver_lib.Saver(allow_empty=True) 442 return init_op, get_next, saver 443 444 def _add_iterator_ops_to_collection(self, 445 init_op, 446 get_next, 447 ds_fn, 448 sparse_tensors=False): 449 ops.add_to_collection("iterator_ops", init_op) 450 # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections 451 # do not support tuples we flatten the tensors and restore the shape in 452 # `_get_iterator_ops_from_collection`. 453 if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. 454 ops.add_to_collection("iterator_ops", get_next.indices) 455 ops.add_to_collection("iterator_ops", get_next.values) 456 ops.add_to_collection("iterator_ops", get_next.dense_shape) 457 return 458 459 get_next_list = nest.flatten(get_next) 460 for i, output_class in enumerate( 461 nest.flatten(self._get_output_classes(ds_fn))): 462 if output_class is sparse_tensor.SparseTensor: 463 ops.add_to_collection("iterator_ops", get_next_list[i].indices) 464 ops.add_to_collection("iterator_ops", get_next_list[i].values) 465 ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape) 466 else: 467 ops.add_to_collection("iterator_ops", get_next_list[i]) 468 469 def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): 470 all_ops = ops.get_collection("iterator_ops") 471 if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. 472 init_op, indices, values, dense_shape = all_ops 473 return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) 474 get_next_list = [] 475 i = 1 476 for output_class in nest.flatten(self._get_output_classes(ds_fn)): 477 if output_class is sparse_tensor.SparseTensor: 478 indices, values, dense_shape = all_ops[i:i + 3] 479 i += 3 480 get_next_list.append( 481 sparse_tensor.SparseTensor(indices, values, dense_shape)) 482 else: 483 get_next_list.append(all_ops[i]) 484 i += 1 485 return all_ops[0], nest.pack_sequence_as( 486 self._get_output_types(ds_fn), get_next_list) 487 488 def _get_output_types(self, ds_fn): 489 with ops.Graph().as_default(): 490 return dataset_ops.get_legacy_output_types(ds_fn()) 491 492 def _get_output_shapes(self, ds_fn): 493 with ops.Graph().as_default(): 494 return dataset_ops.get_legacy_output_shapes(ds_fn()) 495 496 def _get_output_classes(self, ds_fn): 497 with ops.Graph().as_default(): 498 return dataset_ops.get_legacy_output_classes(ds_fn()) 499 500 def _ckpt_path(self): 501 return os.path.join(self.get_temp_dir(), "iterator") 502 503 def _latest_ckpt(self): 504 return checkpoint_management.latest_checkpoint(self.get_temp_dir()) 505 506 def _save(self, sess, saver): 507 saver.save(sess, self._ckpt_path()) 508 509 def _restore(self, saver, sess): 510 sess.run(lookup_ops.tables_initializer()) 511 saver.restore(sess, self._latest_ckpt()) 512 513 def _initialize(self, init_op, sess): 514 sess.run(variables.global_variables_initializer()) 515 sess.run(lookup_ops.tables_initializer()) 516 sess.run(init_op) 517 518 def _import_meta_graph(self): 519 meta_file_path = self._ckpt_path() + ".meta" 520 return saver_lib.import_meta_graph(meta_file_path) 521 522 def _delete_ckpt(self): 523 # Remove all checkpoint files. 524 prefix = self._ckpt_path() 525 pattern = prefix + "*" 526 files = gfile.Glob(pattern) 527 map(gfile.Remove, files) 528