1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for input_pipeline_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import random
21
22from tensorflow.contrib.input_pipeline.ops import gen_input_pipeline_ops
23from tensorflow.contrib.util import loader
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import variable_scope
28from tensorflow.python.platform import resource_loader
29
30
31_input_pipeline_ops = loader.load_op_library(
32    resource_loader.get_path_to_datafile("_input_pipeline_ops.so"))
33
34
35def obtain_next(string_list_tensor, counter):
36  """Basic wrapper for the ObtainNextOp.
37
38  Args:
39    string_list_tensor: A tensor that is a list of strings
40    counter: an int64 ref tensor to keep track of which element is returned.
41
42  Returns:
43    An op that produces the element at counter + 1 in the list, round
44    robin style.
45  """
46  return gen_input_pipeline_ops.obtain_next(string_list_tensor, counter)
47
48
49def _maybe_randomize_list(string_list, shuffle):
50  if shuffle:
51    random.shuffle(string_list)
52  return string_list
53
54
55def _create_list(string_list, shuffle, seed, num_epochs):
56  if shuffle and seed:
57    random.seed(seed)
58  expanded_list = _maybe_randomize_list(string_list, shuffle)[:]
59  if num_epochs:
60    for _ in range(num_epochs - 1):
61      expanded_list.extend(_maybe_randomize_list(string_list, shuffle))
62  return expanded_list
63
64
65def seek_next(string_list, shuffle=False, seed=None, num_epochs=None):
66  """Returns an op that seeks the next element in a list of strings.
67
68  Seeking happens in a round robin fashion. This op creates a variable called
69  obtain_next_counter that is initialized to -1 and is used to keep track of
70  which element in the list was returned, and a variable
71  obtain_next_expanded_list to hold the list. If num_epochs is not None, then we
72  limit the number of times we go around the string_list before OutOfRangeError
73  is thrown. It creates a variable to keep track of this.
74
75  Args:
76    string_list: A list of strings.
77    shuffle: If true, we shuffle the string_list differently for each epoch.
78    seed: Seed used for shuffling.
79    num_epochs: Returns OutOfRangeError once string_list has been repeated
80                num_epoch times. If unspecified then keeps on looping.
81
82  Returns:
83    An op that produces the next element in the provided list.
84  """
85  expanded_list = _create_list(string_list, shuffle, seed, num_epochs)
86
87  with variable_scope.variable_scope("obtain_next"):
88    counter = variable_scope.get_variable(
89        name="obtain_next_counter",
90        initializer=constant_op.constant(
91            -1, dtype=dtypes.int64),
92        dtype=dtypes.int64,
93        trainable=False)
94    with ops.colocate_with(counter):
95      string_tensor = variable_scope.get_variable(
96          name="obtain_next_expanded_list",
97          initializer=constant_op.constant(expanded_list),
98          dtype=dtypes.string,
99          trainable=False)
100    if num_epochs:
101      filename_counter = variable_scope.get_variable(
102          name="obtain_next_filename_counter",
103          initializer=constant_op.constant(
104              0, dtype=dtypes.int64),
105          dtype=dtypes.int64,
106          trainable=False)
107      c = filename_counter.count_up_to(len(expanded_list))
108      with ops.control_dependencies([c]):
109        return obtain_next(string_tensor, counter)
110    else:
111      return obtain_next(string_tensor, counter)
112