1# Copyright 2014 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import fnmatch
16import shlex
17import unittest
18
19
20def convert_newlines(msg):
21    """A routine that mimics Python's universal_newlines conversion."""
22    return msg.replace('\r\n', '\n').replace('\r', '\n')
23
24
25class TestCase(unittest.TestCase):
26    child = None
27    context = None
28    maxDiff = 80 * 66
29
30
31class MainTestCase(TestCase):
32    prog = None
33    files_to_ignore = []
34
35    def _write_files(self, host, files):
36        for path, contents in list(files.items()):
37            dirname = host.dirname(path)
38            if dirname:
39                host.maybe_mkdir(dirname)
40            host.write_text_file(path, contents)
41
42    def _read_files(self, host, tmpdir):
43        out_files = {}
44        for f in host.files_under(tmpdir):
45            if any(fnmatch.fnmatch(f, pat) for pat in self.files_to_ignore):
46                continue
47            key = f.replace(host.sep, '/')
48            out_files[key] = host.read_text_file(tmpdir, f)
49        return out_files
50
51    def assert_files(self, expected_files, actual_files, files_to_ignore=None):
52        files_to_ignore = files_to_ignore or []
53        for k, v in expected_files.items():
54            self.assertMultiLineEqual(expected_files[k], v)
55        interesting_files = set(actual_files.keys()).difference(
56            files_to_ignore)
57        self.assertEqual(interesting_files, set(expected_files.keys()))
58
59    def make_host(self):
60        # If we are ever called by unittest directly, and not through typ,
61        # this will probably fail.
62        assert self.child
63        return self.child.host
64
65    def call(self, host, argv, stdin, env):
66        return host.call(argv, stdin=stdin, env=env)
67
68    def check(self, cmd=None, stdin=None, env=None, aenv=None, files=None,
69              prog=None, cwd=None, host=None,
70              ret=None, out=None, rout=None, err=None, rerr=None,
71              exp_files=None,
72              files_to_ignore=None, universal_newlines=True):
73        # Too many arguments pylint: disable=R0913
74        prog = prog or self.prog or []
75        host = host or self.make_host()
76        argv = shlex.split(cmd) if isinstance(cmd, str) else cmd or []
77
78        tmpdir = None
79        orig_wd = host.getcwd()
80        try:
81            tmpdir = host.mkdtemp()
82            host.chdir(tmpdir)
83            if files:
84                self._write_files(host, files)
85            if cwd:
86                host.chdir(cwd)
87            if aenv:
88                env = host.env.copy()
89                env.update(aenv)
90
91            if self.child.debugger:  # pragma: no cover
92                host.print_('')
93                host.print_('cd %s' % tmpdir, stream=host.stdout.stream)
94                host.print_(' '.join(prog + argv), stream=host.stdout.stream)
95                host.print_('')
96                import pdb
97                dbg = pdb.Pdb(stdout=host.stdout.stream)
98                dbg.set_trace()
99
100            result = self.call(host, prog + argv, stdin=stdin, env=env)
101
102            actual_ret, actual_out, actual_err = result
103            actual_files = self._read_files(host, tmpdir)
104        finally:
105            host.chdir(orig_wd)
106            if tmpdir:
107                host.rmtree(tmpdir)
108
109        if universal_newlines:
110            actual_out = convert_newlines(actual_out)
111        if universal_newlines:
112            actual_err = convert_newlines(actual_err)
113
114        if ret is not None:
115            self.assertEqual(ret, actual_ret)
116        if out is not None:
117            self.assertMultiLineEqual(out, actual_out)
118        if rout is not None:
119            self.assertRegexpMatches(actual_out, rout)
120        if err is not None:
121            self.assertMultiLineEqual(err, actual_err)
122        if rerr is not None:
123            self.assertRegexpMatches(actual_err, rerr)
124        if exp_files:
125            self.assert_files(exp_files, actual_files, files_to_ignore)
126
127        return actual_ret, actual_out, actual_err, actual_files
128