1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""take-while dataset transformation."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import tensor_spec
23from tensorflow.python.ops import gen_experimental_dataset_ops
24from tensorflow.python.util.tf_export import tf_export
25
26
27class _TakeWhileDataset(dataset_ops.UnaryUnchangedStructureDataset):
28  """A dataset that stops iteration when `predicate` returns false."""
29
30  def __init__(self, input_dataset, predicate):
31    """See `take_while()` for details."""
32
33    self._input_dataset = input_dataset
34    wrapped_func = dataset_ops.StructuredFunctionWrapper(
35        predicate,
36        "tf.data.experimental.take_while()",
37        dataset=self._input_dataset)
38
39    if not wrapped_func.output_structure.is_compatible_with(
40        tensor_spec.TensorSpec([], dtypes.bool)):
41      raise ValueError("`predicate` must return a scalar boolean tensor.")
42
43    self._predicate = wrapped_func
44    var_tensor = gen_experimental_dataset_ops.take_while_dataset(
45        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
46        other_arguments=self._predicate.function.captured_inputs,
47        predicate=self._predicate.function,
48        **self._flat_structure)
49    super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor)
50
51  def _functions(self):
52    return [self._predicate]
53
54
55@tf_export("data.experimental.take_while")
56def take_while(predicate):
57  """A transformation that stops dataset iteration based on a `predicate`.
58
59  Args:
60    predicate: A function that maps a nested structure of tensors (having shapes
61      and types defined by `self.output_shapes` and `self.output_types`) to a
62      scalar `tf.bool` tensor.
63
64  Returns:
65    A `Dataset` transformation function, which can be passed to
66    `tf.data.Dataset.apply`.
67  """
68
69  def _apply_fn(dataset):
70    return _TakeWhileDataset(dataset, predicate)
71
72  return _apply_fn
73