1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Smoke test for reading records from GCS to TensorFlow.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import random 21import sys 22import time 23 24import numpy as np 25import tensorflow as tf 26from tensorflow.core.example import example_pb2 27from tensorflow.python.lib.io import file_io 28 29flags = tf.app.flags 30flags.DEFINE_string("gcs_bucket_url", "", 31 "The URL to the GCS bucket in which the temporary " 32 "tfrecord file is to be written and read, e.g., " 33 "gs://my-gcs-bucket/test-directory") 34flags.DEFINE_integer("num_examples", 10, "Number of examples to generate") 35 36FLAGS = flags.FLAGS 37 38 39def create_examples(num_examples, input_mean): 40 """Create ExampleProto's containing data.""" 41 ids = np.arange(num_examples).reshape([num_examples, 1]) 42 inputs = np.random.randn(num_examples, 1) + input_mean 43 target = inputs - input_mean 44 examples = [] 45 for row in range(num_examples): 46 ex = example_pb2.Example() 47 ex.features.feature["id"].bytes_list.value.append(str(ids[row, 0])) 48 ex.features.feature["target"].float_list.value.append(target[row, 0]) 49 ex.features.feature["inputs"].float_list.value.append(inputs[row, 0]) 50 examples.append(ex) 51 return examples 52 53 54def create_dir_test(): 55 """Verifies file_io directory handling methods.""" 56 57 # Test directory creation. 58 starttime_ms = int(round(time.time() * 1000)) 59 dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms) 60 print("Creating dir %s" % dir_name) 61 file_io.create_dir(dir_name) 62 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 63 print("Created directory in: %d milliseconds" % elapsed_ms) 64 65 # Check that the directory exists. 66 dir_exists = file_io.is_directory(dir_name) 67 assert dir_exists 68 print("%s directory exists: %s" % (dir_name, dir_exists)) 69 70 # Test recursive directory creation. 71 starttime_ms = int(round(time.time() * 1000)) 72 recursive_dir_name = "%s/%s/%s" % (dir_name, 73 "nested_dir1", 74 "nested_dir2") 75 print("Creating recursive dir %s" % recursive_dir_name) 76 file_io.recursive_create_dir(recursive_dir_name) 77 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 78 print("Created directory recursively in: %d milliseconds" % elapsed_ms) 79 80 # Check that the directory exists. 81 recursive_dir_exists = file_io.is_directory(recursive_dir_name) 82 assert recursive_dir_exists 83 print("%s directory exists: %s" % (recursive_dir_name, recursive_dir_exists)) 84 85 # Create some contents in the just created directory and list the contents. 86 num_files = 10 87 files_to_create = ["file_%d.txt" % n for n in range(num_files)] 88 for file_num in files_to_create: 89 file_name = "%s/%s" % (dir_name, file_num) 90 print("Creating file %s." % file_name) 91 file_io.write_string_to_file(file_name, "test file.") 92 93 print("Listing directory %s." % dir_name) 94 starttime_ms = int(round(time.time() * 1000)) 95 directory_contents = file_io.list_directory(dir_name) 96 print(directory_contents) 97 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 98 print("Listed directory %s in %s milliseconds" % (dir_name, elapsed_ms)) 99 assert set(directory_contents) == set(files_to_create + ["nested_dir1/"]) 100 101 # Test directory renaming. 102 dir_to_rename = "%s/old_dir" % dir_name 103 new_dir_name = "%s/new_dir" % dir_name 104 file_io.create_dir(dir_to_rename) 105 assert file_io.is_directory(dir_to_rename) 106 assert not file_io.is_directory(new_dir_name) 107 108 starttime_ms = int(round(time.time() * 1000)) 109 print("Will try renaming directory %s to %s" % (dir_to_rename, new_dir_name)) 110 file_io.rename(dir_to_rename, new_dir_name) 111 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 112 print("Renamed directory %s to %s in %s milliseconds" % ( 113 dir_to_rename, new_dir_name, elapsed_ms)) 114 assert not file_io.is_directory(dir_to_rename) 115 assert file_io.is_directory(new_dir_name) 116 117 # Test Delete directory recursively. 118 print("Deleting directory recursively %s." % dir_name) 119 starttime_ms = int(round(time.time() * 1000)) 120 file_io.delete_recursively(dir_name) 121 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 122 dir_exists = file_io.is_directory(dir_name) 123 assert not dir_exists 124 print("Deleted directory recursively %s in %s milliseconds" % ( 125 dir_name, elapsed_ms)) 126 127 128def create_object_test(): 129 """Verifies file_io's object manipulation methods .""" 130 starttime_ms = int(round(time.time() * 1000)) 131 dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms) 132 print("Creating dir %s." % dir_name) 133 file_io.create_dir(dir_name) 134 135 num_files = 5 136 # Create files of 2 different patterns in this directory. 137 files_pattern_1 = ["%s/test_file_%d.txt" % (dir_name, n) 138 for n in range(num_files)] 139 files_pattern_2 = ["%s/testfile%d.txt" % (dir_name, n) 140 for n in range(num_files)] 141 142 starttime_ms = int(round(time.time() * 1000)) 143 files_to_create = files_pattern_1 + files_pattern_2 144 for file_name in files_to_create: 145 print("Creating file %s." % file_name) 146 file_io.write_string_to_file(file_name, "test file creation.") 147 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 148 print("Created %d files in %s milliseconds" % ( 149 len(files_to_create), elapsed_ms)) 150 151 # Listing files of pattern1. 152 list_files_pattern = "%s/test_file*.txt" % dir_name 153 print("Getting files matching pattern %s." % list_files_pattern) 154 starttime_ms = int(round(time.time() * 1000)) 155 files_list = file_io.get_matching_files(list_files_pattern) 156 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 157 print("Listed files in %s milliseconds" % elapsed_ms) 158 print(files_list) 159 assert set(files_list) == set(files_pattern_1) 160 161 # Listing files of pattern2. 162 list_files_pattern = "%s/testfile*.txt" % dir_name 163 print("Getting files matching pattern %s." % list_files_pattern) 164 starttime_ms = int(round(time.time() * 1000)) 165 files_list = file_io.get_matching_files(list_files_pattern) 166 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 167 print("Listed files in %s milliseconds" % elapsed_ms) 168 print(files_list) 169 assert set(files_list) == set(files_pattern_2) 170 171 # Test renaming file. 172 file_to_rename = "%s/oldname.txt" % dir_name 173 file_new_name = "%s/newname.txt" % dir_name 174 file_io.write_string_to_file(file_to_rename, "test file.") 175 assert file_io.file_exists(file_to_rename) 176 assert not file_io.file_exists(file_new_name) 177 178 print("Will try renaming file %s to %s" % (file_to_rename, file_new_name)) 179 starttime_ms = int(round(time.time() * 1000)) 180 file_io.rename(file_to_rename, file_new_name) 181 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 182 print("File %s renamed to %s in %s milliseconds" % ( 183 file_to_rename, file_new_name, elapsed_ms)) 184 assert not file_io.file_exists(file_to_rename) 185 assert file_io.file_exists(file_new_name) 186 187 # Delete directory. 188 print("Deleting directory %s." % dir_name) 189 file_io.delete_recursively(dir_name) 190 191 192def main(argv): 193 del argv # Unused. 194 195 # Sanity check on the GCS bucket URL. 196 if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"): 197 print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url) 198 sys.exit(1) 199 200 # Generate random tfrecord path name. 201 input_path = FLAGS.gcs_bucket_url + "/" 202 input_path += "".join(random.choice("0123456789ABCDEF") for i in range(8)) 203 input_path += ".tfrecord" 204 print("Using input path: %s" % input_path) 205 206 # Verify that writing to the records file in GCS works. 207 print("\n=== Testing writing and reading of GCS record file... ===") 208 example_data = create_examples(FLAGS.num_examples, 5) 209 with tf.python_io.TFRecordWriter(input_path) as hf: 210 for e in example_data: 211 hf.write(e.SerializeToString()) 212 213 print("Data written to: %s" % input_path) 214 215 # Verify that reading from the tfrecord file works and that 216 # tf_record_iterator works. 217 record_iter = tf.python_io.tf_record_iterator(input_path) 218 read_count = 0 219 for _ in record_iter: 220 read_count += 1 221 print("Read %d records using tf_record_iterator" % read_count) 222 223 if read_count != FLAGS.num_examples: 224 print("FAIL: The number of records read from tf_record_iterator (%d) " 225 "differs from the expected number (%d)" % (read_count, 226 FLAGS.num_examples)) 227 sys.exit(1) 228 229 # Verify that running the read op in a session works. 230 print("\n=== Testing TFRecordReader.read op in a session... ===") 231 with tf.Graph().as_default(): 232 filename_queue = tf.train.string_input_producer([input_path], num_epochs=1) 233 reader = tf.TFRecordReader() 234 _, serialized_example = reader.read(filename_queue) 235 236 with tf.Session() as sess: 237 sess.run(tf.global_variables_initializer()) 238 sess.run(tf.local_variables_initializer()) 239 tf.train.start_queue_runners() 240 index = 0 241 for _ in range(FLAGS.num_examples): 242 print("Read record: %d" % index) 243 sess.run(serialized_example) 244 index += 1 245 246 # Reading one more record should trigger an exception. 247 try: 248 sess.run(serialized_example) 249 print("FAIL: Failed to catch the expected OutOfRangeError while " 250 "reading one more record than is available") 251 sys.exit(1) 252 except tf.errors.OutOfRangeError: 253 print("Successfully caught the expected OutOfRangeError while " 254 "reading one more record than is available") 255 256 create_dir_test() 257 create_object_test() 258 259 260if __name__ == "__main__": 261 tf.app.run(main) 262