1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ignore_errors dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.ops import gen_experimental_dataset_ops 22from tensorflow.python.util.tf_export import tf_export 23from tensorflow.python.compat import compat 24 25 26@tf_export("data.experimental.ignore_errors") 27def ignore_errors(log_warning=False): 28 """Creates a `Dataset` from another `Dataset` and silently ignores any errors. 29 30 Use this transformation to produce a dataset that contains the same elements 31 as the input, but silently drops any elements that caused an error. For 32 example: 33 34 ```python 35 dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.]) 36 37 # Computing `tf.debugging.check_numerics(1. / 0.)` will raise an 38 InvalidArgumentError. 39 dataset = dataset.map(lambda x: tf.debugging.check_numerics(1. / x, "error")) 40 41 # Using `ignore_errors()` will drop the element that causes an error. 42 dataset = 43 dataset.apply(tf.data.experimental.ignore_errors()) # ==> {1., 0.5, 0.2} 44 ``` 45 Args: 46 log_warning: (Optional.) A 'tf.bool' scalar indicating whether ignored 47 errors should be logged to stderr. Defaults to 'False'. 48 49 Returns: 50 A `Dataset` transformation function, which can be passed to 51 `tf.data.Dataset.apply`. 52 """ 53 54 def _apply_fn(dataset): 55 return _IgnoreErrorsDataset(dataset, log_warning) 56 57 return _apply_fn 58 59 60class _IgnoreErrorsDataset(dataset_ops.UnaryUnchangedStructureDataset): 61 """A `Dataset` that silently ignores errors when computing its input.""" 62 63 def __init__(self, input_dataset, log_warning): 64 """See `Dataset.ignore_errors()` for details.""" 65 self._input_dataset = input_dataset 66 if compat.forward_compatible(2020, 8, 26) or log_warning: 67 variant_tensor = ( 68 gen_experimental_dataset_ops.ignore_errors_dataset( 69 self._input_dataset._variant_tensor, # pylint: disable=protected-access 70 log_warning=log_warning, 71 **self._flat_structure)) 72 else: 73 variant_tensor = ( 74 gen_experimental_dataset_ops.ignore_errors_dataset( 75 self._input_dataset._variant_tensor, # pylint: disable=protected-access 76 **self._flat_structure)) 77 super(_IgnoreErrorsDataset, self).__init__(input_dataset, variant_tensor) 78