1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ignore_errors dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.ops import gen_experimental_dataset_ops
22from tensorflow.python.util.tf_export import tf_export
23from tensorflow.python.compat import compat
24
25
26@tf_export("data.experimental.ignore_errors")
27def ignore_errors(log_warning=False):
28  """Creates a `Dataset` from another `Dataset` and silently ignores any errors.
29
30  Use this transformation to produce a dataset that contains the same elements
31  as the input, but silently drops any elements that caused an error. For
32  example:
33
34  ```python
35  dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.])
36
37  # Computing `tf.debugging.check_numerics(1. / 0.)` will raise an
38  InvalidArgumentError.
39  dataset = dataset.map(lambda x: tf.debugging.check_numerics(1. / x, "error"))
40
41  # Using `ignore_errors()` will drop the element that causes an error.
42  dataset =
43      dataset.apply(tf.data.experimental.ignore_errors())  # ==> {1., 0.5, 0.2}
44  ```
45  Args:
46     log_warning: (Optional.) A 'tf.bool' scalar indicating whether ignored
47      errors should be logged to stderr. Defaults to 'False'.
48
49  Returns:
50    A `Dataset` transformation function, which can be passed to
51    `tf.data.Dataset.apply`.
52  """
53
54  def _apply_fn(dataset):
55    return _IgnoreErrorsDataset(dataset, log_warning)
56
57  return _apply_fn
58
59
60class _IgnoreErrorsDataset(dataset_ops.UnaryUnchangedStructureDataset):
61  """A `Dataset` that silently ignores errors when computing its input."""
62
63  def __init__(self, input_dataset, log_warning):
64    """See `Dataset.ignore_errors()` for details."""
65    self._input_dataset = input_dataset
66    if compat.forward_compatible(2020, 8, 26) or log_warning:
67      variant_tensor = (
68          gen_experimental_dataset_ops.ignore_errors_dataset(
69              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
70              log_warning=log_warning,
71              **self._flat_structure))
72    else:
73      variant_tensor = (
74          gen_experimental_dataset_ops.ignore_errors_dataset(
75              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
76              **self._flat_structure))
77    super(_IgnoreErrorsDataset, self).__init__(input_dataset, variant_tensor)
78