1#!/usr/bin/env python
2# Copyright 2020 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import math
10import os
11import re
12import sys
13import yaml
14
15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16from primes import next_prime
17import xngen
18import xnncommon
19
20
21parser = argparse.ArgumentParser(description='GAvgPool microkernel test generator')
22parser.add_argument("-s", "--spec", metavar="FILE", required=True,
23                    help="Specification (YAML) file")
24parser.add_argument("-o", "--output", metavar="FILE", required=True,
25                    help='Output (C++ source) file')
26parser.set_defaults(defines=list())
27
28
29def split_ukernel_name(name):
30  match = re.match(r"^xnn_(qs8|qu8|f16|f32)_[p]?gavgpool(_(minmax))?_ukernel_((\d+)p)?(\d+)x__(.+)_c(\d+)(_acc(\d+))?$", name)
31  if match is None:
32    raise ValueError("Unexpected microkernel name: " + name)
33
34  if match.group(4):
35    primary_tile = int(match.group(5))
36    incremental_tile = int(match.group(6))
37  else:
38    primary_tile = int(match.group(6))
39    incremental_tile = 0
40  channel_tile = int(match.group(8))
41
42  arch, isa = xnncommon.parse_target_name(target_name=match.group(7))
43  return primary_tile, incremental_tile, channel_tile, arch, isa
44
45
46AVGPOOL_TEST_TEMPLATE = """\
47$if INCREMENTAL_TILE == 0:
48  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_fulltile) {
49    $if ISA_CHECK:
50      ${ISA_CHECK};
51    GAvgPoolMicrokernelTester()
52      .rows(${PRIMARY_TILE})
53      .channels(${CHANNEL_TILE})
54      .Test(${", ".join(TEST_ARGS)});
55  }
56
57  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_subtile) {
58    $if ISA_CHECK:
59      ${ISA_CHECK};
60    for (size_t rows = 1; rows < ${PRIMARY_TILE}; rows++) {
61      GAvgPoolMicrokernelTester()
62        .rows(rows)
63        .channels(${CHANNEL_TILE})
64        .Test(${", ".join(TEST_ARGS)});
65    }
66  }
67
68  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_fulltile_with_input_stride) {
69    $if ISA_CHECK:
70      ${ISA_CHECK};
71    GAvgPoolMicrokernelTester()
72      .rows(${PRIMARY_TILE})
73      .channels(${CHANNEL_TILE})
74      .input_stride(${next_prime(CHANNEL_TILE+1)})
75      .Test(${", ".join(TEST_ARGS)});
76  }
77
78  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_fulltile_with_qmax) {
79    $if ISA_CHECK:
80      ${ISA_CHECK};
81    GAvgPoolMicrokernelTester()
82      .rows(${PRIMARY_TILE})
83      .channels(${CHANNEL_TILE})
84      .qmax(128)
85      .Test(${", ".join(TEST_ARGS)});
86  }
87
88  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_fulltile_with_qmin) {
89    $if ISA_CHECK:
90      ${ISA_CHECK};
91    GAvgPoolMicrokernelTester()
92      .rows(${PRIMARY_TILE})
93      .channels(${CHANNEL_TILE})
94      .qmin(128)
95      .Test(${", ".join(TEST_ARGS)});
96  }
97
98  $if CHANNEL_TILE > 1:
99    TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_fulltile) {
100      $if ISA_CHECK:
101        ${ISA_CHECK};
102      for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
103        GAvgPoolMicrokernelTester()
104          .rows(${PRIMARY_TILE})
105          .channels(channels)
106          .Test(${", ".join(TEST_ARGS)});
107      }
108    }
109
110    TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_subtile) {
111      $if ISA_CHECK:
112        ${ISA_CHECK};
113      for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
114        for (size_t rows = 1; rows < ${PRIMARY_TILE}; rows++) {
115          GAvgPoolMicrokernelTester()
116            .rows(rows)
117            .channels(channels)
118            .Test(${", ".join(TEST_ARGS)});
119        }
120      }
121    }
122
123    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_fulltile) {
124      $if ISA_CHECK:
125        ${ISA_CHECK};
126      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
127        GAvgPoolMicrokernelTester()
128          .rows(${PRIMARY_TILE})
129          .channels(channels)
130          .Test(${", ".join(TEST_ARGS)});
131      }
132    }
133
134    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_subtile) {
135      $if ISA_CHECK:
136        ${ISA_CHECK};
137      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
138        for (size_t rows = 1; rows < ${PRIMARY_TILE}; rows++) {
139          GAvgPoolMicrokernelTester()
140            .rows(rows)
141            .channels(channels)
142            .Test(${", ".join(TEST_ARGS)});
143        }
144      }
145    }
146
147    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_fulltile_with_qmax) {
148      $if ISA_CHECK:
149        ${ISA_CHECK};
150      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
151        GAvgPoolMicrokernelTester()
152          .rows(${PRIMARY_TILE})
153          .channels(channels)
154          .qmax(128)
155          .Test(${", ".join(TEST_ARGS)});
156      }
157    }
158
159    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_fulltile_with_qmin) {
160      $if ISA_CHECK:
161        ${ISA_CHECK};
162      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
163        GAvgPoolMicrokernelTester()
164          .rows(${PRIMARY_TILE})
165          .channels(channels)
166          .qmin(128)
167          .Test(${", ".join(TEST_ARGS)});
168      }
169    }
170
171  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_fulltile) {
172    $if ISA_CHECK:
173      ${ISA_CHECK};
174    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
175      GAvgPoolMicrokernelTester()
176        .rows(${PRIMARY_TILE})
177        .channels(channels)
178        .Test(${", ".join(TEST_ARGS)});
179    }
180  }
181
182  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_subtile) {
183    $if ISA_CHECK:
184      ${ISA_CHECK};
185    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
186      for (size_t rows = 1; rows < ${PRIMARY_TILE}; rows++) {
187        GAvgPoolMicrokernelTester()
188          .rows(rows)
189          .channels(channels)
190          .Test(${", ".join(TEST_ARGS)});
191      }
192    }
193  }
194
195  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_fulltile_with_qmax) {
196    $if ISA_CHECK:
197      ${ISA_CHECK};
198    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
199      GAvgPoolMicrokernelTester()
200        .rows(${PRIMARY_TILE})
201        .channels(channels)
202        .qmax(128)
203        .Test(${", ".join(TEST_ARGS)});
204    }
205  }
206
207  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_fulltile_with_qmin) {
208    $if ISA_CHECK:
209      ${ISA_CHECK};
210    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
211      GAvgPoolMicrokernelTester()
212        .rows(${PRIMARY_TILE})
213        .channels(channels)
214        .qmin(128)
215        .Test(${", ".join(TEST_ARGS)});
216    }
217  }
218$else:
219  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_fulltile) {
220    $if ISA_CHECK:
221      ${ISA_CHECK};
222    GAvgPoolMicrokernelTester()
223      .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
224      .channels(${CHANNEL_TILE})
225      .Test(${", ".join(TEST_ARGS)});
226  }
227
228  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_fulltile_with_input_stride) {
229    $if ISA_CHECK:
230      ${ISA_CHECK};
231    GAvgPoolMicrokernelTester()
232      .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
233      .channels(${CHANNEL_TILE})
234      .input_stride(${next_prime(CHANNEL_TILE+1)})
235      .Test(${", ".join(TEST_ARGS)});
236  }
237
238  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_fulltile_with_qmax) {
239    $if ISA_CHECK:
240      ${ISA_CHECK};
241    GAvgPoolMicrokernelTester()
242      .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
243      .channels(${CHANNEL_TILE})
244      .qmax(128)
245      .Test(${", ".join(TEST_ARGS)});
246  }
247
248  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_fulltile_with_qmin) {
249    $if ISA_CHECK:
250      ${ISA_CHECK};
251    GAvgPoolMicrokernelTester()
252      .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
253      .channels(${CHANNEL_TILE})
254      .qmin(128)
255      .Test(${", ".join(TEST_ARGS)});
256  }
257
258  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_subtile) {
259    $if ISA_CHECK:
260      ${ISA_CHECK};
261    for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
262      GAvgPoolMicrokernelTester()
263        .rows(rows)
264        .channels(${CHANNEL_TILE})
265        .Test(${", ".join(TEST_ARGS)});
266    }
267  }
268
269  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_subtile_with_input_stride) {
270    $if ISA_CHECK:
271      ${ISA_CHECK};
272    for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
273      GAvgPoolMicrokernelTester()
274        .rows(rows)
275        .channels(${CHANNEL_TILE})
276        .input_stride(${next_prime(CHANNEL_TILE+1)})
277        .Test(${", ".join(TEST_ARGS)});
278    }
279  }
280
281  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_multipass_fulltile) {
282    $if ISA_CHECK:
283      ${ISA_CHECK};
284    for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
285      GAvgPoolMicrokernelTester()
286        .rows(rows)
287        .channels(${CHANNEL_TILE})
288        .Test(${", ".join(TEST_ARGS)});
289    }
290  }
291
292  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_multipass_fulltile_with_input_stride) {
293    $if ISA_CHECK:
294      ${ISA_CHECK};
295    for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
296      GAvgPoolMicrokernelTester()
297        .rows(rows)
298        .channels(${CHANNEL_TILE})
299        .input_stride(${next_prime(CHANNEL_TILE+1)})
300        .Test(${", ".join(TEST_ARGS)});
301    }
302  }
303
304  TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_2pass_fulltile) {
305    $if ISA_CHECK:
306      ${ISA_CHECK};
307    for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
308      GAvgPoolMicrokernelTester()
309        .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
310        .channels(channels)
311        .Test(${", ".join(TEST_ARGS)});
312    }
313  }
314
315  TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_2pass_subtile) {
316    $if ISA_CHECK:
317      ${ISA_CHECK};
318    for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
319      for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
320        GAvgPoolMicrokernelTester()
321          .rows(rows)
322          .channels(channels)
323          .Test(${", ".join(TEST_ARGS)});
324      }
325    }
326  }
327
328  TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_multipass_fulltile) {
329    $if ISA_CHECK:
330      ${ISA_CHECK};
331    for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
332      for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
333        GAvgPoolMicrokernelTester()
334          .rows(rows)
335          .channels(channels)
336          .Test(${", ".join(TEST_ARGS)});
337      }
338    }
339  }
340
341  TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_multipass_fulltile_with_input_stride) {
342    $if ISA_CHECK:
343      ${ISA_CHECK};
344    for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
345      for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
346        GAvgPoolMicrokernelTester()
347          .rows(rows)
348          .channels(channels)
349          .input_stride(${next_prime(CHANNEL_TILE*16+1)})
350          .Test(${", ".join(TEST_ARGS)});
351      }
352    }
353  }
354
355  $if CHANNEL_TILE > 1:
356    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_2pass_fulltile) {
357      $if ISA_CHECK:
358        ${ISA_CHECK};
359      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
360        GAvgPoolMicrokernelTester()
361          .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
362          .channels(channels)
363          .Test(${", ".join(TEST_ARGS)});
364      }
365    }
366
367    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_2pass_fulltile_with_qmax) {
368      $if ISA_CHECK:
369        ${ISA_CHECK};
370      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
371        GAvgPoolMicrokernelTester()
372          .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
373          .channels(channels)
374          .qmax(128)
375          .Test(${", ".join(TEST_ARGS)});
376      }
377    }
378
379    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_2pass_fulltile_with_qmin) {
380      $if ISA_CHECK:
381        ${ISA_CHECK};
382      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
383        GAvgPoolMicrokernelTester()
384          .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
385          .channels(channels)
386          .qmin(128)
387          .Test(${", ".join(TEST_ARGS)});
388      }
389    }
390
391    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_2pass_subtile) {
392      $if ISA_CHECK:
393        ${ISA_CHECK};
394      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
395        for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
396          GAvgPoolMicrokernelTester()
397            .rows(rows)
398            .channels(channels)
399            .Test(${", ".join(TEST_ARGS)});
400        }
401      }
402    }
403
404    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_multipass_fulltile) {
405      $if ISA_CHECK:
406        ${ISA_CHECK};
407      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
408        for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
409          GAvgPoolMicrokernelTester()
410            .rows(rows)
411            .channels(channels)
412            .Test(${", ".join(TEST_ARGS)});
413        }
414      }
415    }
416
417    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_multipass_fulltile_with_input_stride) {
418      $if ISA_CHECK:
419        ${ISA_CHECK};
420      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
421        for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
422          GAvgPoolMicrokernelTester()
423            .rows(rows)
424            .channels(channels)
425            .input_stride(${next_prime(CHANNEL_TILE+1)})
426            .Test(${", ".join(TEST_ARGS)});
427        }
428      }
429    }
430
431  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_2pass_fulltile) {
432    $if ISA_CHECK:
433      ${ISA_CHECK};
434    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
435      GAvgPoolMicrokernelTester()
436        .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
437        .channels(channels)
438        .Test(${", ".join(TEST_ARGS)});
439    }
440  }
441
442  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_2pass_fulltile_with_qmax) {
443    $if ISA_CHECK:
444      ${ISA_CHECK};
445    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
446      GAvgPoolMicrokernelTester()
447        .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
448        .channels(channels)
449        .qmax(128)
450        .Test(${", ".join(TEST_ARGS)});
451    }
452  }
453
454  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_2pass_fulltile_with_qmin) {
455    $if ISA_CHECK:
456      ${ISA_CHECK};
457    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
458      GAvgPoolMicrokernelTester()
459        .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
460        .channels(channels)
461        .qmin(128)
462        .Test(${", ".join(TEST_ARGS)});
463    }
464  }
465
466  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_2pass_subtile) {
467    $if ISA_CHECK:
468      ${ISA_CHECK};
469    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
470      for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
471        GAvgPoolMicrokernelTester()
472          .rows(rows)
473          .channels(channels)
474          .Test(${", ".join(TEST_ARGS)});
475      }
476    }
477  }
478
479  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_multipass_fulltile) {
480    $if ISA_CHECK:
481      ${ISA_CHECK};
482    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
483      for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows < ${INCREMENTAL_TILE*5}; rows += ${PRIMARY_TILE+INCREMENTAL_TILE}) {
484        GAvgPoolMicrokernelTester()
485          .rows(rows)
486          .channels(channels)
487          .Test(${", ".join(TEST_ARGS)});
488      }
489    }
490  }
491
492  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_multipass_fulltile_with_input_stride) {
493    $if ISA_CHECK:
494      ${ISA_CHECK};
495    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
496      for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows < ${INCREMENTAL_TILE*5}; rows += ${PRIMARY_TILE+INCREMENTAL_TILE}) {
497        GAvgPoolMicrokernelTester()
498          .rows(rows)
499          .channels(channels)
500          .input_stride(${next_prime(CHANNEL_TILE*2+11)})
501          .Test(${", ".join(TEST_ARGS)});
502      }
503    }
504  }
505
506"""
507
508
509def generate_test_cases(ukernel, primary_tile, incremental_tile, channel_tile,
510                        isa):
511  """Generates all tests cases for a GAVGPOOL micro-kernel.
512
513  Args:
514    ukernel: C name of the micro-kernel function.
515    primary_tile: Number of rows (pixels) processed per one iteration of the
516                  primary outer loop of the micro-kernel.
517    incremental_tile: Number of rows (pixels) processed per one iteration of
518                      the incremental outer loop of the micro-kernel.
519    channel_tile: Number of channels processed per one iteration of the inner
520                  loops of the micro-kernel.
521    isa: instruction set required to run the micro-kernel. Generated unit test
522         will skip execution if the host processor doesn't support this ISA.
523
524  Returns:
525    Code for the test case.
526  """
527  _, test_name = ukernel.split("_", 1)
528  _, datatype, ukernel_type, _ = ukernel.split("_", 3)
529  test_args = [ukernel]
530  if not isa:
531    test_args.append("GAvgPoolMicrokernelTester::Variant::Scalar")
532  return xngen.preprocess(AVGPOOL_TEST_TEMPLATE, {
533      "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
534      "TEST_ARGS": test_args,
535      "DATATYPE": datatype,
536      "PRIMARY_TILE": primary_tile,
537      "INCREMENTAL_TILE": incremental_tile,
538      "CHANNEL_TILE": channel_tile,
539      "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
540      "next_prime": next_prime,
541    })
542
543
544def main(args):
545  options = parser.parse_args(args)
546
547  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
548    spec_yaml = yaml.safe_load(spec_file)
549    if not isinstance(spec_yaml, list):
550      raise ValueError("expected a list of micro-kernels in the spec")
551
552    tests = """\
553// Copyright (c) Facebook, Inc. and its affiliates.
554// All rights reserved.
555//
556// Copyright 2020 Google LLC
557//
558// This source code is licensed under the BSD-style license found in the
559// LICENSE file in the root directory of this source tree.
560//
561// Auto-generated file. Do not edit!
562//   Specification: {specification}
563//   Generator: {generator}
564
565
566#include <gtest/gtest.h>
567
568#include <xnnpack/common.h>
569#include <xnnpack/isa-checks.h>
570
571#include <xnnpack/gavgpool.h>
572#include "gavgpool-microkernel-tester.h"
573""".format(specification=options.spec, generator=sys.argv[0])
574
575    for ukernel_spec in spec_yaml:
576      name = ukernel_spec["name"]
577      primary_tile, incremental_tile, channel_tile, arch, isa = \
578        split_ukernel_name(name)
579
580      # specification can override architecture
581      arch = ukernel_spec.get("arch", arch)
582
583      test_case = generate_test_cases(name, primary_tile, incremental_tile,
584                                      channel_tile, isa)
585      tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
586
587    with codecs.open(options.output, "w", encoding="utf-8") as output_file:
588      output_file.write(tests)
589
590
591if __name__ == "__main__":
592  main(sys.argv[1:])
593