1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import absolute_import
16
17import collections
18import multiprocessing
19import os
20import select
21import signal
22import sys
23import tempfile
24import threading
25import time
26import unittest
27import uuid
28
29import six
30from six import moves
31
32from tests import _loader
33from tests import _result
34
35
36class CaptureFile(object):
37    """A context-managed file to redirect output to a byte array.
38
39  Use by invoking `start` (`__enter__`) and at some point invoking `stop`
40  (`__exit__`). At any point after the initial call to `start` call `output` to
41  get the current redirected output. Note that we don't currently use file
42  locking, so calling `output` between calls to `start` and `stop` may muddle
43  the result (you should only be doing this during a Python-handled interrupt as
44  a last ditch effort to provide output to the user).
45
46  Attributes:
47    _redirected_fd (int): File descriptor of file to redirect writes from.
48    _saved_fd (int): A copy of the original value of the redirected file
49      descriptor.
50    _into_file (TemporaryFile or None): File to which writes are redirected.
51      Only non-None when self is started.
52  """
53
54    def __init__(self, fd):
55        self._redirected_fd = fd
56        self._saved_fd = os.dup(self._redirected_fd)
57        self._into_file = None
58
59    def output(self):
60        """Get all output from the redirected-to file if it exists."""
61        if self._into_file:
62            self._into_file.seek(0)
63            return bytes(self._into_file.read())
64        else:
65            return bytes()
66
67    def start(self):
68        """Start redirection of writes to the file descriptor."""
69        self._into_file = tempfile.TemporaryFile()
70        os.dup2(self._into_file.fileno(), self._redirected_fd)
71
72    def stop(self):
73        """Stop redirection of writes to the file descriptor."""
74        # n.b. this dup2 call auto-closes self._redirected_fd
75        os.dup2(self._saved_fd, self._redirected_fd)
76
77    def write_bypass(self, value):
78        """Bypass the redirection and write directly to the original file.
79
80    Arguments:
81      value (str): What to write to the original file.
82    """
83        if six.PY3 and not isinstance(value, six.binary_type):
84            value = bytes(value, 'ascii')
85        if self._saved_fd is None:
86            os.write(self._redirect_fd, value)
87        else:
88            os.write(self._saved_fd, value)
89
90    def __enter__(self):
91        self.start()
92        return self
93
94    def __exit__(self, type, value, traceback):
95        self.stop()
96
97    def close(self):
98        """Close any resources used by self not closed by stop()."""
99        os.close(self._saved_fd)
100
101
102class AugmentedCase(collections.namedtuple('AugmentedCase', ['case', 'id'])):
103    """A test case with a guaranteed unique externally specified identifier.
104
105  Attributes:
106    case (unittest.TestCase): TestCase we're decorating with an additional
107      identifier.
108    id (object): Any identifier that may be considered 'unique' for testing
109      purposes.
110  """
111
112    def __new__(cls, case, id=None):
113        if id is None:
114            id = uuid.uuid4()
115        return super(cls, AugmentedCase).__new__(cls, case, id)
116
117
118class Runner(object):
119
120    def __init__(self):
121        self._skipped_tests = []
122
123    def skip_tests(self, tests):
124        self._skipped_tests = tests
125
126    def run(self, suite):
127        """See setuptools' test_runner setup argument for information."""
128        # only run test cases with id starting with given prefix
129        testcase_filter = os.getenv('GRPC_PYTHON_TESTRUNNER_FILTER')
130        filtered_cases = []
131        for case in _loader.iterate_suite_cases(suite):
132            if not testcase_filter or case.id().startswith(testcase_filter):
133                filtered_cases.append(case)
134
135        # Ensure that every test case has no collision with any other test case in
136        # the augmented results.
137        augmented_cases = [
138            AugmentedCase(case, uuid.uuid4()) for case in filtered_cases
139        ]
140        case_id_by_case = dict((augmented_case.case, augmented_case.id)
141                               for augmented_case in augmented_cases)
142        result_out = moves.cStringIO()
143        result = _result.TerminalResult(
144            result_out, id_map=lambda case: case_id_by_case[case])
145        stdout_pipe = CaptureFile(sys.stdout.fileno())
146        stderr_pipe = CaptureFile(sys.stderr.fileno())
147        kill_flag = [False]
148
149        def sigint_handler(signal_number, frame):
150            if signal_number == signal.SIGINT:
151                kill_flag[0] = True  # Python 2.7 not having 'local'... :-(
152            signal.signal(signal_number, signal.SIG_DFL)
153
154        def fault_handler(signal_number, frame):
155            stdout_pipe.write_bypass(
156                'Received fault signal {}\nstdout:\n{}\n\nstderr:{}\n'.format(
157                    signal_number, stdout_pipe.output(), stderr_pipe.output()))
158            os._exit(1)
159
160        def check_kill_self():
161            if kill_flag[0]:
162                stdout_pipe.write_bypass('Stopping tests short...')
163                result.stopTestRun()
164                stdout_pipe.write_bypass(result_out.getvalue())
165                stdout_pipe.write_bypass('\ninterrupted stdout:\n{}\n'.format(
166                    stdout_pipe.output().decode()))
167                stderr_pipe.write_bypass('\ninterrupted stderr:\n{}\n'.format(
168                    stderr_pipe.output().decode()))
169                os._exit(1)
170
171        def try_set_handler(name, handler):
172            try:
173                signal.signal(getattr(signal, name), handler)
174            except AttributeError:
175                pass
176
177        try_set_handler('SIGINT', sigint_handler)
178        try_set_handler('SIGSEGV', fault_handler)
179        try_set_handler('SIGBUS', fault_handler)
180        try_set_handler('SIGABRT', fault_handler)
181        try_set_handler('SIGFPE', fault_handler)
182        try_set_handler('SIGILL', fault_handler)
183        # Sometimes output will lag after a test has successfully finished; we
184        # ignore such writes to our pipes.
185        try_set_handler('SIGPIPE', signal.SIG_IGN)
186
187        # Run the tests
188        result.startTestRun()
189        for augmented_case in augmented_cases:
190            for skipped_test in self._skipped_tests:
191                if skipped_test in augmented_case.case.id():
192                    break
193            else:
194                sys.stdout.write('Running       {}\n'.format(
195                    augmented_case.case.id()))
196                sys.stdout.flush()
197                case_thread = threading.Thread(
198                    target=augmented_case.case.run, args=(result,))
199                try:
200                    with stdout_pipe, stderr_pipe:
201                        case_thread.start()
202                        while case_thread.is_alive():
203                            check_kill_self()
204                            time.sleep(0)
205                        case_thread.join()
206                except:
207                    # re-raise the exception after forcing the with-block to end
208                    raise
209                result.set_output(augmented_case.case, stdout_pipe.output(),
210                                  stderr_pipe.output())
211                sys.stdout.write(result_out.getvalue())
212                sys.stdout.flush()
213                result_out.truncate(0)
214                check_kill_self()
215        result.stopTestRun()
216        stdout_pipe.close()
217        stderr_pipe.close()
218
219        # Report results
220        sys.stdout.write(result_out.getvalue())
221        sys.stdout.flush()
222        signal.signal(signal.SIGINT, signal.SIG_DFL)
223        with open('report.xml', 'wb') as report_xml_file:
224            _result.jenkins_junit_xml(result).write(report_xml_file)
225        return result
226