1#!/usr/bin/env python3
2
3"""A test case update script.
4
5This script is a utility to update LLVM 'llvm-mca' based test cases with new
6FileCheck patterns.
7"""
8
9import argparse
10from collections import defaultdict
11import glob
12import os
13import sys
14import warnings
15
16from UpdateTestChecks import common
17
18
19COMMENT_CHAR = '#'
20ADVERT_PREFIX = '{} NOTE: Assertions have been autogenerated by '.format(
21    COMMENT_CHAR)
22ADVERT = '{}utils/{}'.format(ADVERT_PREFIX, os.path.basename(__file__))
23
24
25class Error(Exception):
26  """ Generic Error that can be raised without printing a traceback.
27  """
28  pass
29
30
31def _warn(msg):
32  """ Log a user warning to stderr.
33  """
34  warnings.warn(msg, Warning, stacklevel=2)
35
36
37def _configure_warnings(args):
38  warnings.resetwarnings()
39  if args.w:
40    warnings.simplefilter('ignore')
41  if args.Werror:
42    warnings.simplefilter('error')
43
44
45def _showwarning(message, category, filename, lineno, file=None, line=None):
46  """ Version of warnings.showwarning that won't attempt to print out the
47      line at the location of the warning if the line text is not explicitly
48      specified.
49  """
50  if file is None:
51    file = sys.stderr
52  if line is None:
53    line = ''
54  file.write(warnings.formatwarning(message, category, filename, lineno, line))
55
56
57def _parse_args():
58  parser = argparse.ArgumentParser(description=__doc__)
59  parser.add_argument('-w',
60                      action='store_true',
61                      help='suppress warnings')
62  parser.add_argument('-Werror',
63                      action='store_true',
64                      help='promote warnings to errors')
65  parser.add_argument('--llvm-mca-binary',
66                      metavar='<path>',
67                      default='llvm-mca',
68                      help='the binary to use to generate the test case '
69                           '(default: llvm-mca)')
70  parser.add_argument('tests',
71                      metavar='<test-path>',
72                      nargs='+')
73  args = common.parse_commandline_args(parser)
74
75  _configure_warnings(args)
76
77  if not args.llvm_mca_binary:
78    raise Error('--llvm-mca-binary value cannot be empty string')
79
80  if 'llvm-mca' not in os.path.basename(args.llvm_mca_binary):
81    _warn('unexpected binary name: {}'.format(args.llvm_mca_binary))
82
83  return args
84
85
86def _get_run_infos(run_lines, args):
87  run_infos = []
88  for run_line in run_lines:
89    try:
90      (tool_cmd, filecheck_cmd) = tuple([cmd.strip()
91                                        for cmd in run_line.split('|', 1)])
92    except ValueError:
93      _warn('could not split tool and filecheck commands: {}'.format(run_line))
94      continue
95
96    common.verify_filecheck_prefixes(filecheck_cmd)
97    tool_basename = os.path.splitext(os.path.basename(args.llvm_mca_binary))[0]
98
99    if not tool_cmd.startswith(tool_basename + ' '):
100      _warn('skipping non-{} RUN line: {}'.format(tool_basename, run_line))
101      continue
102
103    if not filecheck_cmd.startswith('FileCheck '):
104      _warn('skipping non-FileCheck RUN line: {}'.format(run_line))
105      continue
106
107    tool_cmd_args = tool_cmd[len(tool_basename):].strip()
108    tool_cmd_args = tool_cmd_args.replace('< %s', '').replace('%s', '').strip()
109
110    check_prefixes = [item
111                      for m in common.CHECK_PREFIX_RE.finditer(filecheck_cmd)
112                      for item in m.group(1).split(',')]
113    if not check_prefixes:
114      check_prefixes = ['CHECK']
115
116    run_infos.append((check_prefixes, tool_cmd_args))
117
118  return run_infos
119
120
121def _break_down_block(block_info, common_prefix):
122  """ Given a block_info, see if we can analyze it further to let us break it
123      down by prefix per-line rather than per-block.
124  """
125  texts = block_info.keys()
126  prefixes = list(block_info.values())
127  # Split the lines from each of the incoming block_texts and zip them so that
128  # each element contains the corresponding lines from each text.  E.g.
129  #
130  # block_text_1: A   # line 1
131  #               B   # line 2
132  #
133  # block_text_2: A   # line 1
134  #               C   # line 2
135  #
136  # would become:
137  #
138  # [(A, A),   # line 1
139  #  (B, C)]   # line 2
140  #
141  line_tuples = list(zip(*list((text.splitlines() for text in texts))))
142
143  # To simplify output, we'll only proceed if the very first line of the block
144  # texts is common to each of them.
145  if len(set(line_tuples[0])) != 1:
146    return []
147
148  result = []
149  lresult = defaultdict(list)
150  for i, line in enumerate(line_tuples):
151    if len(set(line)) == 1:
152      # We're about to output a line with the common prefix.  This is a sync
153      # point so flush any batched-up lines one prefix at a time to the output
154      # first.
155      for prefix in sorted(lresult):
156        result.extend(lresult[prefix])
157      lresult = defaultdict(list)
158
159      # The line is common to each block so output with the common prefix.
160      result.append((common_prefix, line[0]))
161    else:
162      # The line is not common to each block, or we don't have a common prefix.
163      # If there are no prefixes available, warn and bail out.
164      if not prefixes[0]:
165        _warn('multiple lines not disambiguated by prefixes:\n{}\n'
166              'Some blocks may be skipped entirely as a result.'.format(
167                  '\n'.join('  - {}'.format(l) for l in line)))
168        return []
169
170      # Iterate through the line from each of the blocks and add the line with
171      # the corresponding prefix to the current batch of results so that we can
172      # later output them per-prefix.
173      for i, l in enumerate(line):
174        for prefix in prefixes[i]:
175          lresult[prefix].append((prefix, l))
176
177  # Flush any remaining batched-up lines one prefix at a time to the output.
178  for prefix in sorted(lresult):
179    result.extend(lresult[prefix])
180  return result
181
182
183def _get_useful_prefix_info(run_infos):
184  """ Given the run_infos, calculate any prefixes that are common to every one,
185      and the length of the longest prefix string.
186  """
187  try:
188    all_sets = [set(s) for s in list(zip(*run_infos))[0]]
189    common_to_all = set.intersection(*all_sets)
190    longest_prefix_len = max(len(p) for p in set.union(*all_sets))
191  except IndexError:
192    common_to_all = []
193    longest_prefix_len = 0
194  else:
195    if len(common_to_all) > 1:
196      _warn('Multiple prefixes common to all RUN lines: {}'.format(
197          common_to_all))
198    if common_to_all:
199      common_to_all = sorted(common_to_all)[0]
200  return common_to_all, longest_prefix_len
201
202
203def _align_matching_blocks(all_blocks, farthest_indexes):
204  """ Some sub-sequences of blocks may be common to multiple lists of blocks,
205      but at different indexes in each one.
206
207      For example, in the following case, A,B,E,F, and H are common to both
208      sets, but only A and B would be identified as such due to the indexes
209      matching:
210
211      index | 0 1 2 3 4 5 6
212      ------+--------------
213      setA  | A B C D E F H
214      setB  | A B E F G H
215
216      This function attempts to align the indexes of matching blocks by
217      inserting empty blocks into the block list. With this approach, A, B, E,
218      F, and H would now be able to be identified as matching blocks:
219
220      index | 0 1 2 3 4 5 6 7
221      ------+----------------
222      setA  | A B C D E F   H
223      setB  | A B     E F G H
224  """
225
226  # "Farthest block analysis": essentially, iterate over all blocks and find
227  # the highest index into a block list for the first instance of each block.
228  # This is relatively expensive, but we're dealing with small numbers of
229  # blocks so it doesn't make a perceivable difference to user time.
230  for blocks in all_blocks.values():
231    for block in blocks:
232      if not block:
233        continue
234
235      index = blocks.index(block)
236
237      if index > farthest_indexes[block]:
238        farthest_indexes[block] = index
239
240  # Use the results of the above analysis to identify any blocks that can be
241  # shunted along to match the farthest index value.
242  for blocks in all_blocks.values():
243    for index, block in enumerate(blocks):
244      if not block:
245        continue
246
247      changed = False
248      # If the block has not already been subject to alignment (i.e. if the
249      # previous block is not empty) then insert empty blocks until the index
250      # matches the farthest index identified for that block.
251      if (index > 0) and blocks[index - 1]:
252        while(index < farthest_indexes[block]):
253          blocks.insert(index, '')
254          index += 1
255          changed = True
256
257      if changed:
258        # Bail out.  We'll need to re-do the farthest block analysis now that
259        # we've inserted some blocks.
260        return True
261
262  return False
263
264
265def _get_block_infos(run_infos, test_path, args, common_prefix):  # noqa
266  """ For each run line, run the tool with the specified args and collect the
267      output. We use the concept of 'blocks' for uniquing, where a block is
268      a series of lines of text with no more than one newline character between
269      each one.  For example:
270
271      This
272      is
273      one
274      block
275
276      This is
277      another block
278
279      This is yet another block
280
281      We then build up a 'block_infos' structure containing a dict where the
282      text of each block is the key and a list of the sets of prefixes that may
283      generate that particular block.  This then goes through a series of
284      transformations to minimise the amount of CHECK lines that need to be
285      written by taking advantage of common prefixes.
286  """
287
288  def _block_key(tool_args, prefixes):
289    """ Get a hashable key based on the current tool_args and prefixes.
290    """
291    return ' '.join([tool_args] + prefixes)
292
293  all_blocks = {}
294  max_block_len = 0
295
296  # A cache of the furthest-back position in any block list of the first
297  # instance of each block, indexed by the block itself.
298  farthest_indexes = defaultdict(int)
299
300  # Run the tool for each run line to generate all of the blocks.
301  for prefixes, tool_args in run_infos:
302    key = _block_key(tool_args, prefixes)
303    raw_tool_output = common.invoke_tool(args.llvm_mca_binary,
304                                         tool_args,
305                                         test_path)
306
307    # Replace any lines consisting of purely whitespace with empty lines.
308    raw_tool_output = '\n'.join(line if line.strip() else ''
309                                for line in raw_tool_output.splitlines())
310
311    # Split blocks, stripping all trailing whitespace, but keeping preceding
312    # whitespace except for newlines so that columns will line up visually.
313    all_blocks[key] = [b.lstrip('\n').rstrip()
314                       for b in raw_tool_output.split('\n\n')]
315    max_block_len = max(max_block_len, len(all_blocks[key]))
316
317    # Attempt to align matching blocks until no more changes can be made.
318    made_changes = True
319    while made_changes:
320      made_changes = _align_matching_blocks(all_blocks, farthest_indexes)
321
322  # If necessary, pad the lists of blocks with empty blocks so that they are
323  # all the same length.
324  for key in all_blocks:
325    len_to_pad = max_block_len - len(all_blocks[key])
326    all_blocks[key] += [''] * len_to_pad
327
328  # Create the block_infos structure where it is a nested dict in the form of:
329  # block number -> block text -> list of prefix sets
330  block_infos = defaultdict(lambda: defaultdict(list))
331  for prefixes, tool_args in run_infos:
332    key = _block_key(tool_args, prefixes)
333    for block_num, block_text in enumerate(all_blocks[key]):
334      block_infos[block_num][block_text].append(set(prefixes))
335
336  # Now go through the block_infos structure and attempt to smartly prune the
337  # number of prefixes per block to the minimal set possible to output.
338  for block_num in range(len(block_infos)):
339    # When there are multiple block texts for a block num, remove any
340    # prefixes that are common to more than one of them.
341    # E.g. [ [{ALL,FOO}] , [{ALL,BAR}] ] -> [ [{FOO}] , [{BAR}] ]
342    all_sets = [s for s in block_infos[block_num].values()]
343    pruned_sets = []
344
345    for i, setlist in enumerate(all_sets):
346      other_set_values = set([elem for j, setlist2 in enumerate(all_sets)
347                              for set_ in setlist2 for elem in set_
348                              if i != j])
349      pruned_sets.append([s - other_set_values for s in setlist])
350
351    for i, block_text in enumerate(block_infos[block_num]):
352
353      # When a block text matches multiple sets of prefixes, try removing any
354      # prefixes that aren't common to all of them.
355      # E.g. [ {ALL,FOO} , {ALL,BAR} ] -> [{ALL}]
356      common_values = set.intersection(*pruned_sets[i])
357      if common_values:
358        pruned_sets[i] = [common_values]
359
360      # Everything should be uniqued as much as possible by now.  Apply the
361      # newly pruned sets to the block_infos structure.
362      # If there are any blocks of text that still match multiple prefixes,
363      # output a warning.
364      current_set = set()
365      for s in pruned_sets[i]:
366        s = sorted(list(s))
367        if s:
368          current_set.add(s[0])
369          if len(s) > 1:
370            _warn('Multiple prefixes generating same output: {} '
371                  '(discarding {})'.format(','.join(s), ','.join(s[1:])))
372
373      if block_text and not current_set:
374        raise Error(
375          'block not captured by existing prefixes:\n\n{}'.format(block_text))
376      block_infos[block_num][block_text] = sorted(list(current_set))
377
378    # If we have multiple block_texts, try to break them down further to avoid
379    # the case where we have very similar block_texts repeated after each
380    # other.
381    if common_prefix and len(block_infos[block_num]) > 1:
382      # We'll only attempt this if each of the block_texts have the same number
383      # of lines as each other.
384      same_num_Lines = (len(set(len(k.splitlines())
385                                for k in block_infos[block_num].keys())) == 1)
386      if same_num_Lines:
387        breakdown = _break_down_block(block_infos[block_num], common_prefix)
388        if breakdown:
389          block_infos[block_num] = breakdown
390
391  return block_infos
392
393
394def _write_block(output, block, not_prefix_set, common_prefix, prefix_pad):
395  """ Write an individual block, with correct padding on the prefixes.
396      Returns a set of all of the prefixes that it has written.
397  """
398  end_prefix = ':     '
399  previous_prefix = None
400  num_lines_of_prefix = 0
401  written_prefixes = set()
402
403  for prefix, line in block:
404    if prefix in not_prefix_set:
405      _warn('not writing for prefix {0} due to presence of "{0}-NOT:" '
406            'in input file.'.format(prefix))
407      continue
408
409    # If the previous line isn't already blank and we're writing more than one
410    # line for the current prefix output a blank line first, unless either the
411    # current of previous prefix is common to all.
412    num_lines_of_prefix += 1
413    if prefix != previous_prefix:
414      if output and output[-1]:
415        if num_lines_of_prefix > 1 or any(p == common_prefix
416                                          for p in (prefix, previous_prefix)):
417          output.append('')
418      num_lines_of_prefix = 0
419      previous_prefix = prefix
420
421    written_prefixes.add(prefix)
422    output.append(
423        '{} {}{}{} {}'.format(COMMENT_CHAR,
424                              prefix,
425                              end_prefix,
426                              ' ' * (prefix_pad - len(prefix)),
427                              line).rstrip())
428    end_prefix = '-NEXT:'
429
430  output.append('')
431  return written_prefixes
432
433
434def _write_output(test_path, input_lines, prefix_list, block_infos,  # noqa
435                  args, common_prefix, prefix_pad):
436  prefix_set = set([prefix for prefixes, _ in prefix_list
437                    for prefix in prefixes])
438  not_prefix_set = set()
439
440  output_lines = []
441  for input_line in input_lines:
442    if input_line.startswith(ADVERT_PREFIX):
443      continue
444
445    if input_line.startswith(COMMENT_CHAR):
446      m = common.CHECK_RE.match(input_line)
447      try:
448        prefix = m.group(1)
449      except AttributeError:
450        prefix = None
451
452      if '{}-NOT:'.format(prefix) in input_line:
453        not_prefix_set.add(prefix)
454
455      if prefix not in prefix_set or prefix in not_prefix_set:
456        output_lines.append(input_line)
457        continue
458
459    if common.should_add_line_to_output(input_line, prefix_set):
460      # This input line of the function body will go as-is into the output.
461      # Except make leading whitespace uniform: 2 spaces.
462      input_line = common.SCRUB_LEADING_WHITESPACE_RE.sub(r'  ', input_line)
463
464      # Skip empty lines if the previous output line is also empty.
465      if input_line or output_lines[-1]:
466        output_lines.append(input_line)
467    else:
468      continue
469
470  # Add a blank line before the new checks if required.
471  if len(output_lines) > 0 and output_lines[-1]:
472    output_lines.append('')
473
474  output_check_lines = []
475  used_prefixes = set()
476  for block_num in range(len(block_infos)):
477    if type(block_infos[block_num]) is list:
478      # The block is of the type output from _break_down_block().
479      used_prefixes |= _write_block(output_check_lines,
480                                    block_infos[block_num],
481                                    not_prefix_set,
482                                    common_prefix,
483                                    prefix_pad)
484    else:
485      # _break_down_block() was unable to do do anything so output the block
486      # as-is.
487
488      # Rather than writing out each block as soon we encounter it, save it
489      # indexed by prefix so that we can write all of the blocks out sorted by
490      # prefix at the end.
491      output_blocks = defaultdict(list)
492
493      for block_text in sorted(block_infos[block_num]):
494
495        if not block_text:
496          continue
497
498        lines = block_text.split('\n')
499        for prefix in block_infos[block_num][block_text]:
500          assert prefix not in output_blocks
501          used_prefixes |= _write_block(output_blocks[prefix],
502                                        [(prefix, line) for line in lines],
503                                        not_prefix_set,
504                                        common_prefix,
505                                        prefix_pad)
506
507      for prefix in sorted(output_blocks):
508        output_check_lines.extend(output_blocks[prefix])
509
510  unused_prefixes = (prefix_set - not_prefix_set) - used_prefixes
511  if unused_prefixes:
512    raise Error('unused prefixes: {}'.format(sorted(unused_prefixes)))
513
514  if output_check_lines:
515    output_lines.insert(0, ADVERT)
516    output_lines.extend(output_check_lines)
517
518  # The file should not end with two newlines. It creates unnecessary churn.
519  while len(output_lines) > 0 and output_lines[-1] == '':
520    output_lines.pop()
521
522  if input_lines == output_lines:
523    sys.stderr.write('            [unchanged]\n')
524    return
525  sys.stderr.write('      [{} lines total]\n'.format(len(output_lines)))
526
527  common.debug('Writing', len(output_lines), 'lines to', test_path, '..\n\n')
528
529  with open(test_path, 'wb') as f:
530    f.writelines(['{}\n'.format(l).encode('utf-8') for l in output_lines])
531
532def main():
533  args = _parse_args()
534  test_paths = [test for pattern in args.tests for test in glob.glob(pattern)]
535  for test_path in test_paths:
536    sys.stderr.write('Test: {}\n'.format(test_path))
537
538    # Call this per test. By default each warning will only be written once
539    # per source location. Reset the warning filter so that now each warning
540    # will be written once per source location per test.
541    _configure_warnings(args)
542
543    if not os.path.isfile(test_path):
544      raise Error('could not find test file: {}'.format(test_path))
545
546    with open(test_path) as f:
547      input_lines = [l.rstrip() for l in f]
548
549    run_lines = common.find_run_lines(test_path, input_lines)
550    run_infos = _get_run_infos(run_lines, args)
551    common_prefix, prefix_pad = _get_useful_prefix_info(run_infos)
552    block_infos = _get_block_infos(run_infos, test_path, args, common_prefix)
553    _write_output(test_path,
554                  input_lines,
555                  run_infos,
556                  block_infos,
557                  args,
558                  common_prefix,
559                  prefix_pad)
560
561  return 0
562
563
564if __name__ == '__main__':
565  try:
566    warnings.showwarning = _showwarning
567    sys.exit(main())
568  except Error as e:
569    sys.stdout.write('error: {}\n'.format(e))
570    sys.exit(1)
571