1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for input_pipeline_ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import random 21 22from tensorflow.contrib.input_pipeline.ops import gen_input_pipeline_ops 23from tensorflow.contrib.util import loader 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import variable_scope 28from tensorflow.python.platform import resource_loader 29 30 31_input_pipeline_ops = loader.load_op_library( 32 resource_loader.get_path_to_datafile("_input_pipeline_ops.so")) 33 34 35def obtain_next(string_list_tensor, counter): 36 """Basic wrapper for the ObtainNextOp. 37 38 Args: 39 string_list_tensor: A tensor that is a list of strings 40 counter: an int64 ref tensor to keep track of which element is returned. 41 42 Returns: 43 An op that produces the element at counter + 1 in the list, round 44 robin style. 45 """ 46 return gen_input_pipeline_ops.obtain_next(string_list_tensor, counter) 47 48 49def _maybe_randomize_list(string_list, shuffle): 50 if shuffle: 51 random.shuffle(string_list) 52 return string_list 53 54 55def _create_list(string_list, shuffle, seed, num_epochs): 56 if shuffle and seed: 57 random.seed(seed) 58 expanded_list = _maybe_randomize_list(string_list, shuffle)[:] 59 if num_epochs: 60 for _ in range(num_epochs - 1): 61 expanded_list.extend(_maybe_randomize_list(string_list, shuffle)) 62 return expanded_list 63 64 65def seek_next(string_list, shuffle=False, seed=None, num_epochs=None): 66 """Returns an op that seeks the next element in a list of strings. 67 68 Seeking happens in a round robin fashion. This op creates a variable called 69 obtain_next_counter that is initialized to -1 and is used to keep track of 70 which element in the list was returned, and a variable 71 obtain_next_expanded_list to hold the list. If num_epochs is not None, then we 72 limit the number of times we go around the string_list before OutOfRangeError 73 is thrown. It creates a variable to keep track of this. 74 75 Args: 76 string_list: A list of strings. 77 shuffle: If true, we shuffle the string_list differently for each epoch. 78 seed: Seed used for shuffling. 79 num_epochs: Returns OutOfRangeError once string_list has been repeated 80 num_epoch times. If unspecified then keeps on looping. 81 82 Returns: 83 An op that produces the next element in the provided list. 84 """ 85 expanded_list = _create_list(string_list, shuffle, seed, num_epochs) 86 87 with variable_scope.variable_scope("obtain_next"): 88 counter = variable_scope.get_variable( 89 name="obtain_next_counter", 90 initializer=constant_op.constant( 91 -1, dtype=dtypes.int64), 92 dtype=dtypes.int64, 93 trainable=False) 94 with ops.colocate_with(counter): 95 string_tensor = variable_scope.get_variable( 96 name="obtain_next_expanded_list", 97 initializer=constant_op.constant(expanded_list), 98 dtype=dtypes.string, 99 trainable=False) 100 if num_epochs: 101 filename_counter = variable_scope.get_variable( 102 name="obtain_next_filename_counter", 103 initializer=constant_op.constant( 104 0, dtype=dtypes.int64), 105 dtype=dtypes.int64, 106 trainable=False) 107 c = filename_counter.count_up_to(len(expanded_list)) 108 with ops.control_dependencies([c]): 109 return obtain_next(string_tensor, counter) 110 else: 111 return obtain_next(string_tensor, counter) 112