1# -*- coding: utf-8 -*- 2# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for tensorflow.python.ops.io_ops.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os 23import shutil 24import tempfile 25 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import io_ops 28from tensorflow.python.platform import test 29from tensorflow.python.util import compat 30 31 32class IoOpsTest(test.TestCase): 33 34 @test_util.run_deprecated_v1 35 def testReadFile(self): 36 cases = ['', 'Some contents', 'Неки садржаји на српском'] 37 for contents in cases: 38 contents = compat.as_bytes(contents) 39 with tempfile.NamedTemporaryFile( 40 prefix='ReadFileTest', dir=self.get_temp_dir(), delete=False) as temp: 41 temp.write(contents) 42 with self.cached_session(): 43 read = io_ops.read_file(temp.name) 44 self.assertEqual([], read.get_shape()) 45 self.assertEqual(read.eval(), contents) 46 os.remove(temp.name) 47 48 def testWriteFile(self): 49 cases = ['', 'Some contents'] 50 for contents in cases: 51 contents = compat.as_bytes(contents) 52 with tempfile.NamedTemporaryFile( 53 prefix='WriteFileTest', dir=self.get_temp_dir(), 54 delete=False) as temp: 55 pass 56 with self.cached_session() as sess: 57 w = io_ops.write_file(temp.name, contents) 58 self.evaluate(w) 59 with open(temp.name, 'rb') as f: 60 file_contents = f.read() 61 self.assertEqual(file_contents, contents) 62 os.remove(temp.name) 63 64 def testWriteFileCreateDir(self): 65 cases = ['', 'Some contents'] 66 for contents in cases: 67 contents = compat.as_bytes(contents) 68 subdir = os.path.join(self.get_temp_dir(), 'subdir1') 69 filepath = os.path.join(subdir, 'subdir2', 'filename') 70 with self.cached_session() as sess: 71 w = io_ops.write_file(filepath, contents) 72 self.evaluate(w) 73 with open(filepath, 'rb') as f: 74 file_contents = f.read() 75 self.assertEqual(file_contents, contents) 76 shutil.rmtree(subdir) 77 78 def _subset(self, files, indices): 79 return set( 80 compat.as_bytes(files[i].name) for i in range(len(files)) 81 if i in indices) 82 83 @test_util.run_deprecated_v1 84 def testMatchingFiles(self): 85 cases = [ 86 'ABcDEF.GH', 'ABzDEF.GH', 'ABasdfjklDEF.GH', 'AB3DEF.GH', 'AB4DEF.GH', 87 'ABDEF.GH', 'XYZ' 88 ] 89 files = [ 90 tempfile.NamedTemporaryFile( 91 prefix=c, dir=self.get_temp_dir(), delete=True) for c in cases 92 ] 93 94 with self.cached_session(): 95 # Test exact match without wildcards. 96 for f in files: 97 self.assertEqual( 98 io_ops.matching_files(f.name).eval(), compat.as_bytes(f.name)) 99 100 # We will look for files matching "ABxDEF.GH*" where "x" is some wildcard. 101 directory_path = files[0].name[:files[0].name.find(cases[0])] 102 pattern = directory_path + 'AB%sDEF.GH*' 103 104 self.assertEqual( 105 set(io_ops.matching_files(pattern % 'z').eval()), 106 self._subset(files, [1])) 107 self.assertEqual( 108 set(io_ops.matching_files(pattern % '?').eval()), 109 self._subset(files, [0, 1, 3, 4])) 110 self.assertEqual( 111 set(io_ops.matching_files(pattern % '*').eval()), 112 self._subset(files, [0, 1, 2, 3, 4, 5])) 113 # NOTE(mrry): Windows uses PathMatchSpec to match file patterns, which 114 # does not support the following expressions. 115 if os.name != 'nt': 116 self.assertEqual( 117 set(io_ops.matching_files(pattern % '[cxz]').eval()), 118 self._subset(files, [0, 1])) 119 self.assertEqual( 120 set(io_ops.matching_files(pattern % '[0-9]').eval()), 121 self._subset(files, [3, 4])) 122 123 # Test an empty list input. 124 self.assertItemsEqual(io_ops.matching_files([]).eval(), []) 125 126 # Test multiple exact filenames. 127 self.assertItemsEqual( 128 io_ops.matching_files([ 129 files[0].name, files[1].name, files[2].name]).eval(), 130 self._subset(files, [0, 1, 2])) 131 132 # Test multiple globs. 133 self.assertItemsEqual( 134 io_ops.matching_files([ 135 pattern % '?', directory_path + 'X?Z*']).eval(), 136 self._subset(files, [0, 1, 3, 4, 6])) 137 138 for f in files: 139 f.close() 140 141 142if __name__ == '__main__': 143 test.main() 144