1#!/usr/bin/env python2
2# Copyright 2013 Google Inc. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15import cPickle
16import errno
17import gzip
18import multiprocessing
19import optparse
20import os
21import signal
22import subprocess
23import sys
24import tempfile
25import thread
26import threading
27import time
28import zlib
29
30# An object that catches SIGINT sent to the Python process and notices
31# if processes passed to wait() die by SIGINT (we need to look for
32# both of those cases, because pressing Ctrl+C can result in either
33# the main process or one of the subprocesses getting the signal).
34#
35# Before a SIGINT is seen, wait(p) will simply call p.wait() and
36# return the result. Once a SIGINT has been seen (in the main process
37# or a subprocess, including the one the current call is waiting for),
38# wait(p) will call p.terminate() and raise ProcessWasInterrupted.
39class SigintHandler(object):
40  class ProcessWasInterrupted(Exception): pass
41  sigint_returncodes = {-signal.SIGINT,  # Unix
42                        -1073741510,     # Windows
43                        }
44  def __init__(self):
45    self.__lock = threading.Lock()
46    self.__processes = set()
47    self.__got_sigint = False
48    signal.signal(signal.SIGINT, self.__sigint_handler)
49  def __on_sigint(self):
50    self.__got_sigint = True
51    while self.__processes:
52      try:
53        self.__processes.pop().terminate()
54      except OSError:
55        pass
56  def __sigint_handler(self, signal_num, frame):
57    with self.__lock:
58      self.__on_sigint()
59  def got_sigint(self):
60    with self.__lock:
61      return self.__got_sigint
62  def wait(self, p):
63    with self.__lock:
64      if self.__got_sigint:
65        p.terminate()
66      self.__processes.add(p)
67    code = p.wait()
68    with self.__lock:
69      self.__processes.discard(p)
70      if code in self.sigint_returncodes:
71        self.__on_sigint()
72      if self.__got_sigint:
73        raise self.ProcessWasInterrupted
74    return code
75sigint_handler = SigintHandler()
76
77# Return the width of the terminal, or None if it couldn't be
78# determined (e.g. because we're not being run interactively).
79def term_width(out):
80  if not out.isatty():
81    return None
82  try:
83    p = subprocess.Popen(["stty", "size"],
84                         stdout=subprocess.PIPE, stderr=subprocess.PIPE)
85    (out, err) = p.communicate()
86    if p.returncode != 0 or err:
87      return None
88    return int(out.split()[1])
89  except (IndexError, OSError, ValueError):
90    return None
91
92# Output transient and permanent lines of text. If several transient
93# lines are written in sequence, the new will overwrite the old. We
94# use this to ensure that lots of unimportant info (tests passing)
95# won't drown out important info (tests failing).
96class Outputter(object):
97  def __init__(self, out_file):
98    self.__out_file = out_file
99    self.__previous_line_was_transient = False
100    self.__width = term_width(out_file)  # Line width, or None if not a tty.
101  def transient_line(self, msg):
102    if self.__width is None:
103      self.__out_file.write(msg + "\n")
104    else:
105      self.__out_file.write("\r" + msg[:self.__width].ljust(self.__width))
106      self.__previous_line_was_transient = True
107  def flush_transient_output(self):
108    if self.__previous_line_was_transient:
109      self.__out_file.write("\n")
110      self.__previous_line_was_transient = False
111  def permanent_line(self, msg):
112    self.flush_transient_output()
113    self.__out_file.write(msg + "\n")
114
115stdout_lock = threading.Lock()
116
117class FilterFormat:
118  if sys.stdout.isatty():
119    # stdout needs to be unbuffered since the output is interactive.
120    sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0)
121
122  out = Outputter(sys.stdout)
123  total_tests = 0
124  finished_tests = 0
125
126  tests = {}
127  outputs = {}
128  failures = []
129
130  def print_test_status(self, last_finished_test, time_ms):
131    self.out.transient_line("[%d/%d] %s (%d ms)"
132                            % (self.finished_tests, self.total_tests,
133                               last_finished_test, time_ms))
134
135  def handle_meta(self, job_id, args):
136    (command, arg) = args.split(' ', 1)
137    if command == "TEST":
138      (binary, test) = arg.split(' ', 1)
139      self.tests[job_id] = (binary, test.strip())
140    elif command == "EXIT":
141      (exit_code, time_ms) = [int(x) for x in arg.split(' ', 1)]
142      self.finished_tests += 1
143      (binary, test) = self.tests[job_id]
144      self.print_test_status(test, time_ms)
145      if exit_code != 0:
146        self.failures.append(self.tests[job_id])
147        with open(self.outputs[job_id]) as f:
148          for line in f.readlines():
149            self.out.permanent_line(line.rstrip())
150        self.out.permanent_line(
151          "[%d/%d] %s returned/aborted with exit code %d (%d ms)"
152          % (self.finished_tests, self.total_tests, test, exit_code, time_ms))
153    elif command == "TESTCNT":
154      self.total_tests = int(arg.split(' ', 1)[1])
155      self.out.transient_line("[0/%d] Running tests..." % self.total_tests)
156
157  def logfile(self, job_id, name):
158    self.outputs[job_id] = name
159
160  def log(self, line):
161    stdout_lock.acquire()
162    (prefix, output) = line.split(' ', 1)
163
164    assert prefix[-1] == ':'
165    self.handle_meta(int(prefix[:-1]), output)
166    stdout_lock.release()
167
168  def end(self):
169    if self.failures:
170      self.out.permanent_line("FAILED TESTS (%d/%d):"
171                              % (len(self.failures), self.total_tests))
172      for (binary, test) in self.failures:
173        self.out.permanent_line(" " + binary + ": " + test)
174    self.out.flush_transient_output()
175
176class RawFormat:
177  def log(self, line):
178    stdout_lock.acquire()
179    sys.stdout.write(line + "\n")
180    sys.stdout.flush()
181    stdout_lock.release()
182  def logfile(self, job_id, name):
183    with open(self.outputs[job_id]) as f:
184      for line in f.readlines():
185        self.log(str(job_id) + '> ' + line.rstrip())
186  def end(self):
187    pass
188
189# Record of test runtimes. Has built-in locking.
190class TestTimes(object):
191  def __init__(self, save_file):
192    "Create new object seeded with saved test times from the given file."
193    self.__times = {}  # (test binary, test name) -> runtime in ms
194
195    # Protects calls to record_test_time(); other calls are not
196    # expected to be made concurrently.
197    self.__lock = threading.Lock()
198
199    try:
200      with gzip.GzipFile(save_file, "rb") as f:
201        times = cPickle.load(f)
202    except (EOFError, IOError, cPickle.UnpicklingError, zlib.error):
203      # File doesn't exist, isn't readable, is malformed---whatever.
204      # Just ignore it.
205      return
206
207    # Discard saved times if the format isn't right.
208    if type(times) is not dict:
209      return
210    for ((test_binary, test_name), runtime) in times.items():
211      if (type(test_binary) is not str or type(test_name) is not str
212          or type(runtime) not in {int, long, type(None)}):
213        return
214
215    self.__times = times
216
217  def get_test_time(self, binary, testname):
218    """Return the last duration for the given test as an integer number of
219    milliseconds, or None if the test failed or if there's no record for it."""
220    return self.__times.get((binary, testname), None)
221
222  def record_test_time(self, binary, testname, runtime_ms):
223    """Record that the given test ran in the specified number of
224    milliseconds. If the test failed, runtime_ms should be None."""
225    with self.__lock:
226      self.__times[(binary, testname)] = runtime_ms
227
228  def write_to_file(self, save_file):
229    "Write all the times to file."
230    try:
231      with open(save_file, "wb") as f:
232        with gzip.GzipFile("", "wb", 9, f) as gzf:
233          cPickle.dump(self.__times, gzf, cPickle.HIGHEST_PROTOCOL)
234    except IOError:
235      pass  # ignore errors---saving the times isn't that important
236
237# Remove additional arguments (anything after --).
238additional_args = []
239
240for i in range(len(sys.argv)):
241  if sys.argv[i] == '--':
242    additional_args = sys.argv[i+1:]
243    sys.argv = sys.argv[:i]
244    break
245
246parser = optparse.OptionParser(
247    usage = 'usage: %prog [options] binary [binary ...] -- [additional args]')
248
249parser.add_option('-d', '--output_dir', type='string',
250                  default=os.path.join(tempfile.gettempdir(), "gtest-parallel"),
251                  help='output directory for test logs')
252parser.add_option('-r', '--repeat', type='int', default=1,
253                  help='repeat tests')
254parser.add_option('-w', '--workers', type='int',
255                  default=multiprocessing.cpu_count(),
256                  help='number of workers to spawn')
257parser.add_option('--gtest_color', type='string', default='yes',
258                  help='color output')
259parser.add_option('--gtest_filter', type='string', default='',
260                  help='test filter')
261parser.add_option('--gtest_also_run_disabled_tests', action='store_true',
262                  default=False, help='run disabled tests too')
263parser.add_option('--format', type='string', default='filter',
264                  help='output format (raw,filter)')
265parser.add_option('--print_test_times', action='store_true', default=False,
266                  help='When done, list the run time of each test')
267
268(options, binaries) = parser.parse_args()
269
270if binaries == []:
271  parser.print_usage()
272  sys.exit(1)
273
274logger = RawFormat()
275if options.format == 'raw':
276  pass
277elif options.format == 'filter':
278  logger = FilterFormat()
279else:
280  sys.exit("Unknown output format: " + options.format)
281
282# Find tests.
283save_file = os.path.join(os.path.expanduser("~"), ".gtest-parallel-times")
284times = TestTimes(save_file)
285tests = []
286for test_binary in binaries:
287  command = [test_binary]
288  if options.gtest_also_run_disabled_tests:
289    command += ['--gtest_also_run_disabled_tests']
290
291  list_command = list(command)
292  if options.gtest_filter != '':
293    list_command += ['--gtest_filter=' + options.gtest_filter]
294
295  try:
296    test_list = subprocess.Popen(list_command + ['--gtest_list_tests'],
297                                 stdout=subprocess.PIPE).communicate()[0]
298  except OSError as e:
299    sys.exit("%s: %s" % (test_binary, str(e)))
300
301  command += additional_args
302
303  test_group = ''
304  for line in test_list.split('\n'):
305    if not line.strip():
306      continue
307    if line[0] != " ":
308      test_group = line.strip()
309      continue
310    # Remove comments for parameterized tests and strip whitespace.
311    line = line.split('#')[0].strip()
312    if not line:
313      continue
314
315    test = test_group + line
316    if not options.gtest_also_run_disabled_tests and 'DISABLED_' in test:
317      continue
318    tests.append((times.get_test_time(test_binary, test),
319                  test_binary, test, command))
320
321# Sort tests by falling runtime (with None, which is what we get for
322# new and failing tests, being considered larger than any real
323# runtime).
324tests.sort(reverse=True, key=lambda x: ((1 if x[0] is None else 0), x))
325
326# Repeat tests (-r flag).
327tests *= options.repeat
328test_lock = threading.Lock()
329job_id = 0
330logger.log(str(-1) + ': TESTCNT ' + ' ' + str(len(tests)))
331
332exit_code = 0
333
334# Create directory for test log output.
335try:
336  os.makedirs(options.output_dir)
337except OSError as e:
338  # Ignore errors if this directory already exists.
339  if e.errno != errno.EEXIST or not os.path.isdir(options.output_dir):
340    raise e
341# Remove files from old test runs.
342for logfile in os.listdir(options.output_dir):
343  os.remove(os.path.join(options.output_dir, logfile))
344
345# Run the specified job. Return the elapsed time in milliseconds if
346# the job succeeds, or None if the job fails. (This ensures that
347# failing tests will run first the next time.)
348def run_job((command, job_id, test)):
349  begin = time.time()
350
351  with tempfile.NamedTemporaryFile(dir=options.output_dir, delete=False) as log:
352    sub = subprocess.Popen(command + ['--gtest_filter=' + test] +
353                             ['--gtest_color=' + options.gtest_color],
354                           stdout=log.file,
355                           stderr=log.file)
356    try:
357      code = sigint_handler.wait(sub)
358    except sigint_handler.ProcessWasInterrupted:
359      thread.exit()
360    runtime_ms = int(1000 * (time.time() - begin))
361    logger.logfile(job_id, log.name)
362
363  logger.log("%s: EXIT %s %d" % (job_id, code, runtime_ms))
364  if code == 0:
365    return runtime_ms
366  global exit_code
367  exit_code = code
368  return None
369
370def worker():
371  global job_id
372  while True:
373    job = None
374    test_lock.acquire()
375    if job_id < len(tests):
376      (_, test_binary, test, command) = tests[job_id]
377      logger.log(str(job_id) + ': TEST ' + test_binary + ' ' + test)
378      job = (command, job_id, test)
379    job_id += 1
380    test_lock.release()
381    if job is None:
382      return
383    times.record_test_time(test_binary, test, run_job(job))
384
385def start_daemon(func):
386  t = threading.Thread(target=func)
387  t.daemon = True
388  t.start()
389  return t
390
391workers = [start_daemon(worker) for i in range(options.workers)]
392
393[t.join() for t in workers]
394logger.end()
395times.write_to_file(save_file)
396if options.print_test_times:
397  ts = sorted((times.get_test_time(test_binary, test), test_binary, test)
398              for (_, test_binary, test, _) in tests
399              if times.get_test_time(test_binary, test) is not None)
400  for (time_ms, test_binary, test) in ts:
401    print "%8s %s" % ("%dms" % time_ms, test)
402sys.exit(-signal.SIGINT if sigint_handler.got_sigint() else exit_code)
403