1#!/usr/bin/env python3
2"""A script to generate FileCheck statements for mlir unit tests.
3
4This script is a utility to add FileCheck patterns to an mlir file.
5
6NOTE: The input .mlir is expected to be the output from the parser, not a
7stripped down variant.
8
9Example usage:
10$ generate-test-checks.py foo.mlir
11$ mlir-opt foo.mlir -transformation | generate-test-checks.py
12$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
13$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
14$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'
15
16The script will heuristically generate CHECK/CHECK-LABEL commands for each line
17within the file. By default this script will also try to insert string
18substitution blocks for all SSA value names. If --source file is specified, the
19script will attempt to insert the generated CHECKs to the source file by looking
20for line positions matched by --source_delim_regex.
21
22The script is designed to make adding checks to a test case fast, it is *not*
23designed to be authoritative about what constitutes a good test!
24"""
25
26# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
27# See https://llvm.org/LICENSE.txt for license information.
28# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
29
30import argparse
31import os  # Used to advertise this file's name ("autogenerated_note").
32import re
33import sys
34
35ADVERT = '// NOTE: Assertions have been autogenerated by '
36
37# Regex command to match an SSA identifier.
38SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*'
39SSA_RE = re.compile(SSA_RE_STR)
40
41
42# Class used to generate and manage string substitution blocks for SSA value
43# names.
44class SSAVariableNamer:
45
46  def __init__(self):
47    self.scopes = []
48    self.name_counter = 0
49
50  # Generate a substitution name for the given ssa value name.
51  def generate_name(self, ssa_name):
52    variable = 'VAL_' + str(self.name_counter)
53    self.name_counter += 1
54    self.scopes[-1][ssa_name] = variable
55    return variable
56
57  # Push a new variable name scope.
58  def push_name_scope(self):
59    self.scopes.append({})
60
61  # Pop the last variable name scope.
62  def pop_name_scope(self):
63    self.scopes.pop()
64
65  # Return the level of nesting (number of pushed scopes).
66  def num_scopes(self):
67    return len(self.scopes)
68
69  # Reset the counter.
70  def clear_counter(self):
71    self.name_counter = 0
72
73
74# Process a line of input that has been split at each SSA identifier '%'.
75def process_line(line_chunks, variable_namer):
76  output_line = ''
77
78  # Process the rest that contained an SSA value name.
79  for chunk in line_chunks:
80    m = SSA_RE.match(chunk)
81    ssa_name = m.group(0)
82
83    # Check if an existing variable exists for this name.
84    variable = None
85    for scope in variable_namer.scopes:
86      variable = scope.get(ssa_name)
87      if variable is not None:
88        break
89
90    # If one exists, then output the existing name.
91    if variable is not None:
92      output_line += '%[[' + variable + ']]'
93    else:
94      # Otherwise, generate a new variable.
95      variable = variable_namer.generate_name(ssa_name)
96      output_line += '%[[' + variable + ':.*]]'
97
98    # Append the non named group.
99    output_line += chunk[len(ssa_name):]
100
101  return output_line.rstrip() + '\n'
102
103
104# Process the source file lines. The source file doesn't have to be .mlir.
105def process_source_lines(source_lines, note, args):
106  source_split_re = re.compile(args.source_delim_regex)
107
108  source_segments = [[]]
109  for line in source_lines:
110    # Remove previous note.
111    if line == note:
112      continue
113    # Remove previous CHECK lines.
114    if line.find(args.check_prefix) != -1:
115      continue
116    # Segment the file based on --source_delim_regex.
117    if source_split_re.search(line):
118      source_segments.append([])
119
120    source_segments[-1].append(line + '\n')
121  return source_segments
122
123
124# Pre-process a line of input to remove any character sequences that will be
125# problematic with FileCheck.
126def preprocess_line(line):
127  # Replace any double brackets, '[[' with escaped replacements. '[['
128  # corresponds to variable names in FileCheck.
129  output_line = line.replace('[[', '{{\\[\\[}}')
130
131  # Replace any single brackets that are followed by an SSA identifier, the
132  # identifier will be replace by a variable; Creating the same situation as
133  # above.
134  output_line = output_line.replace('[%', '{{\\[}}%')
135
136  return output_line
137
138
139def main():
140  parser = argparse.ArgumentParser(
141      description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
142  parser.add_argument(
143      '--check-prefix', default='CHECK', help='Prefix to use from check file.')
144  parser.add_argument(
145      '-o',
146      '--output',
147      nargs='?',
148      type=argparse.FileType('w'),
149      default=None)
150  parser.add_argument(
151      'input',
152      nargs='?',
153      type=argparse.FileType('r'),
154      default=sys.stdin)
155  parser.add_argument(
156      '--source', type=str,
157      help='Print each CHECK chunk before each delimeter line in the source'
158           'file, respectively. The delimeter lines are identified by '
159           '--source_delim_regex.')
160  parser.add_argument('--source_delim_regex', type=str, default='func @')
161  parser.add_argument(
162      '--starts_from_scope', type=int, default=1,
163      help='Omit the top specified level of content. For example, by default '
164           'it omits "module {"')
165  parser.add_argument('-i', '--inplace', action='store_true', default=False)
166
167  args = parser.parse_args()
168
169  # Open the given input file.
170  input_lines = [l.rstrip() for l in args.input]
171  args.input.close()
172
173  # Generate a note used for the generated check file.
174  script_name = os.path.basename(__file__)
175  autogenerated_note = (ADVERT + 'utils/' + script_name)
176
177  source_segments = None
178  if args.source:
179    source_segments = process_source_lines(
180        [l.rstrip() for l in open(args.source, 'r')],
181        autogenerated_note,
182        args
183    )
184
185  if args.inplace:
186    assert args.output is None
187    output = open(args.source, 'w')
188  elif args.output is None:
189    output = sys.stdout
190  else:
191    output = args.output
192
193  output_segments = [[]]
194  # A map containing data used for naming SSA value names.
195  variable_namer = SSAVariableNamer()
196  for input_line in input_lines:
197    if not input_line:
198      continue
199    lstripped_input_line = input_line.lstrip()
200
201    # Lines with blocks begin with a ^. These lines have a trailing comment
202    # that needs to be stripped.
203    is_block = lstripped_input_line[0] == '^'
204    if is_block:
205      input_line = input_line.rsplit('//', 1)[0].rstrip()
206
207    cur_level = variable_namer.num_scopes()
208
209    # If the line starts with a '}', pop the last name scope.
210    if lstripped_input_line[0] == '}':
211      variable_namer.pop_name_scope()
212      cur_level = variable_namer.num_scopes()
213
214    # If the line ends with a '{', push a new name scope.
215    if input_line[-1] == '{':
216      variable_namer.push_name_scope()
217      if cur_level == args.starts_from_scope:
218        output_segments.append([])
219
220    # Omit lines at the near top level e.g. "module {".
221    if cur_level < args.starts_from_scope:
222      continue
223
224    if len(output_segments[-1]) == 0:
225      variable_namer.clear_counter()
226
227    # Preprocess the input to remove any sequences that may be problematic with
228    # FileCheck.
229    input_line = preprocess_line(input_line)
230
231    # Split the line at the each SSA value name.
232    ssa_split = input_line.split('%')
233
234    # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
235    if len(output_segments[-1]) != 0 or not ssa_split[0]:
236      output_line = '// ' + args.check_prefix + ': '
237      # Pad to align with the 'LABEL' statements.
238      output_line += (' ' * len('-LABEL'))
239
240      # Output the first line chunk that does not contain an SSA name.
241      output_line += ssa_split[0]
242
243      # Process the rest of the input line.
244      output_line += process_line(ssa_split[1:], variable_namer)
245
246    else:
247      # Output the first line chunk that does not contain an SSA name for the
248      # label.
249      output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n'
250
251      # Process the rest of the input line on separate check lines.
252      for argument in ssa_split[1:]:
253        output_line += '// ' + args.check_prefix + '-SAME:  '
254
255        # Pad to align with the original position in the line.
256        output_line += ' ' * len(ssa_split[0])
257
258        # Process the rest of the line.
259        output_line += process_line([argument], variable_namer)
260
261    # Append the output line.
262    output_segments[-1].append(output_line)
263
264  output.write(autogenerated_note + '\n')
265
266  # Write the output.
267  if source_segments:
268    assert len(output_segments) == len(source_segments)
269    for check_segment, source_segment in zip(output_segments, source_segments):
270      for line in check_segment:
271        output.write(line)
272      for line in source_segment:
273        output.write(line)
274  else:
275    for segment in output_segments:
276      output.write('\n')
277      for output_line in segment:
278        output.write(output_line)
279    output.write('\n')
280  output.close()
281
282
283if __name__ == '__main__':
284  main()
285