1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for learn.io.graph_io.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import base64 22import os 23import random 24import tempfile 25 26from six.moves import xrange # pylint: disable=redefined-builtin 27 28from tensorflow.contrib.learn.python.learn.learn_io import graph_io 29from tensorflow.python.client import session as session_lib 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes as dtypes_lib 32from tensorflow.python.framework import errors 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import test_util 35from tensorflow.python.ops import io_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import parsing_ops 38from tensorflow.python.ops import variables 39from tensorflow.python.platform import gfile 40from tensorflow.python.platform import test 41from tensorflow.python.training import coordinator 42from tensorflow.python.training import queue_runner_impl 43from tensorflow.python.training import server_lib 44 45_VALID_FILE_PATTERN = "VALID" 46_VALID_FILE_PATTERN_2 = "VALID_2" 47_FILE_NAMES = [b"abc", b"def", b"ghi", b"jkl"] 48_FILE_NAMES_2 = [b"mno", b"pqr"] 49_INVALID_FILE_PATTERN = "INVALID" 50 51 52class GraphIOTest(test.TestCase): 53 54 def _mock_glob(self, pattern): 55 if _VALID_FILE_PATTERN == pattern: 56 return _FILE_NAMES 57 if _VALID_FILE_PATTERN_2 == pattern: 58 return _FILE_NAMES_2 59 self.assertEqual(_INVALID_FILE_PATTERN, pattern) 60 return [] 61 62 def setUp(self): 63 super(GraphIOTest, self).setUp() 64 random.seed(42) 65 self._orig_glob = gfile.Glob 66 gfile.Glob = self._mock_glob 67 68 def tearDown(self): 69 gfile.Glob = self._orig_glob 70 super(GraphIOTest, self).tearDown() 71 72 def test_dequeue_batch_value_errors(self): 73 default_batch_size = 17 74 queue_capacity = 1234 75 num_threads = 3 76 name = "my_batch" 77 78 self.assertRaisesRegexp( 79 ValueError, 80 "No files match", 81 graph_io.read_batch_examples, 82 _INVALID_FILE_PATTERN, 83 default_batch_size, 84 io_ops.TFRecordReader, 85 False, 86 num_epochs=None, 87 queue_capacity=queue_capacity, 88 num_threads=num_threads, 89 name=name) 90 self.assertRaisesRegexp( 91 ValueError, 92 "Invalid batch_size", 93 graph_io.read_batch_examples, 94 _VALID_FILE_PATTERN, 95 None, 96 io_ops.TFRecordReader, 97 False, 98 num_epochs=None, 99 queue_capacity=queue_capacity, 100 num_threads=num_threads, 101 name=name) 102 self.assertRaisesRegexp( 103 ValueError, 104 "Invalid batch_size", 105 graph_io.read_batch_examples, 106 _VALID_FILE_PATTERN, 107 -1, 108 io_ops.TFRecordReader, 109 False, 110 num_epochs=None, 111 queue_capacity=queue_capacity, 112 num_threads=num_threads, 113 name=name) 114 self.assertRaisesRegexp( 115 ValueError, 116 "Invalid batch_size", 117 graph_io.read_batch_examples, 118 _VALID_FILE_PATTERN, 119 default_batch_size, 120 io_ops.TFRecordReader, 121 False, 122 num_epochs=None, 123 queue_capacity=default_batch_size, 124 num_threads=num_threads, 125 name=name) 126 self.assertRaisesRegexp( 127 ValueError, 128 "Invalid queue_capacity", 129 graph_io.read_batch_examples, 130 _VALID_FILE_PATTERN, 131 default_batch_size, 132 io_ops.TFRecordReader, 133 False, 134 num_epochs=None, 135 queue_capacity=None, 136 num_threads=num_threads, 137 name=name) 138 self.assertRaisesRegexp( 139 ValueError, 140 "Invalid num_threads", 141 graph_io.read_batch_examples, 142 _VALID_FILE_PATTERN, 143 default_batch_size, 144 io_ops.TFRecordReader, 145 False, 146 num_epochs=None, 147 queue_capacity=queue_capacity, 148 num_threads=None, 149 name=name) 150 self.assertRaisesRegexp( 151 ValueError, 152 "Invalid num_threads", 153 graph_io.read_batch_examples, 154 _VALID_FILE_PATTERN, 155 default_batch_size, 156 io_ops.TFRecordReader, 157 False, 158 num_epochs=None, 159 queue_capacity=queue_capacity, 160 num_threads=-1, 161 name=name) 162 self.assertRaisesRegexp( 163 ValueError, 164 "Invalid batch_size", 165 graph_io.read_batch_examples, 166 _VALID_FILE_PATTERN, 167 queue_capacity + 1, 168 io_ops.TFRecordReader, 169 False, 170 num_epochs=None, 171 queue_capacity=queue_capacity, 172 num_threads=1, 173 name=name) 174 self.assertRaisesRegexp( 175 ValueError, 176 "Invalid num_epochs", 177 graph_io.read_batch_examples, 178 _VALID_FILE_PATTERN, 179 default_batch_size, 180 io_ops.TFRecordReader, 181 False, 182 num_epochs=-1, 183 queue_capacity=queue_capacity, 184 num_threads=1, 185 name=name) 186 self.assertRaisesRegexp( 187 ValueError, 188 "Invalid read_batch_size", 189 graph_io.read_batch_examples, 190 _VALID_FILE_PATTERN, 191 default_batch_size, 192 io_ops.TFRecordReader, 193 False, 194 num_epochs=None, 195 queue_capacity=queue_capacity, 196 num_threads=1, 197 read_batch_size=0, 198 name=name) 199 200 def test_batch_record_features(self): 201 batch_size = 17 202 queue_capacity = 1234 203 name = "my_batch" 204 shape = (0,) 205 features = { 206 "feature": 207 parsing_ops.FixedLenFeature(shape=shape, dtype=dtypes_lib.float32) 208 } 209 210 with ops.Graph().as_default() as g, self.session(graph=g) as sess: 211 features = graph_io.read_batch_record_features( 212 _VALID_FILE_PATTERN, 213 batch_size, 214 features, 215 randomize_input=False, 216 queue_capacity=queue_capacity, 217 reader_num_threads=2, 218 name=name) 219 self.assertTrue("feature" in features, 220 "'feature' missing from %s." % features.keys()) 221 feature = features["feature"] 222 self.assertEqual("%s/fifo_queue_1_Dequeue:0" % name, feature.name) 223 self.assertAllEqual((batch_size,) + shape, feature.get_shape().as_list()) 224 file_name_queue_name = "%s/file_name_queue" % name 225 file_names_name = "%s/input" % file_name_queue_name 226 example_queue_name = "%s/fifo_queue" % name 227 parse_example_queue_name = "%s/fifo_queue" % name 228 op_nodes = test_util.assert_ops_in_graph({ 229 file_names_name: "Const", 230 file_name_queue_name: "FIFOQueueV2", 231 "%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2", 232 example_queue_name: "FIFOQueueV2", 233 parse_example_queue_name: "FIFOQueueV2", 234 name: "QueueDequeueManyV2" 235 }, g) 236 self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0]) 237 self.assertEqual(queue_capacity, 238 op_nodes[example_queue_name].attr["capacity"].i) 239 240 def test_one_epoch(self): 241 batch_size = 17 242 queue_capacity = 1234 243 name = "my_batch" 244 245 with ops.Graph().as_default() as g, self.session(graph=g) as sess: 246 inputs = graph_io.read_batch_examples( 247 _VALID_FILE_PATTERN, 248 batch_size, 249 reader=io_ops.TFRecordReader, 250 randomize_input=True, 251 num_epochs=1, 252 queue_capacity=queue_capacity, 253 name=name) 254 self.assertAllEqual((None,), inputs.get_shape().as_list()) 255 self.assertEqual("%s:1" % name, inputs.name) 256 file_name_queue_name = "%s/file_name_queue" % name 257 file_name_queue_limit_name = ( 258 "%s/limit_epochs/epochs" % file_name_queue_name) 259 file_names_name = "%s/input" % file_name_queue_name 260 example_queue_name = "%s/random_shuffle_queue" % name 261 op_nodes = test_util.assert_ops_in_graph({ 262 file_names_name: "Const", 263 file_name_queue_name: "FIFOQueueV2", 264 "%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2", 265 example_queue_name: "RandomShuffleQueueV2", 266 name: "QueueDequeueUpToV2", 267 file_name_queue_limit_name: "VariableV2" 268 }, g) 269 self.assertEqual( 270 set(_FILE_NAMES), set(sess.run(["%s:0" % file_names_name])[0])) 271 self.assertEqual(queue_capacity, 272 op_nodes[example_queue_name].attr["capacity"].i) 273 274 def test_batch_randomized_multiple_globs(self): 275 batch_size = 17 276 queue_capacity = 1234 277 name = "my_batch" 278 279 with ops.Graph().as_default() as g, self.session(graph=g) as sess: 280 inputs = graph_io.read_batch_examples( 281 [_VALID_FILE_PATTERN, _VALID_FILE_PATTERN_2], 282 batch_size, 283 reader=io_ops.TFRecordReader, 284 randomize_input=True, 285 queue_capacity=queue_capacity, 286 name=name) 287 self.assertAllEqual((batch_size,), inputs.get_shape().as_list()) 288 self.assertEqual("%s:1" % name, inputs.name) 289 file_name_queue_name = "%s/file_name_queue" % name 290 file_names_name = "%s/input" % file_name_queue_name 291 example_queue_name = "%s/random_shuffle_queue" % name 292 op_nodes = test_util.assert_ops_in_graph({ 293 file_names_name: "Const", 294 file_name_queue_name: "FIFOQueueV2", 295 "%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2", 296 example_queue_name: "RandomShuffleQueueV2", 297 name: "QueueDequeueManyV2" 298 }, g) 299 self.assertEqual( 300 set(_FILE_NAMES + _FILE_NAMES_2), 301 set(sess.run(["%s:0" % file_names_name])[0])) 302 self.assertEqual(queue_capacity, 303 op_nodes[example_queue_name].attr["capacity"].i) 304 305 def _create_temp_file(self, lines): 306 tempdir = tempfile.mkdtemp() 307 filename = os.path.join(tempdir, "temp_file") 308 gfile.Open(filename, "w").write(lines) 309 return filename 310 311 def _create_sorted_temp_files(self, lines_list): 312 tempdir = tempfile.mkdtemp() 313 filenames = [] 314 for i, lines in enumerate(lines_list): 315 filename = os.path.join(tempdir, "temp_file%05d" % i) 316 gfile.Open(filename, "w").write(lines) 317 filenames.append(filename) 318 return filenames 319 320 def test_read_text_lines(self): 321 gfile.Glob = self._orig_glob 322 filename = self._create_temp_file("ABC\nDEF\nGHK\n") 323 324 batch_size = 1 325 queue_capacity = 5 326 name = "my_batch" 327 328 with ops.Graph().as_default() as g, self.session(graph=g) as session: 329 inputs = graph_io.read_batch_examples( 330 filename, 331 batch_size, 332 reader=io_ops.TextLineReader, 333 randomize_input=False, 334 num_epochs=1, 335 queue_capacity=queue_capacity, 336 name=name) 337 self.assertAllEqual((None,), inputs.get_shape().as_list()) 338 session.run(variables.local_variables_initializer()) 339 340 coord = coordinator.Coordinator() 341 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 342 343 self.assertAllEqual(session.run(inputs), [b"ABC"]) 344 self.assertAllEqual(session.run(inputs), [b"DEF"]) 345 self.assertAllEqual(session.run(inputs), [b"GHK"]) 346 with self.assertRaises(errors.OutOfRangeError): 347 session.run(inputs) 348 349 coord.request_stop() 350 coord.join(threads) 351 352 def _create_file_from_list_of_features(self, lines): 353 json_lines = [ 354 "".join([ 355 '{"features": { "feature": { "sequence": {', 356 '"bytes_list": { "value": ["', 357 base64.b64encode(l).decode("ascii"), '"]}}}}}\n' 358 ]) for l in lines 359 ] 360 return self._create_temp_file("".join(json_lines)) 361 362 def test_read_text_lines_large(self): 363 gfile.Glob = self._orig_glob 364 sequence_prefix = "abcdefghijklmnopqrstuvwxyz123456789" 365 num_records = 49999 366 lines = [ 367 "".join([sequence_prefix, str(l)]).encode("ascii") 368 for l in xrange(num_records) 369 ] 370 filename = self._create_file_from_list_of_features(lines) 371 batch_size = 10000 372 queue_capacity = 100000 373 name = "my_large_batch" 374 375 features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)} 376 377 with ops.Graph().as_default() as g, self.session(graph=g) as session: 378 keys, result = graph_io.read_keyed_batch_features( 379 filename, 380 batch_size, 381 features, 382 io_ops.TextLineReader, 383 randomize_input=False, 384 num_epochs=1, 385 queue_capacity=queue_capacity, 386 num_enqueue_threads=2, 387 parse_fn=parsing_ops.decode_json_example, 388 name=name) 389 self.assertAllEqual((None,), keys.get_shape().as_list()) 390 self.assertEqual(1, len(result)) 391 self.assertAllEqual((None,), result["sequence"].get_shape().as_list()) 392 session.run(variables.local_variables_initializer()) 393 coord = coordinator.Coordinator() 394 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 395 396 data = [] 397 try: 398 while not coord.should_stop(): 399 data.append(session.run(result)) 400 except errors.OutOfRangeError: 401 pass 402 finally: 403 coord.request_stop() 404 405 coord.join(threads) 406 407 parsed_records = [ 408 item for sublist in [d["sequence"] for d in data] for item in sublist 409 ] 410 # Check that the number of records matches expected and all records 411 # are present. 412 self.assertEqual(len(parsed_records), num_records) 413 self.assertEqual(set(parsed_records), set(lines)) 414 415 def test_read_batch_features_maintains_order(self): 416 """Make sure that examples are read in the right order. 417 418 When randomize_input=False, num_enqueue_threads=1 and reader_num_threads=1 419 read_keyed_batch_features() should read the examples in the same order as 420 they appear in the file. 421 """ 422 gfile.Glob = self._orig_glob 423 num_records = 1000 424 lines = ["".join(str(l)).encode("ascii") for l in xrange(num_records)] 425 filename = self._create_file_from_list_of_features(lines) 426 batch_size = 10 427 queue_capacity = 1000 428 name = "my_large_batch" 429 430 features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)} 431 432 with ops.Graph().as_default() as g, self.session(graph=g) as session: 433 result = graph_io.read_batch_features( 434 filename, 435 batch_size, 436 features, 437 io_ops.TextLineReader, 438 randomize_input=False, 439 num_epochs=1, 440 queue_capacity=queue_capacity, 441 reader_num_threads=1, 442 num_enqueue_threads=1, 443 parse_fn=parsing_ops.decode_json_example, 444 name=name) 445 self.assertEqual(1, len(result)) 446 self.assertAllEqual((None,), result["sequence"].get_shape().as_list()) 447 session.run(variables.local_variables_initializer()) 448 coord = coordinator.Coordinator() 449 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 450 451 data = [] 452 try: 453 while not coord.should_stop(): 454 data.append(session.run(result)) 455 except errors.OutOfRangeError: 456 pass 457 finally: 458 coord.request_stop() 459 460 coord.join(threads) 461 462 parsed_records = [ 463 item for sublist in [d["sequence"] for d in data] for item in sublist 464 ] 465 # Check that the number of records matches expected and all records 466 # are present in the right order. 467 self.assertEqual(len(parsed_records), num_records) 468 self.assertEqual(parsed_records, lines) 469 470 def test_read_text_lines_multifile(self): 471 gfile.Glob = self._orig_glob 472 filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"]) 473 474 batch_size = 1 475 queue_capacity = 5 476 name = "my_batch" 477 478 with ops.Graph().as_default() as g, self.session(graph=g) as session: 479 inputs = graph_io.read_batch_examples( 480 filenames, 481 batch_size, 482 reader=io_ops.TextLineReader, 483 randomize_input=False, 484 num_epochs=1, 485 queue_capacity=queue_capacity, 486 name=name) 487 self.assertAllEqual((None,), inputs.get_shape().as_list()) 488 session.run(variables.local_variables_initializer()) 489 490 coord = coordinator.Coordinator() 491 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 492 493 self.assertEqual("%s:1" % name, inputs.name) 494 file_name_queue_name = "%s/file_name_queue" % name 495 file_names_name = "%s/input" % file_name_queue_name 496 example_queue_name = "%s/fifo_queue" % name 497 test_util.assert_ops_in_graph({ 498 file_names_name: "Const", 499 file_name_queue_name: "FIFOQueueV2", 500 "%s/read/TextLineReaderV2" % name: "TextLineReaderV2", 501 example_queue_name: "FIFOQueueV2", 502 name: "QueueDequeueUpToV2" 503 }, g) 504 505 self.assertAllEqual(session.run(inputs), [b"ABC"]) 506 self.assertAllEqual(session.run(inputs), [b"DEF"]) 507 self.assertAllEqual(session.run(inputs), [b"GHK"]) 508 with self.assertRaises(errors.OutOfRangeError): 509 session.run(inputs) 510 511 coord.request_stop() 512 coord.join(threads) 513 514 def test_read_text_lines_multifile_with_shared_queue(self): 515 gfile.Glob = self._orig_glob 516 filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"]) 517 518 batch_size = 1 519 queue_capacity = 5 520 name = "my_batch" 521 522 with ops.Graph().as_default() as g, self.session(graph=g) as session: 523 keys, inputs = graph_io.read_keyed_batch_examples_shared_queue( 524 filenames, 525 batch_size, 526 reader=io_ops.TextLineReader, 527 randomize_input=False, 528 num_epochs=1, 529 queue_capacity=queue_capacity, 530 name=name) 531 self.assertAllEqual((None,), keys.get_shape().as_list()) 532 self.assertAllEqual((None,), inputs.get_shape().as_list()) 533 session.run([ 534 variables.local_variables_initializer(), 535 variables.global_variables_initializer() 536 ]) 537 538 coord = coordinator.Coordinator() 539 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 540 541 self.assertEqual("%s:1" % name, inputs.name) 542 example_queue_name = "%s/fifo_queue" % name 543 worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name 544 test_util.assert_ops_in_graph({ 545 "%s/read/TextLineReaderV2" % name: "TextLineReaderV2", 546 example_queue_name: "FIFOQueueV2", 547 worker_file_name_queue_name: "FIFOQueueV2", 548 name: "QueueDequeueUpToV2" 549 }, g) 550 551 self.assertAllEqual(session.run(inputs), [b"ABC"]) 552 self.assertAllEqual(session.run(inputs), [b"DEF"]) 553 self.assertAllEqual(session.run(inputs), [b"GHK"]) 554 with self.assertRaises(errors.OutOfRangeError): 555 session.run(inputs) 556 557 coord.request_stop() 558 coord.join(threads) 559 560 def _get_qr(self, name): 561 for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS): 562 if qr.name == name: 563 return qr 564 565 def _run_queue(self, name, session): 566 qr = self._get_qr(name) 567 for op in qr.enqueue_ops: 568 session.run(op) 569 570 def test_multiple_workers_with_shared_queue(self): 571 gfile.Glob = self._orig_glob 572 filenames = self._create_sorted_temp_files([ 573 "ABC\n", "DEF\n", "GHI\n", "JKL\n", "MNO\n", "PQR\n", "STU\n", "VWX\n", 574 "YZ\n" 575 ]) 576 577 batch_size = 1 578 queue_capacity = 5 579 name = "my_batch" 580 example_queue_name = "%s/fifo_queue" % name 581 worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name 582 583 server = server_lib.Server.create_local_server() 584 585 with ops.Graph().as_default() as g1, session_lib.Session( 586 server.target, graph=g1) as session: 587 keys, inputs = graph_io.read_keyed_batch_examples_shared_queue( 588 filenames, 589 batch_size, 590 reader=io_ops.TextLineReader, 591 randomize_input=False, 592 num_epochs=1, 593 queue_capacity=queue_capacity, 594 name=name) 595 self.assertAllEqual((None,), keys.get_shape().as_list()) 596 self.assertAllEqual((None,), inputs.get_shape().as_list()) 597 session.run([ 598 variables.local_variables_initializer(), 599 variables.global_variables_initializer() 600 ]) 601 602 # Run the two queues once manually. 603 self._run_queue(worker_file_name_queue_name, session) 604 self._run_queue(example_queue_name, session) 605 606 self.assertAllEqual(session.run(inputs), [b"ABC"]) 607 608 # Run the worker and the example queue. 609 self._run_queue(worker_file_name_queue_name, session) 610 self._run_queue(example_queue_name, session) 611 612 self.assertAllEqual(session.run(inputs), [b"DEF"]) 613 614 with ops.Graph().as_default() as g2, session_lib.Session( 615 server.target, graph=g2) as session: 616 keys, inputs = graph_io.read_keyed_batch_examples_shared_queue( 617 filenames, 618 batch_size, 619 reader=io_ops.TextLineReader, 620 randomize_input=False, 621 num_epochs=1, 622 queue_capacity=queue_capacity, 623 name=name) 624 self.assertAllEqual((None,), keys.get_shape().as_list()) 625 self.assertAllEqual((None,), inputs.get_shape().as_list()) 626 627 # Run the worker and the example queue. 628 self._run_queue(worker_file_name_queue_name, session) 629 self._run_queue(example_queue_name, session) 630 631 self.assertAllEqual(session.run(inputs), [b"GHI"]) 632 633 self.assertTrue(g1 is not g2) 634 635 def test_batch_text_lines(self): 636 gfile.Glob = self._orig_glob 637 filename = self._create_temp_file("A\nB\nC\nD\nE\n") 638 639 batch_size = 3 640 queue_capacity = 10 641 name = "my_batch" 642 643 with ops.Graph().as_default() as g, self.session(graph=g) as session: 644 inputs = graph_io.read_batch_examples( 645 [filename], 646 batch_size, 647 reader=io_ops.TextLineReader, 648 randomize_input=False, 649 num_epochs=1, 650 queue_capacity=queue_capacity, 651 read_batch_size=10, 652 name=name) 653 self.assertAllEqual((None,), inputs.get_shape().as_list()) 654 session.run(variables.local_variables_initializer()) 655 656 coord = coordinator.Coordinator() 657 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 658 659 self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"]) 660 self.assertAllEqual(session.run(inputs), [b"D", b"E"]) 661 with self.assertRaises(errors.OutOfRangeError): 662 session.run(inputs) 663 664 coord.request_stop() 665 coord.join(threads) 666 667 def test_keyed_read_text_lines(self): 668 gfile.Glob = self._orig_glob 669 filename = self._create_temp_file("ABC\nDEF\nGHK\n") 670 671 batch_size = 1 672 queue_capacity = 5 673 name = "my_batch" 674 675 with ops.Graph().as_default() as g, self.session(graph=g) as session: 676 keys, inputs = graph_io.read_keyed_batch_examples( 677 filename, 678 batch_size, 679 reader=io_ops.TextLineReader, 680 randomize_input=False, 681 num_epochs=1, 682 queue_capacity=queue_capacity, 683 name=name) 684 self.assertAllEqual((None,), keys.get_shape().as_list()) 685 self.assertAllEqual((None,), inputs.get_shape().as_list()) 686 session.run(variables.local_variables_initializer()) 687 688 coord = coordinator.Coordinator() 689 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 690 691 self.assertAllEqual( 692 session.run([keys, inputs]), 693 [[filename.encode("utf-8") + b":1"], [b"ABC"]]) 694 self.assertAllEqual( 695 session.run([keys, inputs]), 696 [[filename.encode("utf-8") + b":2"], [b"DEF"]]) 697 self.assertAllEqual( 698 session.run([keys, inputs]), 699 [[filename.encode("utf-8") + b":3"], [b"GHK"]]) 700 with self.assertRaises(errors.OutOfRangeError): 701 session.run(inputs) 702 703 coord.request_stop() 704 coord.join(threads) 705 706 def test_keyed_parse_json(self): 707 gfile.Glob = self._orig_glob 708 filename = self._create_temp_file( 709 '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}\n' 710 '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}\n' 711 '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}\n') 712 713 batch_size = 1 714 queue_capacity = 5 715 name = "my_batch" 716 717 with ops.Graph().as_default() as g, self.session(graph=g) as session: 718 dtypes = {"age": parsing_ops.FixedLenFeature([1], dtypes_lib.int64)} 719 parse_fn = lambda example: parsing_ops.parse_single_example( # pylint: disable=g-long-lambda 720 parsing_ops.decode_json_example(example), dtypes) 721 keys, inputs = graph_io.read_keyed_batch_examples( 722 filename, 723 batch_size, 724 reader=io_ops.TextLineReader, 725 randomize_input=False, 726 num_epochs=1, 727 queue_capacity=queue_capacity, 728 parse_fn=parse_fn, 729 name=name) 730 self.assertAllEqual((None,), keys.get_shape().as_list()) 731 self.assertEqual(1, len(inputs)) 732 self.assertAllEqual((None, 1), inputs["age"].get_shape().as_list()) 733 session.run(variables.local_variables_initializer()) 734 735 coord = coordinator.Coordinator() 736 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 737 738 key, age = session.run([keys, inputs["age"]]) 739 self.assertAllEqual(age, [[0]]) 740 self.assertAllEqual(key, [filename.encode("utf-8") + b":1"]) 741 key, age = session.run([keys, inputs["age"]]) 742 self.assertAllEqual(age, [[1]]) 743 self.assertAllEqual(key, [filename.encode("utf-8") + b":2"]) 744 key, age = session.run([keys, inputs["age"]]) 745 self.assertAllEqual(age, [[2]]) 746 self.assertAllEqual(key, [filename.encode("utf-8") + b":3"]) 747 with self.assertRaises(errors.OutOfRangeError): 748 session.run(inputs) 749 750 coord.request_stop() 751 coord.join(threads) 752 753 def test_keyed_features_filter(self): 754 gfile.Glob = self._orig_glob 755 lines = [ 756 '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}', 757 '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}', 758 '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}', 759 '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}', 760 '{"features": {"feature": {"age": {"int64_list": {"value": [3]}}}}}', 761 '{"features": {"feature": {"age": {"int64_list": {"value": [5]}}}}}' 762 ] 763 filename = self._create_temp_file("\n".join(lines)) 764 765 batch_size = 2 766 queue_capacity = 4 767 name = "my_batch" 768 features = {"age": parsing_ops.FixedLenFeature([], dtypes_lib.int64)} 769 770 def filter_fn(keys, examples_json): 771 del keys 772 serialized = parsing_ops.decode_json_example(examples_json) 773 examples = parsing_ops.parse_example(serialized, features) 774 return math_ops.less(examples["age"], 2) 775 776 with ops.Graph().as_default() as g, self.session(graph=g) as session: 777 keys, inputs = graph_io._read_keyed_batch_examples_helper( 778 filename, 779 batch_size, 780 reader=io_ops.TextLineReader, 781 randomize_input=False, 782 num_epochs=1, 783 read_batch_size=batch_size, 784 queue_capacity=queue_capacity, 785 filter_fn=filter_fn, 786 name=name) 787 self.assertAllEqual((None,), keys.get_shape().as_list()) 788 self.assertAllEqual((None,), inputs.get_shape().as_list()) 789 session.run(variables.local_variables_initializer()) 790 791 coord = coordinator.Coordinator() 792 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 793 # First batch of two filtered examples. 794 out_keys, out_vals = session.run((keys, inputs)) 795 self.assertAllEqual( 796 [filename.encode("utf-8") + b":2", filename.encode("utf-8") + b":3"], 797 out_keys) 798 self.assertAllEqual([lines[1].encode("utf-8"), lines[2].encode("utf-8")], 799 out_vals) 800 801 # Second batch will only have one filtered example as that's the only 802 # remaining example that satisfies the filtering criterion. 803 out_keys, out_vals = session.run((keys, inputs)) 804 self.assertAllEqual([filename.encode("utf-8") + b":4"], out_keys) 805 self.assertAllEqual([lines[3].encode("utf-8")], out_vals) 806 807 # Exhausted input. 808 with self.assertRaises(errors.OutOfRangeError): 809 session.run((keys, inputs)) 810 811 coord.request_stop() 812 coord.join(threads) 813 814 def test_queue_parsed_features_single_tensor(self): 815 with ops.Graph().as_default() as g, self.session(graph=g) as session: 816 features = {"test": constant_op.constant([1, 2, 3])} 817 _, queued_features = graph_io.queue_parsed_features(features) 818 coord = coordinator.Coordinator() 819 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 820 out_features = session.run(queued_features["test"]) 821 self.assertAllEqual([1, 2, 3], out_features) 822 coord.request_stop() 823 coord.join(threads) 824 825 def test_read_keyed_batch_features_shared_queue(self): 826 batch_size = 17 827 shape = (0,) 828 fixed_feature = parsing_ops.FixedLenFeature( 829 shape=shape, dtype=dtypes_lib.float32) 830 feature = {"feature": fixed_feature} 831 reader = io_ops.TFRecordReader 832 833 _, queued_feature = graph_io.read_keyed_batch_features_shared_queue( 834 _VALID_FILE_PATTERN, batch_size, feature, reader) 835 836 with ops.Graph().as_default() as g, self.session(graph=g) as session: 837 features_result = graph_io.read_batch_features( 838 _VALID_FILE_PATTERN, batch_size, feature, reader) 839 session.run(variables.local_variables_initializer()) 840 841 self.assertAllEqual( 842 queued_feature.get("feature").get_shape().as_list(), 843 features_result.get("feature").get_shape().as_list()) 844 845 def test_get_file_names_errors(self): 846 # Raise bad file_pattern. 847 with self.assertRaises(ValueError): 848 graph_io._get_file_names([], True) 849 850 851if __name__ == "__main__": 852 test.main() 853