1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Smoke test for reading records from GCS to TensorFlow."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import random
21import sys
22import time
23
24import numpy as np
25import tensorflow as tf
26from tensorflow.core.example import example_pb2
27from tensorflow.python.lib.io import file_io
28
29flags = tf.app.flags
30flags.DEFINE_string("gcs_bucket_url", "",
31                    "The URL to the GCS bucket in which the temporary "
32                    "tfrecord file is to be written and read, e.g., "
33                    "gs://my-gcs-bucket/test-directory")
34flags.DEFINE_integer("num_examples", 10, "Number of examples to generate")
35
36FLAGS = flags.FLAGS
37
38
39def create_examples(num_examples, input_mean):
40  """Create ExampleProto's containing data."""
41  ids = np.arange(num_examples).reshape([num_examples, 1])
42  inputs = np.random.randn(num_examples, 1) + input_mean
43  target = inputs - input_mean
44  examples = []
45  for row in range(num_examples):
46    ex = example_pb2.Example()
47    ex.features.feature["id"].bytes_list.value.append(str(ids[row, 0]))
48    ex.features.feature["target"].float_list.value.append(target[row, 0])
49    ex.features.feature["inputs"].float_list.value.append(inputs[row, 0])
50    examples.append(ex)
51  return examples
52
53
54def create_dir_test():
55  """Verifies file_io directory handling methods."""
56
57  # Test directory creation.
58  starttime_ms = int(round(time.time() * 1000))
59  dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms)
60  print("Creating dir %s" % dir_name)
61  file_io.create_dir(dir_name)
62  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
63  print("Created directory in: %d milliseconds" % elapsed_ms)
64
65  # Check that the directory exists.
66  dir_exists = file_io.is_directory(dir_name)
67  assert dir_exists
68  print("%s directory exists: %s" % (dir_name, dir_exists))
69
70  # Test recursive directory creation.
71  starttime_ms = int(round(time.time() * 1000))
72  recursive_dir_name = "%s/%s/%s" % (dir_name,
73                                     "nested_dir1",
74                                     "nested_dir2")
75  print("Creating recursive dir %s" % recursive_dir_name)
76  file_io.recursive_create_dir(recursive_dir_name)
77  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
78  print("Created directory recursively in: %d milliseconds" % elapsed_ms)
79
80  # Check that the directory exists.
81  recursive_dir_exists = file_io.is_directory(recursive_dir_name)
82  assert recursive_dir_exists
83  print("%s directory exists: %s" % (recursive_dir_name, recursive_dir_exists))
84
85  # Create some contents in the just created directory and list the contents.
86  num_files = 10
87  files_to_create = ["file_%d.txt" % n for n in range(num_files)]
88  for file_num in files_to_create:
89    file_name = "%s/%s" % (dir_name, file_num)
90    print("Creating file %s." % file_name)
91    file_io.write_string_to_file(file_name, "test file.")
92
93  print("Listing directory %s." % dir_name)
94  starttime_ms = int(round(time.time() * 1000))
95  directory_contents = file_io.list_directory(dir_name)
96  print(directory_contents)
97  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
98  print("Listed directory %s in %s milliseconds" % (dir_name, elapsed_ms))
99  assert set(directory_contents) == set(files_to_create + ["nested_dir1/"])
100
101  # Test directory renaming.
102  dir_to_rename = "%s/old_dir" % dir_name
103  new_dir_name = "%s/new_dir" % dir_name
104  file_io.create_dir(dir_to_rename)
105  assert file_io.is_directory(dir_to_rename)
106  assert not file_io.is_directory(new_dir_name)
107
108  starttime_ms = int(round(time.time() * 1000))
109  print("Will try renaming directory %s to %s" % (dir_to_rename, new_dir_name))
110  file_io.rename(dir_to_rename, new_dir_name)
111  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
112  print("Renamed directory %s to %s in %s milliseconds" % (
113      dir_to_rename, new_dir_name, elapsed_ms))
114  assert not file_io.is_directory(dir_to_rename)
115  assert file_io.is_directory(new_dir_name)
116
117  # Test Delete directory recursively.
118  print("Deleting directory recursively %s." % dir_name)
119  starttime_ms = int(round(time.time() * 1000))
120  file_io.delete_recursively(dir_name)
121  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
122  dir_exists = file_io.is_directory(dir_name)
123  assert not dir_exists
124  print("Deleted directory recursively %s in %s milliseconds" % (
125      dir_name, elapsed_ms))
126
127
128def create_object_test():
129  """Verifies file_io's object manipulation methods ."""
130  starttime_ms = int(round(time.time() * 1000))
131  dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms)
132  print("Creating dir %s." % dir_name)
133  file_io.create_dir(dir_name)
134
135  num_files = 5
136  # Create files of 2 different patterns in this directory.
137  files_pattern_1 = ["%s/test_file_%d.txt" % (dir_name, n)
138                     for n in range(num_files)]
139  files_pattern_2 = ["%s/testfile%d.txt" % (dir_name, n)
140                     for n in range(num_files)]
141
142  starttime_ms = int(round(time.time() * 1000))
143  files_to_create = files_pattern_1 + files_pattern_2
144  for file_name in files_to_create:
145    print("Creating file %s." % file_name)
146    file_io.write_string_to_file(file_name, "test file creation.")
147  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
148  print("Created %d files in %s milliseconds" % (
149      len(files_to_create), elapsed_ms))
150
151  # Listing files of pattern1.
152  list_files_pattern = "%s/test_file*.txt" % dir_name
153  print("Getting files matching pattern %s." % list_files_pattern)
154  starttime_ms = int(round(time.time() * 1000))
155  files_list = file_io.get_matching_files(list_files_pattern)
156  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
157  print("Listed files in %s milliseconds" % elapsed_ms)
158  print(files_list)
159  assert set(files_list) == set(files_pattern_1)
160
161  # Listing files of pattern2.
162  list_files_pattern = "%s/testfile*.txt" % dir_name
163  print("Getting files matching pattern %s." % list_files_pattern)
164  starttime_ms = int(round(time.time() * 1000))
165  files_list = file_io.get_matching_files(list_files_pattern)
166  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
167  print("Listed files in %s milliseconds" % elapsed_ms)
168  print(files_list)
169  assert set(files_list) == set(files_pattern_2)
170
171  # Test renaming file.
172  file_to_rename = "%s/oldname.txt" % dir_name
173  file_new_name = "%s/newname.txt" % dir_name
174  file_io.write_string_to_file(file_to_rename, "test file.")
175  assert file_io.file_exists(file_to_rename)
176  assert not file_io.file_exists(file_new_name)
177
178  print("Will try renaming file %s to %s" % (file_to_rename, file_new_name))
179  starttime_ms = int(round(time.time() * 1000))
180  file_io.rename(file_to_rename, file_new_name)
181  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
182  print("File %s renamed to %s in %s milliseconds" % (
183      file_to_rename, file_new_name, elapsed_ms))
184  assert not file_io.file_exists(file_to_rename)
185  assert file_io.file_exists(file_new_name)
186
187  # Delete directory.
188  print("Deleting directory %s." % dir_name)
189  file_io.delete_recursively(dir_name)
190
191
192def main(argv):
193  del argv  # Unused.
194
195  # Sanity check on the GCS bucket URL.
196  if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"):
197    print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url)
198    sys.exit(1)
199
200  # Generate random tfrecord path name.
201  input_path = FLAGS.gcs_bucket_url + "/"
202  input_path += "".join(random.choice("0123456789ABCDEF") for i in range(8))
203  input_path += ".tfrecord"
204  print("Using input path: %s" % input_path)
205
206  # Verify that writing to the records file in GCS works.
207  print("\n=== Testing writing and reading of GCS record file... ===")
208  example_data = create_examples(FLAGS.num_examples, 5)
209  with tf.python_io.TFRecordWriter(input_path) as hf:
210    for e in example_data:
211      hf.write(e.SerializeToString())
212
213    print("Data written to: %s" % input_path)
214
215  # Verify that reading from the tfrecord file works and that
216  # tf_record_iterator works.
217  record_iter = tf.python_io.tf_record_iterator(input_path)
218  read_count = 0
219  for _ in record_iter:
220    read_count += 1
221  print("Read %d records using tf_record_iterator" % read_count)
222
223  if read_count != FLAGS.num_examples:
224    print("FAIL: The number of records read from tf_record_iterator (%d) "
225          "differs from the expected number (%d)" % (read_count,
226                                                     FLAGS.num_examples))
227    sys.exit(1)
228
229  # Verify that running the read op in a session works.
230  print("\n=== Testing TFRecordReader.read op in a session... ===")
231  with tf.Graph().as_default():
232    filename_queue = tf.train.string_input_producer([input_path], num_epochs=1)
233    reader = tf.TFRecordReader()
234    _, serialized_example = reader.read(filename_queue)
235
236    with tf.Session() as sess:
237      sess.run(tf.global_variables_initializer())
238      sess.run(tf.local_variables_initializer())
239      tf.train.start_queue_runners()
240      index = 0
241      for _ in range(FLAGS.num_examples):
242        print("Read record: %d" % index)
243        sess.run(serialized_example)
244        index += 1
245
246      # Reading one more record should trigger an exception.
247      try:
248        sess.run(serialized_example)
249        print("FAIL: Failed to catch the expected OutOfRangeError while "
250              "reading one more record than is available")
251        sys.exit(1)
252      except tf.errors.OutOfRangeError:
253        print("Successfully caught the expected OutOfRangeError while "
254              "reading one more record than is available")
255
256  create_dir_test()
257  create_object_test()
258
259
260if __name__ == "__main__":
261  tf.app.run(main)
262