1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for learn.io.graph_io."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import base64
22import os
23import random
24import tempfile
25
26from six.moves import xrange  # pylint: disable=redefined-builtin
27
28from tensorflow.contrib.learn.python.learn.learn_io import graph_io
29from tensorflow.python.client import session as session_lib
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes as dtypes_lib
32from tensorflow.python.framework import errors
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import test_util
35from tensorflow.python.ops import io_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import parsing_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import gfile
40from tensorflow.python.platform import test
41from tensorflow.python.training import coordinator
42from tensorflow.python.training import queue_runner_impl
43from tensorflow.python.training import server_lib
44
45_VALID_FILE_PATTERN = "VALID"
46_VALID_FILE_PATTERN_2 = "VALID_2"
47_FILE_NAMES = [b"abc", b"def", b"ghi", b"jkl"]
48_FILE_NAMES_2 = [b"mno", b"pqr"]
49_INVALID_FILE_PATTERN = "INVALID"
50
51
52class GraphIOTest(test.TestCase):
53
54  def _mock_glob(self, pattern):
55    if _VALID_FILE_PATTERN == pattern:
56      return _FILE_NAMES
57    if _VALID_FILE_PATTERN_2 == pattern:
58      return _FILE_NAMES_2
59    self.assertEqual(_INVALID_FILE_PATTERN, pattern)
60    return []
61
62  def setUp(self):
63    super(GraphIOTest, self).setUp()
64    random.seed(42)
65    self._orig_glob = gfile.Glob
66    gfile.Glob = self._mock_glob
67
68  def tearDown(self):
69    gfile.Glob = self._orig_glob
70    super(GraphIOTest, self).tearDown()
71
72  def test_dequeue_batch_value_errors(self):
73    default_batch_size = 17
74    queue_capacity = 1234
75    num_threads = 3
76    name = "my_batch"
77
78    self.assertRaisesRegexp(
79        ValueError,
80        "No files match",
81        graph_io.read_batch_examples,
82        _INVALID_FILE_PATTERN,
83        default_batch_size,
84        io_ops.TFRecordReader,
85        False,
86        num_epochs=None,
87        queue_capacity=queue_capacity,
88        num_threads=num_threads,
89        name=name)
90    self.assertRaisesRegexp(
91        ValueError,
92        "Invalid batch_size",
93        graph_io.read_batch_examples,
94        _VALID_FILE_PATTERN,
95        None,
96        io_ops.TFRecordReader,
97        False,
98        num_epochs=None,
99        queue_capacity=queue_capacity,
100        num_threads=num_threads,
101        name=name)
102    self.assertRaisesRegexp(
103        ValueError,
104        "Invalid batch_size",
105        graph_io.read_batch_examples,
106        _VALID_FILE_PATTERN,
107        -1,
108        io_ops.TFRecordReader,
109        False,
110        num_epochs=None,
111        queue_capacity=queue_capacity,
112        num_threads=num_threads,
113        name=name)
114    self.assertRaisesRegexp(
115        ValueError,
116        "Invalid batch_size",
117        graph_io.read_batch_examples,
118        _VALID_FILE_PATTERN,
119        default_batch_size,
120        io_ops.TFRecordReader,
121        False,
122        num_epochs=None,
123        queue_capacity=default_batch_size,
124        num_threads=num_threads,
125        name=name)
126    self.assertRaisesRegexp(
127        ValueError,
128        "Invalid queue_capacity",
129        graph_io.read_batch_examples,
130        _VALID_FILE_PATTERN,
131        default_batch_size,
132        io_ops.TFRecordReader,
133        False,
134        num_epochs=None,
135        queue_capacity=None,
136        num_threads=num_threads,
137        name=name)
138    self.assertRaisesRegexp(
139        ValueError,
140        "Invalid num_threads",
141        graph_io.read_batch_examples,
142        _VALID_FILE_PATTERN,
143        default_batch_size,
144        io_ops.TFRecordReader,
145        False,
146        num_epochs=None,
147        queue_capacity=queue_capacity,
148        num_threads=None,
149        name=name)
150    self.assertRaisesRegexp(
151        ValueError,
152        "Invalid num_threads",
153        graph_io.read_batch_examples,
154        _VALID_FILE_PATTERN,
155        default_batch_size,
156        io_ops.TFRecordReader,
157        False,
158        num_epochs=None,
159        queue_capacity=queue_capacity,
160        num_threads=-1,
161        name=name)
162    self.assertRaisesRegexp(
163        ValueError,
164        "Invalid batch_size",
165        graph_io.read_batch_examples,
166        _VALID_FILE_PATTERN,
167        queue_capacity + 1,
168        io_ops.TFRecordReader,
169        False,
170        num_epochs=None,
171        queue_capacity=queue_capacity,
172        num_threads=1,
173        name=name)
174    self.assertRaisesRegexp(
175        ValueError,
176        "Invalid num_epochs",
177        graph_io.read_batch_examples,
178        _VALID_FILE_PATTERN,
179        default_batch_size,
180        io_ops.TFRecordReader,
181        False,
182        num_epochs=-1,
183        queue_capacity=queue_capacity,
184        num_threads=1,
185        name=name)
186    self.assertRaisesRegexp(
187        ValueError,
188        "Invalid read_batch_size",
189        graph_io.read_batch_examples,
190        _VALID_FILE_PATTERN,
191        default_batch_size,
192        io_ops.TFRecordReader,
193        False,
194        num_epochs=None,
195        queue_capacity=queue_capacity,
196        num_threads=1,
197        read_batch_size=0,
198        name=name)
199
200  def test_batch_record_features(self):
201    batch_size = 17
202    queue_capacity = 1234
203    name = "my_batch"
204    shape = (0,)
205    features = {
206        "feature":
207            parsing_ops.FixedLenFeature(shape=shape, dtype=dtypes_lib.float32)
208    }
209
210    with ops.Graph().as_default() as g, self.session(graph=g) as sess:
211      features = graph_io.read_batch_record_features(
212          _VALID_FILE_PATTERN,
213          batch_size,
214          features,
215          randomize_input=False,
216          queue_capacity=queue_capacity,
217          reader_num_threads=2,
218          name=name)
219      self.assertTrue("feature" in features,
220                      "'feature' missing from %s." % features.keys())
221      feature = features["feature"]
222      self.assertEqual("%s/fifo_queue_1_Dequeue:0" % name, feature.name)
223      self.assertAllEqual((batch_size,) + shape, feature.get_shape().as_list())
224      file_name_queue_name = "%s/file_name_queue" % name
225      file_names_name = "%s/input" % file_name_queue_name
226      example_queue_name = "%s/fifo_queue" % name
227      parse_example_queue_name = "%s/fifo_queue" % name
228      op_nodes = test_util.assert_ops_in_graph({
229          file_names_name: "Const",
230          file_name_queue_name: "FIFOQueueV2",
231          "%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2",
232          example_queue_name: "FIFOQueueV2",
233          parse_example_queue_name: "FIFOQueueV2",
234          name: "QueueDequeueManyV2"
235      }, g)
236      self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0])
237      self.assertEqual(queue_capacity,
238                       op_nodes[example_queue_name].attr["capacity"].i)
239
240  def test_one_epoch(self):
241    batch_size = 17
242    queue_capacity = 1234
243    name = "my_batch"
244
245    with ops.Graph().as_default() as g, self.session(graph=g) as sess:
246      inputs = graph_io.read_batch_examples(
247          _VALID_FILE_PATTERN,
248          batch_size,
249          reader=io_ops.TFRecordReader,
250          randomize_input=True,
251          num_epochs=1,
252          queue_capacity=queue_capacity,
253          name=name)
254      self.assertAllEqual((None,), inputs.get_shape().as_list())
255      self.assertEqual("%s:1" % name, inputs.name)
256      file_name_queue_name = "%s/file_name_queue" % name
257      file_name_queue_limit_name = (
258          "%s/limit_epochs/epochs" % file_name_queue_name)
259      file_names_name = "%s/input" % file_name_queue_name
260      example_queue_name = "%s/random_shuffle_queue" % name
261      op_nodes = test_util.assert_ops_in_graph({
262          file_names_name: "Const",
263          file_name_queue_name: "FIFOQueueV2",
264          "%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2",
265          example_queue_name: "RandomShuffleQueueV2",
266          name: "QueueDequeueUpToV2",
267          file_name_queue_limit_name: "VariableV2"
268      }, g)
269      self.assertEqual(
270          set(_FILE_NAMES), set(sess.run(["%s:0" % file_names_name])[0]))
271      self.assertEqual(queue_capacity,
272                       op_nodes[example_queue_name].attr["capacity"].i)
273
274  def test_batch_randomized_multiple_globs(self):
275    batch_size = 17
276    queue_capacity = 1234
277    name = "my_batch"
278
279    with ops.Graph().as_default() as g, self.session(graph=g) as sess:
280      inputs = graph_io.read_batch_examples(
281          [_VALID_FILE_PATTERN, _VALID_FILE_PATTERN_2],
282          batch_size,
283          reader=io_ops.TFRecordReader,
284          randomize_input=True,
285          queue_capacity=queue_capacity,
286          name=name)
287      self.assertAllEqual((batch_size,), inputs.get_shape().as_list())
288      self.assertEqual("%s:1" % name, inputs.name)
289      file_name_queue_name = "%s/file_name_queue" % name
290      file_names_name = "%s/input" % file_name_queue_name
291      example_queue_name = "%s/random_shuffle_queue" % name
292      op_nodes = test_util.assert_ops_in_graph({
293          file_names_name: "Const",
294          file_name_queue_name: "FIFOQueueV2",
295          "%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2",
296          example_queue_name: "RandomShuffleQueueV2",
297          name: "QueueDequeueManyV2"
298      }, g)
299      self.assertEqual(
300          set(_FILE_NAMES + _FILE_NAMES_2),
301          set(sess.run(["%s:0" % file_names_name])[0]))
302      self.assertEqual(queue_capacity,
303                       op_nodes[example_queue_name].attr["capacity"].i)
304
305  def _create_temp_file(self, lines):
306    tempdir = tempfile.mkdtemp()
307    filename = os.path.join(tempdir, "temp_file")
308    gfile.Open(filename, "w").write(lines)
309    return filename
310
311  def _create_sorted_temp_files(self, lines_list):
312    tempdir = tempfile.mkdtemp()
313    filenames = []
314    for i, lines in enumerate(lines_list):
315      filename = os.path.join(tempdir, "temp_file%05d" % i)
316      gfile.Open(filename, "w").write(lines)
317      filenames.append(filename)
318    return filenames
319
320  def test_read_text_lines(self):
321    gfile.Glob = self._orig_glob
322    filename = self._create_temp_file("ABC\nDEF\nGHK\n")
323
324    batch_size = 1
325    queue_capacity = 5
326    name = "my_batch"
327
328    with ops.Graph().as_default() as g, self.session(graph=g) as session:
329      inputs = graph_io.read_batch_examples(
330          filename,
331          batch_size,
332          reader=io_ops.TextLineReader,
333          randomize_input=False,
334          num_epochs=1,
335          queue_capacity=queue_capacity,
336          name=name)
337      self.assertAllEqual((None,), inputs.get_shape().as_list())
338      session.run(variables.local_variables_initializer())
339
340      coord = coordinator.Coordinator()
341      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
342
343      self.assertAllEqual(session.run(inputs), [b"ABC"])
344      self.assertAllEqual(session.run(inputs), [b"DEF"])
345      self.assertAllEqual(session.run(inputs), [b"GHK"])
346      with self.assertRaises(errors.OutOfRangeError):
347        session.run(inputs)
348
349      coord.request_stop()
350      coord.join(threads)
351
352  def _create_file_from_list_of_features(self, lines):
353    json_lines = [
354        "".join([
355            '{"features": { "feature": { "sequence": {',
356            '"bytes_list": { "value": ["',
357            base64.b64encode(l).decode("ascii"), '"]}}}}}\n'
358        ]) for l in lines
359    ]
360    return self._create_temp_file("".join(json_lines))
361
362  def test_read_text_lines_large(self):
363    gfile.Glob = self._orig_glob
364    sequence_prefix = "abcdefghijklmnopqrstuvwxyz123456789"
365    num_records = 49999
366    lines = [
367        "".join([sequence_prefix, str(l)]).encode("ascii")
368        for l in xrange(num_records)
369    ]
370    filename = self._create_file_from_list_of_features(lines)
371    batch_size = 10000
372    queue_capacity = 100000
373    name = "my_large_batch"
374
375    features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)}
376
377    with ops.Graph().as_default() as g, self.session(graph=g) as session:
378      keys, result = graph_io.read_keyed_batch_features(
379          filename,
380          batch_size,
381          features,
382          io_ops.TextLineReader,
383          randomize_input=False,
384          num_epochs=1,
385          queue_capacity=queue_capacity,
386          num_enqueue_threads=2,
387          parse_fn=parsing_ops.decode_json_example,
388          name=name)
389      self.assertAllEqual((None,), keys.get_shape().as_list())
390      self.assertEqual(1, len(result))
391      self.assertAllEqual((None,), result["sequence"].get_shape().as_list())
392      session.run(variables.local_variables_initializer())
393      coord = coordinator.Coordinator()
394      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
395
396      data = []
397      try:
398        while not coord.should_stop():
399          data.append(session.run(result))
400      except errors.OutOfRangeError:
401        pass
402      finally:
403        coord.request_stop()
404
405      coord.join(threads)
406
407    parsed_records = [
408        item for sublist in [d["sequence"] for d in data] for item in sublist
409    ]
410    # Check that the number of records matches expected and all records
411    # are present.
412    self.assertEqual(len(parsed_records), num_records)
413    self.assertEqual(set(parsed_records), set(lines))
414
415  def test_read_batch_features_maintains_order(self):
416    """Make sure that examples are read in the right order.
417
418    When randomize_input=False, num_enqueue_threads=1 and reader_num_threads=1
419    read_keyed_batch_features() should read the examples in the same order as
420    they appear in the file.
421    """
422    gfile.Glob = self._orig_glob
423    num_records = 1000
424    lines = ["".join(str(l)).encode("ascii") for l in xrange(num_records)]
425    filename = self._create_file_from_list_of_features(lines)
426    batch_size = 10
427    queue_capacity = 1000
428    name = "my_large_batch"
429
430    features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)}
431
432    with ops.Graph().as_default() as g, self.session(graph=g) as session:
433      result = graph_io.read_batch_features(
434          filename,
435          batch_size,
436          features,
437          io_ops.TextLineReader,
438          randomize_input=False,
439          num_epochs=1,
440          queue_capacity=queue_capacity,
441          reader_num_threads=1,
442          num_enqueue_threads=1,
443          parse_fn=parsing_ops.decode_json_example,
444          name=name)
445      self.assertEqual(1, len(result))
446      self.assertAllEqual((None,), result["sequence"].get_shape().as_list())
447      session.run(variables.local_variables_initializer())
448      coord = coordinator.Coordinator()
449      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
450
451      data = []
452      try:
453        while not coord.should_stop():
454          data.append(session.run(result))
455      except errors.OutOfRangeError:
456        pass
457      finally:
458        coord.request_stop()
459
460      coord.join(threads)
461
462    parsed_records = [
463        item for sublist in [d["sequence"] for d in data] for item in sublist
464    ]
465    # Check that the number of records matches expected and all records
466    # are present in the right order.
467    self.assertEqual(len(parsed_records), num_records)
468    self.assertEqual(parsed_records, lines)
469
470  def test_read_text_lines_multifile(self):
471    gfile.Glob = self._orig_glob
472    filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])
473
474    batch_size = 1
475    queue_capacity = 5
476    name = "my_batch"
477
478    with ops.Graph().as_default() as g, self.session(graph=g) as session:
479      inputs = graph_io.read_batch_examples(
480          filenames,
481          batch_size,
482          reader=io_ops.TextLineReader,
483          randomize_input=False,
484          num_epochs=1,
485          queue_capacity=queue_capacity,
486          name=name)
487      self.assertAllEqual((None,), inputs.get_shape().as_list())
488      session.run(variables.local_variables_initializer())
489
490      coord = coordinator.Coordinator()
491      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
492
493      self.assertEqual("%s:1" % name, inputs.name)
494      file_name_queue_name = "%s/file_name_queue" % name
495      file_names_name = "%s/input" % file_name_queue_name
496      example_queue_name = "%s/fifo_queue" % name
497      test_util.assert_ops_in_graph({
498          file_names_name: "Const",
499          file_name_queue_name: "FIFOQueueV2",
500          "%s/read/TextLineReaderV2" % name: "TextLineReaderV2",
501          example_queue_name: "FIFOQueueV2",
502          name: "QueueDequeueUpToV2"
503      }, g)
504
505      self.assertAllEqual(session.run(inputs), [b"ABC"])
506      self.assertAllEqual(session.run(inputs), [b"DEF"])
507      self.assertAllEqual(session.run(inputs), [b"GHK"])
508      with self.assertRaises(errors.OutOfRangeError):
509        session.run(inputs)
510
511      coord.request_stop()
512      coord.join(threads)
513
514  def test_read_text_lines_multifile_with_shared_queue(self):
515    gfile.Glob = self._orig_glob
516    filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])
517
518    batch_size = 1
519    queue_capacity = 5
520    name = "my_batch"
521
522    with ops.Graph().as_default() as g, self.session(graph=g) as session:
523      keys, inputs = graph_io.read_keyed_batch_examples_shared_queue(
524          filenames,
525          batch_size,
526          reader=io_ops.TextLineReader,
527          randomize_input=False,
528          num_epochs=1,
529          queue_capacity=queue_capacity,
530          name=name)
531      self.assertAllEqual((None,), keys.get_shape().as_list())
532      self.assertAllEqual((None,), inputs.get_shape().as_list())
533      session.run([
534          variables.local_variables_initializer(),
535          variables.global_variables_initializer()
536      ])
537
538      coord = coordinator.Coordinator()
539      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
540
541      self.assertEqual("%s:1" % name, inputs.name)
542      example_queue_name = "%s/fifo_queue" % name
543      worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name
544      test_util.assert_ops_in_graph({
545          "%s/read/TextLineReaderV2" % name: "TextLineReaderV2",
546          example_queue_name: "FIFOQueueV2",
547          worker_file_name_queue_name: "FIFOQueueV2",
548          name: "QueueDequeueUpToV2"
549      }, g)
550
551      self.assertAllEqual(session.run(inputs), [b"ABC"])
552      self.assertAllEqual(session.run(inputs), [b"DEF"])
553      self.assertAllEqual(session.run(inputs), [b"GHK"])
554      with self.assertRaises(errors.OutOfRangeError):
555        session.run(inputs)
556
557      coord.request_stop()
558      coord.join(threads)
559
560  def _get_qr(self, name):
561    for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
562      if qr.name == name:
563        return qr
564
565  def _run_queue(self, name, session):
566    qr = self._get_qr(name)
567    for op in qr.enqueue_ops:
568      session.run(op)
569
570  def test_multiple_workers_with_shared_queue(self):
571    gfile.Glob = self._orig_glob
572    filenames = self._create_sorted_temp_files([
573        "ABC\n", "DEF\n", "GHI\n", "JKL\n", "MNO\n", "PQR\n", "STU\n", "VWX\n",
574        "YZ\n"
575    ])
576
577    batch_size = 1
578    queue_capacity = 5
579    name = "my_batch"
580    example_queue_name = "%s/fifo_queue" % name
581    worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name
582
583    server = server_lib.Server.create_local_server()
584
585    with ops.Graph().as_default() as g1, session_lib.Session(
586        server.target, graph=g1) as session:
587      keys, inputs = graph_io.read_keyed_batch_examples_shared_queue(
588          filenames,
589          batch_size,
590          reader=io_ops.TextLineReader,
591          randomize_input=False,
592          num_epochs=1,
593          queue_capacity=queue_capacity,
594          name=name)
595      self.assertAllEqual((None,), keys.get_shape().as_list())
596      self.assertAllEqual((None,), inputs.get_shape().as_list())
597      session.run([
598          variables.local_variables_initializer(),
599          variables.global_variables_initializer()
600      ])
601
602      # Run the two queues once manually.
603      self._run_queue(worker_file_name_queue_name, session)
604      self._run_queue(example_queue_name, session)
605
606      self.assertAllEqual(session.run(inputs), [b"ABC"])
607
608      # Run the worker and the example queue.
609      self._run_queue(worker_file_name_queue_name, session)
610      self._run_queue(example_queue_name, session)
611
612      self.assertAllEqual(session.run(inputs), [b"DEF"])
613
614    with ops.Graph().as_default() as g2, session_lib.Session(
615        server.target, graph=g2) as session:
616      keys, inputs = graph_io.read_keyed_batch_examples_shared_queue(
617          filenames,
618          batch_size,
619          reader=io_ops.TextLineReader,
620          randomize_input=False,
621          num_epochs=1,
622          queue_capacity=queue_capacity,
623          name=name)
624      self.assertAllEqual((None,), keys.get_shape().as_list())
625      self.assertAllEqual((None,), inputs.get_shape().as_list())
626
627      # Run the worker and the example queue.
628      self._run_queue(worker_file_name_queue_name, session)
629      self._run_queue(example_queue_name, session)
630
631      self.assertAllEqual(session.run(inputs), [b"GHI"])
632
633    self.assertTrue(g1 is not g2)
634
635  def test_batch_text_lines(self):
636    gfile.Glob = self._orig_glob
637    filename = self._create_temp_file("A\nB\nC\nD\nE\n")
638
639    batch_size = 3
640    queue_capacity = 10
641    name = "my_batch"
642
643    with ops.Graph().as_default() as g, self.session(graph=g) as session:
644      inputs = graph_io.read_batch_examples(
645          [filename],
646          batch_size,
647          reader=io_ops.TextLineReader,
648          randomize_input=False,
649          num_epochs=1,
650          queue_capacity=queue_capacity,
651          read_batch_size=10,
652          name=name)
653      self.assertAllEqual((None,), inputs.get_shape().as_list())
654      session.run(variables.local_variables_initializer())
655
656      coord = coordinator.Coordinator()
657      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
658
659      self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
660      self.assertAllEqual(session.run(inputs), [b"D", b"E"])
661      with self.assertRaises(errors.OutOfRangeError):
662        session.run(inputs)
663
664      coord.request_stop()
665      coord.join(threads)
666
667  def test_keyed_read_text_lines(self):
668    gfile.Glob = self._orig_glob
669    filename = self._create_temp_file("ABC\nDEF\nGHK\n")
670
671    batch_size = 1
672    queue_capacity = 5
673    name = "my_batch"
674
675    with ops.Graph().as_default() as g, self.session(graph=g) as session:
676      keys, inputs = graph_io.read_keyed_batch_examples(
677          filename,
678          batch_size,
679          reader=io_ops.TextLineReader,
680          randomize_input=False,
681          num_epochs=1,
682          queue_capacity=queue_capacity,
683          name=name)
684      self.assertAllEqual((None,), keys.get_shape().as_list())
685      self.assertAllEqual((None,), inputs.get_shape().as_list())
686      session.run(variables.local_variables_initializer())
687
688      coord = coordinator.Coordinator()
689      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
690
691      self.assertAllEqual(
692          session.run([keys, inputs]),
693          [[filename.encode("utf-8") + b":1"], [b"ABC"]])
694      self.assertAllEqual(
695          session.run([keys, inputs]),
696          [[filename.encode("utf-8") + b":2"], [b"DEF"]])
697      self.assertAllEqual(
698          session.run([keys, inputs]),
699          [[filename.encode("utf-8") + b":3"], [b"GHK"]])
700      with self.assertRaises(errors.OutOfRangeError):
701        session.run(inputs)
702
703      coord.request_stop()
704      coord.join(threads)
705
706  def test_keyed_parse_json(self):
707    gfile.Glob = self._orig_glob
708    filename = self._create_temp_file(
709        '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}\n'
710        '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}\n'
711        '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}\n')
712
713    batch_size = 1
714    queue_capacity = 5
715    name = "my_batch"
716
717    with ops.Graph().as_default() as g, self.session(graph=g) as session:
718      dtypes = {"age": parsing_ops.FixedLenFeature([1], dtypes_lib.int64)}
719      parse_fn = lambda example: parsing_ops.parse_single_example(  # pylint: disable=g-long-lambda
720          parsing_ops.decode_json_example(example), dtypes)
721      keys, inputs = graph_io.read_keyed_batch_examples(
722          filename,
723          batch_size,
724          reader=io_ops.TextLineReader,
725          randomize_input=False,
726          num_epochs=1,
727          queue_capacity=queue_capacity,
728          parse_fn=parse_fn,
729          name=name)
730      self.assertAllEqual((None,), keys.get_shape().as_list())
731      self.assertEqual(1, len(inputs))
732      self.assertAllEqual((None, 1), inputs["age"].get_shape().as_list())
733      session.run(variables.local_variables_initializer())
734
735      coord = coordinator.Coordinator()
736      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
737
738      key, age = session.run([keys, inputs["age"]])
739      self.assertAllEqual(age, [[0]])
740      self.assertAllEqual(key, [filename.encode("utf-8") + b":1"])
741      key, age = session.run([keys, inputs["age"]])
742      self.assertAllEqual(age, [[1]])
743      self.assertAllEqual(key, [filename.encode("utf-8") + b":2"])
744      key, age = session.run([keys, inputs["age"]])
745      self.assertAllEqual(age, [[2]])
746      self.assertAllEqual(key, [filename.encode("utf-8") + b":3"])
747      with self.assertRaises(errors.OutOfRangeError):
748        session.run(inputs)
749
750      coord.request_stop()
751      coord.join(threads)
752
753  def test_keyed_features_filter(self):
754    gfile.Glob = self._orig_glob
755    lines = [
756        '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}',
757        '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}',
758        '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}',
759        '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}',
760        '{"features": {"feature": {"age": {"int64_list": {"value": [3]}}}}}',
761        '{"features": {"feature": {"age": {"int64_list": {"value": [5]}}}}}'
762    ]
763    filename = self._create_temp_file("\n".join(lines))
764
765    batch_size = 2
766    queue_capacity = 4
767    name = "my_batch"
768    features = {"age": parsing_ops.FixedLenFeature([], dtypes_lib.int64)}
769
770    def filter_fn(keys, examples_json):
771      del keys
772      serialized = parsing_ops.decode_json_example(examples_json)
773      examples = parsing_ops.parse_example(serialized, features)
774      return math_ops.less(examples["age"], 2)
775
776    with ops.Graph().as_default() as g, self.session(graph=g) as session:
777      keys, inputs = graph_io._read_keyed_batch_examples_helper(
778          filename,
779          batch_size,
780          reader=io_ops.TextLineReader,
781          randomize_input=False,
782          num_epochs=1,
783          read_batch_size=batch_size,
784          queue_capacity=queue_capacity,
785          filter_fn=filter_fn,
786          name=name)
787      self.assertAllEqual((None,), keys.get_shape().as_list())
788      self.assertAllEqual((None,), inputs.get_shape().as_list())
789      session.run(variables.local_variables_initializer())
790
791      coord = coordinator.Coordinator()
792      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
793      # First batch of two filtered examples.
794      out_keys, out_vals = session.run((keys, inputs))
795      self.assertAllEqual(
796          [filename.encode("utf-8") + b":2", filename.encode("utf-8") + b":3"],
797          out_keys)
798      self.assertAllEqual([lines[1].encode("utf-8"), lines[2].encode("utf-8")],
799                          out_vals)
800
801      # Second batch will only have one filtered example as that's the only
802      # remaining example that satisfies the filtering criterion.
803      out_keys, out_vals = session.run((keys, inputs))
804      self.assertAllEqual([filename.encode("utf-8") + b":4"], out_keys)
805      self.assertAllEqual([lines[3].encode("utf-8")], out_vals)
806
807      # Exhausted input.
808      with self.assertRaises(errors.OutOfRangeError):
809        session.run((keys, inputs))
810
811      coord.request_stop()
812      coord.join(threads)
813
814  def test_queue_parsed_features_single_tensor(self):
815    with ops.Graph().as_default() as g, self.session(graph=g) as session:
816      features = {"test": constant_op.constant([1, 2, 3])}
817      _, queued_features = graph_io.queue_parsed_features(features)
818      coord = coordinator.Coordinator()
819      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
820      out_features = session.run(queued_features["test"])
821      self.assertAllEqual([1, 2, 3], out_features)
822      coord.request_stop()
823      coord.join(threads)
824
825  def test_read_keyed_batch_features_shared_queue(self):
826    batch_size = 17
827    shape = (0,)
828    fixed_feature = parsing_ops.FixedLenFeature(
829        shape=shape, dtype=dtypes_lib.float32)
830    feature = {"feature": fixed_feature}
831    reader = io_ops.TFRecordReader
832
833    _, queued_feature = graph_io.read_keyed_batch_features_shared_queue(
834        _VALID_FILE_PATTERN, batch_size, feature, reader)
835
836    with ops.Graph().as_default() as g, self.session(graph=g) as session:
837      features_result = graph_io.read_batch_features(
838          _VALID_FILE_PATTERN, batch_size, feature, reader)
839      session.run(variables.local_variables_initializer())
840
841    self.assertAllEqual(
842        queued_feature.get("feature").get_shape().as_list(),
843        features_result.get("feature").get_shape().as_list())
844
845  def test_get_file_names_errors(self):
846    # Raise bad file_pattern.
847    with self.assertRaises(ValueError):
848      graph_io._get_file_names([], True)
849
850
851if __name__ == "__main__":
852  test.main()
853