1# -*- coding: utf-8 -*-
2# Copyright 2015 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Tests for yapf.file_resources."""
16
17import contextlib
18import os
19import shutil
20import tempfile
21import unittest
22
23from yapf.yapflib import errors
24from yapf.yapflib import file_resources
25from yapf.yapflib import py3compat
26
27from yapftests import utils
28
29
30@contextlib.contextmanager
31def _restore_working_dir():
32  curdir = os.getcwd()
33  try:
34    yield
35  finally:
36    os.chdir(curdir)
37
38
39class GetDefaultStyleForDirTest(unittest.TestCase):
40
41  def setUp(self):
42    self.test_tmpdir = tempfile.mkdtemp()
43
44  def tearDown(self):
45    shutil.rmtree(self.test_tmpdir)
46
47  def test_no_local_style(self):
48    test_file = os.path.join(self.test_tmpdir, 'file.py')
49    style_name = file_resources.GetDefaultStyleForDir(test_file)
50    self.assertEqual(style_name, 'pep8')
51
52  def test_with_local_style(self):
53    # Create an empty .style.yapf file in test_tmpdir
54    style_file = os.path.join(self.test_tmpdir, '.style.yapf')
55    open(style_file, 'w').close()
56
57    test_filename = os.path.join(self.test_tmpdir, 'file.py')
58    self.assertEqual(style_file,
59                     file_resources.GetDefaultStyleForDir(test_filename))
60
61    test_filename = os.path.join(self.test_tmpdir, 'dir1', 'file.py')
62    self.assertEqual(style_file,
63                     file_resources.GetDefaultStyleForDir(test_filename))
64
65
66def _touch_files(filenames):
67  for name in filenames:
68    open(name, 'a').close()
69
70
71class GetCommandLineFilesTest(unittest.TestCase):
72
73  def setUp(self):
74    self.test_tmpdir = tempfile.mkdtemp()
75    self.old_dir = os.getcwd()
76
77  def tearDown(self):
78    shutil.rmtree(self.test_tmpdir)
79    os.chdir(self.old_dir)
80
81  def _make_test_dir(self, name):
82    fullpath = os.path.normpath(os.path.join(self.test_tmpdir, name))
83    os.makedirs(fullpath)
84    return fullpath
85
86  def test_find_files_not_dirs(self):
87    tdir1 = self._make_test_dir('test1')
88    tdir2 = self._make_test_dir('test2')
89    file1 = os.path.join(tdir1, 'testfile1.py')
90    file2 = os.path.join(tdir2, 'testfile2.py')
91    _touch_files([file1, file2])
92
93    self.assertEqual(
94        file_resources.GetCommandLineFiles(
95            [file1, file2], recursive=False, exclude=None), [file1, file2])
96    self.assertEqual(
97        file_resources.GetCommandLineFiles(
98            [file1, file2], recursive=True, exclude=None), [file1, file2])
99
100  def test_nonrecursive_find_in_dir(self):
101    tdir1 = self._make_test_dir('test1')
102    tdir2 = self._make_test_dir('test1/foo')
103    file1 = os.path.join(tdir1, 'testfile1.py')
104    file2 = os.path.join(tdir2, 'testfile2.py')
105    _touch_files([file1, file2])
106
107    self.assertRaises(
108        errors.YapfError,
109        file_resources.GetCommandLineFiles,
110        command_line_file_list=[tdir1],
111        recursive=False,
112        exclude=None)
113
114  def test_recursive_find_in_dir(self):
115    tdir1 = self._make_test_dir('test1')
116    tdir2 = self._make_test_dir('test2/testinner/')
117    tdir3 = self._make_test_dir('test3/foo/bar/bas/xxx')
118    files = [
119        os.path.join(tdir1, 'testfile1.py'),
120        os.path.join(tdir2, 'testfile2.py'),
121        os.path.join(tdir3, 'testfile3.py'),
122    ]
123    _touch_files(files)
124
125    self.assertEqual(
126        sorted(
127            file_resources.GetCommandLineFiles(
128                [self.test_tmpdir], recursive=True, exclude=None)),
129        sorted(files))
130
131  def test_recursive_find_in_dir_with_exclude(self):
132    tdir1 = self._make_test_dir('test1')
133    tdir2 = self._make_test_dir('test2/testinner/')
134    tdir3 = self._make_test_dir('test3/foo/bar/bas/xxx')
135    files = [
136        os.path.join(tdir1, 'testfile1.py'),
137        os.path.join(tdir2, 'testfile2.py'),
138        os.path.join(tdir3, 'testfile3.py'),
139    ]
140    _touch_files(files)
141
142    self.assertEqual(
143        sorted(
144            file_resources.GetCommandLineFiles(
145                [self.test_tmpdir], recursive=True, exclude=['*test*3.py'])),
146        sorted([
147            os.path.join(tdir1, 'testfile1.py'),
148            os.path.join(tdir2, 'testfile2.py'),
149        ]))
150
151  def test_find_with_excluded_hidden_dirs(self):
152    tdir1 = self._make_test_dir('.test1')
153    tdir2 = self._make_test_dir('test_2')
154    tdir3 = self._make_test_dir('test.3')
155    files = [
156        os.path.join(tdir1, 'testfile1.py'),
157        os.path.join(tdir2, 'testfile2.py'),
158        os.path.join(tdir3, 'testfile3.py'),
159    ]
160    _touch_files(files)
161
162    actual = file_resources.GetCommandLineFiles(
163        [self.test_tmpdir], recursive=True, exclude=['*.test1*'])
164
165    self.assertEqual(
166        sorted(actual),
167        sorted([
168            os.path.join(tdir2, 'testfile2.py'),
169            os.path.join(tdir3, 'testfile3.py'),
170        ]))
171
172  def test_find_with_excluded_hidden_dirs_relative(self):
173    """Test find with excluded hidden dirs.
174
175    A regression test against a specific case where a hidden directory (one
176    beginning with a period) is being excluded, but it is also an immediate
177    child of the current directory which has been specified in a relative
178    manner.
179
180    At its core, the bug has to do with overzelous stripping of "./foo" so that
181    it removes too much from "./.foo" .
182    """
183    tdir1 = self._make_test_dir('.test1')
184    tdir2 = self._make_test_dir('test_2')
185    tdir3 = self._make_test_dir('test.3')
186    files = [
187        os.path.join(tdir1, 'testfile1.py'),
188        os.path.join(tdir2, 'testfile2.py'),
189        os.path.join(tdir3, 'testfile3.py'),
190    ]
191    _touch_files(files)
192
193    # We must temporarily change the current directory, so that we test against
194    # patterns like ./.test1/file instead of /tmp/foo/.test1/file
195    with _restore_working_dir():
196
197      os.chdir(self.test_tmpdir)
198      actual = file_resources.GetCommandLineFiles(
199          [os.path.relpath(self.test_tmpdir)],
200          recursive=True,
201          exclude=['*.test1*'])
202
203      self.assertEqual(
204          sorted(actual),
205          sorted([
206              os.path.join(
207                  os.path.relpath(self.test_tmpdir), os.path.basename(tdir2),
208                  'testfile2.py'),
209              os.path.join(
210                  os.path.relpath(self.test_tmpdir), os.path.basename(tdir3),
211                  'testfile3.py'),
212          ]))
213
214  def test_find_with_excluded_dirs(self):
215    tdir1 = self._make_test_dir('test1')
216    tdir2 = self._make_test_dir('test2/testinner/')
217    tdir3 = self._make_test_dir('test3/foo/bar/bas/xxx')
218    files = [
219        os.path.join(tdir1, 'testfile1.py'),
220        os.path.join(tdir2, 'testfile2.py'),
221        os.path.join(tdir3, 'testfile3.py'),
222    ]
223    _touch_files(files)
224
225    os.chdir(self.test_tmpdir)
226
227    found = sorted(
228        file_resources.GetCommandLineFiles(
229            ['test1', 'test2', 'test3'],
230            recursive=True,
231            exclude=[
232                'test1',
233                'test2/testinner/',
234            ]))
235
236    self.assertEqual(found, ['test3/foo/bar/bas/xxx/testfile3.py'])
237
238    found = sorted(
239        file_resources.GetCommandLineFiles(
240            ['.'], recursive=True, exclude=[
241                'test1',
242                'test3',
243            ]))
244
245    self.assertEqual(found, ['./test2/testinner/testfile2.py'])
246
247  def test_find_with_excluded_current_dir(self):
248    with self.assertRaises(errors.YapfError):
249      file_resources.GetCommandLineFiles([], False, exclude=['./z'])
250
251
252class IsPythonFileTest(unittest.TestCase):
253
254  def setUp(self):
255    self.test_tmpdir = tempfile.mkdtemp()
256
257  def tearDown(self):
258    shutil.rmtree(self.test_tmpdir)
259
260  def test_with_py_extension(self):
261    file1 = os.path.join(self.test_tmpdir, 'testfile1.py')
262    self.assertTrue(file_resources.IsPythonFile(file1))
263
264  def test_empty_without_py_extension(self):
265    file1 = os.path.join(self.test_tmpdir, 'testfile1')
266    self.assertFalse(file_resources.IsPythonFile(file1))
267    file2 = os.path.join(self.test_tmpdir, 'testfile1.rb')
268    self.assertFalse(file_resources.IsPythonFile(file2))
269
270  def test_python_shebang(self):
271    file1 = os.path.join(self.test_tmpdir, 'testfile1')
272    with open(file1, 'w') as f:
273      f.write(u'#!/usr/bin/python\n')
274    self.assertTrue(file_resources.IsPythonFile(file1))
275
276    file2 = os.path.join(self.test_tmpdir, 'testfile2.run')
277    with open(file2, 'w') as f:
278      f.write(u'#! /bin/python2\n')
279    self.assertTrue(file_resources.IsPythonFile(file1))
280
281  def test_with_latin_encoding(self):
282    file1 = os.path.join(self.test_tmpdir, 'testfile1')
283    with py3compat.open_with_encoding(file1, mode='w', encoding='latin-1') as f:
284      f.write(u'#! /bin/python2\n')
285    self.assertTrue(file_resources.IsPythonFile(file1))
286
287  def test_with_invalid_encoding(self):
288    file1 = os.path.join(self.test_tmpdir, 'testfile1')
289    with open(file1, 'w') as f:
290      f.write(u'#! /bin/python2\n')
291      f.write(u'# -*- coding: iso-3-14159 -*-\n')
292    self.assertFalse(file_resources.IsPythonFile(file1))
293
294
295class IsIgnoredTest(unittest.TestCase):
296
297  def test_root_path(self):
298    self.assertTrue(file_resources.IsIgnored('media', ['media']))
299    self.assertFalse(file_resources.IsIgnored('media', ['media/*']))
300
301  def test_sub_path(self):
302    self.assertTrue(file_resources.IsIgnored('media/a', ['*/a']))
303    self.assertTrue(file_resources.IsIgnored('media/b', ['media/*']))
304    self.assertTrue(file_resources.IsIgnored('media/b/c', ['*/*/c']))
305
306  def test_trailing_slash(self):
307    self.assertTrue(file_resources.IsIgnored('z', ['z']))
308    self.assertTrue(file_resources.IsIgnored('z', ['z/']))
309
310
311class BufferedByteStream(object):
312
313  def __init__(self):
314    self.stream = py3compat.BytesIO()
315
316  def getvalue(self):  # pylint: disable=invalid-name
317    return self.stream.getvalue().decode('utf-8')
318
319  @property
320  def buffer(self):
321    return self.stream
322
323
324class WriteReformattedCodeTest(unittest.TestCase):
325
326  @classmethod
327  def setUpClass(cls):
328    cls.test_tmpdir = tempfile.mkdtemp()
329
330  @classmethod
331  def tearDownClass(cls):
332    shutil.rmtree(cls.test_tmpdir)
333
334  def test_write_to_file(self):
335    s = u'foobar\n'
336    with utils.NamedTempFile(dirname=self.test_tmpdir) as (f, fname):
337      file_resources.WriteReformattedCode(
338          fname, s, in_place=True, encoding='utf-8')
339      f.flush()
340
341      with open(fname) as f2:
342        self.assertEqual(f2.read(), s)
343
344  def test_write_to_stdout(self):
345    s = u'foobar'
346    stream = BufferedByteStream() if py3compat.PY3 else py3compat.StringIO()
347    with utils.stdout_redirector(stream):
348      file_resources.WriteReformattedCode(
349          None, s, in_place=False, encoding='utf-8')
350    self.assertEqual(stream.getvalue(), s)
351
352  def test_write_encoded_to_stdout(self):
353    s = '\ufeff# -*- coding: utf-8 -*-\nresult = "passed"\n'  # pylint: disable=anomalous-unicode-escape-in-string
354    stream = BufferedByteStream() if py3compat.PY3 else py3compat.StringIO()
355    with utils.stdout_redirector(stream):
356      file_resources.WriteReformattedCode(
357          None, s, in_place=False, encoding='utf-8')
358    self.assertEqual(stream.getvalue(), s)
359
360
361if __name__ == '__main__':
362  unittest.main()
363