1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to traverse the Dataset dependency structure."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from six.moves import queue as Queue  # pylint: disable=redefined-builtin
21
22from tensorflow.python.framework import dtypes
23
24
25def obtain_all_variant_tensor_ops(dataset):
26  """Given an input dataset, finds all dataset ops used for construction.
27
28  A series of transformations would have created this dataset with each
29  transformation including zero or more Dataset ops, each producing a dataset
30  variant tensor. This method outputs all of them.
31
32  Args:
33    dataset: Dataset to find variant tensors for.
34
35  Returns:
36    A list of variant_tensor producing dataset ops used to construct this
37    dataset.
38  """
39  all_variant_tensor_ops = []
40  bfs_q = Queue.Queue()
41  bfs_q.put(dataset._variant_tensor.op)  # pylint: disable=protected-access
42  visited = []
43  while not bfs_q.empty():
44    op = bfs_q.get()
45    visited.append(op)
46    # We look for all ops that produce variant tensors as output. This is a bit
47    # of overkill but the other dataset _inputs() traversal strategies can't
48    # cover the case of function inputs that capture dataset variants.
49    # TODO(b/120873778): Make this more efficient.
50    if op.outputs[0].dtype == dtypes.variant:
51      all_variant_tensor_ops.append(op)
52    for i in op.inputs:
53      input_op = i.op
54      if input_op not in visited:
55        bfs_q.put(input_op)
56  return all_variant_tensor_ops
57