1# -*- coding: utf-8 -*-
2# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for tensorflow.python.ops.io_ops."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23import shutil
24import tempfile
25
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import io_ops
28from tensorflow.python.platform import test
29from tensorflow.python.util import compat
30
31
32class IoOpsTest(test.TestCase):
33
34  @test_util.run_deprecated_v1
35  def testReadFile(self):
36    cases = ['', 'Some contents', 'Неки садржаји на српском']
37    for contents in cases:
38      contents = compat.as_bytes(contents)
39      with tempfile.NamedTemporaryFile(
40          prefix='ReadFileTest', dir=self.get_temp_dir(), delete=False) as temp:
41        temp.write(contents)
42      with self.cached_session():
43        read = io_ops.read_file(temp.name)
44        self.assertEqual([], read.get_shape())
45        self.assertEqual(read.eval(), contents)
46      os.remove(temp.name)
47
48  def testWriteFile(self):
49    cases = ['', 'Some contents']
50    for contents in cases:
51      contents = compat.as_bytes(contents)
52      with tempfile.NamedTemporaryFile(
53          prefix='WriteFileTest', dir=self.get_temp_dir(),
54          delete=False) as temp:
55        pass
56      with self.cached_session() as sess:
57        w = io_ops.write_file(temp.name, contents)
58        self.evaluate(w)
59        with open(temp.name, 'rb') as f:
60          file_contents = f.read()
61        self.assertEqual(file_contents, contents)
62      os.remove(temp.name)
63
64  def testWriteFileCreateDir(self):
65    cases = ['', 'Some contents']
66    for contents in cases:
67      contents = compat.as_bytes(contents)
68      subdir = os.path.join(self.get_temp_dir(), 'subdir1')
69      filepath = os.path.join(subdir, 'subdir2', 'filename')
70      with self.cached_session() as sess:
71        w = io_ops.write_file(filepath, contents)
72        self.evaluate(w)
73        with open(filepath, 'rb') as f:
74          file_contents = f.read()
75        self.assertEqual(file_contents, contents)
76      shutil.rmtree(subdir)
77
78  def _subset(self, files, indices):
79    return set(
80        compat.as_bytes(files[i].name) for i in range(len(files))
81        if i in indices)
82
83  @test_util.run_deprecated_v1
84  def testMatchingFiles(self):
85    cases = [
86        'ABcDEF.GH', 'ABzDEF.GH', 'ABasdfjklDEF.GH', 'AB3DEF.GH', 'AB4DEF.GH',
87        'ABDEF.GH', 'XYZ'
88    ]
89    files = [
90        tempfile.NamedTemporaryFile(
91            prefix=c, dir=self.get_temp_dir(), delete=True) for c in cases
92    ]
93
94    with self.cached_session():
95      # Test exact match without wildcards.
96      for f in files:
97        self.assertEqual(
98            io_ops.matching_files(f.name).eval(), compat.as_bytes(f.name))
99
100      # We will look for files matching "ABxDEF.GH*" where "x" is some wildcard.
101      directory_path = files[0].name[:files[0].name.find(cases[0])]
102      pattern = directory_path + 'AB%sDEF.GH*'
103
104      self.assertEqual(
105          set(io_ops.matching_files(pattern % 'z').eval()),
106          self._subset(files, [1]))
107      self.assertEqual(
108          set(io_ops.matching_files(pattern % '?').eval()),
109          self._subset(files, [0, 1, 3, 4]))
110      self.assertEqual(
111          set(io_ops.matching_files(pattern % '*').eval()),
112          self._subset(files, [0, 1, 2, 3, 4, 5]))
113      # NOTE(mrry): Windows uses PathMatchSpec to match file patterns, which
114      # does not support the following expressions.
115      if os.name != 'nt':
116        self.assertEqual(
117            set(io_ops.matching_files(pattern % '[cxz]').eval()),
118            self._subset(files, [0, 1]))
119        self.assertEqual(
120            set(io_ops.matching_files(pattern % '[0-9]').eval()),
121            self._subset(files, [3, 4]))
122
123      # Test an empty list input.
124      self.assertItemsEqual(io_ops.matching_files([]).eval(), [])
125
126      # Test multiple exact filenames.
127      self.assertItemsEqual(
128          io_ops.matching_files([
129              files[0].name, files[1].name, files[2].name]).eval(),
130          self._subset(files, [0, 1, 2]))
131
132      # Test multiple globs.
133      self.assertItemsEqual(
134          io_ops.matching_files([
135              pattern % '?', directory_path + 'X?Z*']).eval(),
136          self._subset(files, [0, 1, 3, 4, 6]))
137
138    for f in files:
139      f.close()
140
141
142if __name__ == '__main__':
143  test.main()
144