1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers to traverse the Dataset dependency structure.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from six.moves import queue as Queue # pylint: disable=redefined-builtin 21 22from tensorflow.python.framework import dtypes 23 24 25def obtain_all_variant_tensor_ops(dataset): 26 """Given an input dataset, finds all dataset ops used for construction. 27 28 A series of transformations would have created this dataset with each 29 transformation including zero or more Dataset ops, each producing a dataset 30 variant tensor. This method outputs all of them. 31 32 Args: 33 dataset: Dataset to find variant tensors for. 34 35 Returns: 36 A list of variant_tensor producing dataset ops used to construct this 37 dataset. 38 """ 39 all_variant_tensor_ops = [] 40 bfs_q = Queue.Queue() 41 bfs_q.put(dataset._variant_tensor.op) # pylint: disable=protected-access 42 visited = [] 43 while not bfs_q.empty(): 44 op = bfs_q.get() 45 visited.append(op) 46 # We look for all ops that produce variant tensors as output. This is a bit 47 # of overkill but the other dataset _inputs() traversal strategies can't 48 # cover the case of function inputs that capture dataset variants. 49 # TODO(b/120873778): Make this more efficient. 50 if op.outputs[0].dtype == dtypes.variant: 51 all_variant_tensor_ops.append(op) 52 for i in op.inputs: 53 input_op = i.op 54 if input_op not in visited: 55 bfs_q.put(input_op) 56 return all_variant_tensor_ops 57