1#!/usr/bin/python3
2
3# Copyright 2018, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""CTS testcase generator
18
19Implements CTS test backend. Invoked by ml/nn/runtime/test/specs/generate_tests.sh;
20See that script for details on how this script is used.
21
22"""
23
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27import argparse
28import math
29import os
30import re
31import sys
32import traceback
33
34# Stuff from test generator
35import test_generator as tg
36from test_generator import ActivationConverter
37from test_generator import BoolScalar
38from test_generator import Configuration
39from test_generator import DataTypeConverter
40from test_generator import DataLayoutConverter
41from test_generator import Example
42from test_generator import Float16Scalar
43from test_generator import Float32Scalar
44from test_generator import Float32Vector
45from test_generator import GetJointStr
46from test_generator import IgnoredOutput
47from test_generator import Input
48from test_generator import Int32Scalar
49from test_generator import Int32Vector
50from test_generator import Internal
51from test_generator import Model
52from test_generator import Operand
53from test_generator import Output
54from test_generator import Parameter
55from test_generator import ParameterAsInputConverter
56from test_generator import RelaxedModeConverter
57from test_generator import SmartOpen
58from test_generator import SymmPerChannelQuantParams
59
60def IndentedPrint(s, indent=2, *args, **kwargs):
61    print('\n'.join([" " * indent + i for i in s.split('\n')]), *args, **kwargs)
62
63# Take a model from command line
64def ParseCmdLine():
65    parser = argparse.ArgumentParser()
66    parser.add_argument("spec", help="the spec file/directory")
67    parser.add_argument(
68        "-m", "--model", help="the output model file/directory", default="-")
69    parser.add_argument(
70        "-e", "--example", help="the output example file/directory", default="-")
71    parser.add_argument(
72        "-t", "--test", help="the output test file/directory", default="-")
73    parser.add_argument(
74        "-c", "--cts", help="the CTS TestGeneratedOneFile.cpp", default="-")
75    parser.add_argument(
76        "-f", "--force", help="force to regenerate all spec files", action="store_true")
77    # for slicing tool
78    parser.add_argument(
79        "-l", "--log", help="the optional log file", default="")
80    args = parser.parse_args()
81    tg.FileNames.InitializeFileLists(
82        args.spec, args.model, args.example, args.test, args.cts, args.log)
83    Configuration.force_regenerate = args.force
84
85def NeedRegenerate():
86    if not all(os.path.exists(f) for f in \
87        [tg.FileNames.modelFile, tg.FileNames.exampleFile, tg.FileNames.testFile]):
88        return True
89    specTime = os.path.getmtime(tg.FileNames.specFile) + 10
90    modelTime = os.path.getmtime(tg.FileNames.modelFile)
91    exampleTime = os.path.getmtime(tg.FileNames.exampleFile)
92    testTime = os.path.getmtime(tg.FileNames.testFile)
93    if all(t > specTime for t in [modelTime, exampleTime, testTime]):
94        return False
95    return True
96
97# Write headers for generated files, which are boilerplate codes only related to filenames
98def InitializeFiles(model_fd, example_fd, test_fd):
99    fileHeader = "// clang-format off\n// Generated file (from: {spec_file}). Do not edit"
100    testFileHeader = """\
101#include "../../TestGenerated.h"\n
102namespace {spec_name} {{
103// Generated {spec_name} test
104#include "{example_file}"
105// Generated model constructor
106#include "{model_file}"
107}} // namespace {spec_name}\n"""
108    # This regex is to remove prefix and get relative path for #include
109    pathRegex = r".*((frameworks/ml/nn/(runtime/test/)?)|(vendor/google/[a-z]*/test/))"
110    specFileBase = os.path.basename(tg.FileNames.specFile)
111    print(fileHeader.format(spec_file=specFileBase), file=model_fd)
112    print(fileHeader.format(spec_file=specFileBase), file=example_fd)
113    print(fileHeader.format(spec_file=specFileBase), file=test_fd)
114    print(testFileHeader.format(
115        model_file=re.sub(pathRegex, "", tg.FileNames.modelFile),
116        example_file=re.sub(pathRegex, "", tg.FileNames.exampleFile),
117        spec_name=tg.FileNames.specName), file=test_fd)
118
119# Dump is_ignored function for IgnoredOutput
120def DumpCtsIsIgnored(model, model_fd):
121    isIgnoredTemplate = """\
122inline bool {is_ignored_name}(int i) {{
123  static std::set<int> ignore = {{{ignored_index}}};
124  return ignore.find(i) != ignore.end();\n}}\n"""
125    print(isIgnoredTemplate.format(
126        ignored_index=tg.GetJointStr(model.GetIgnoredOutputs(), method=lambda x: str(x.index)),
127        is_ignored_name=str(model.isIgnoredFunctionName)), file=model_fd)
128
129# Dump Model file for Cts tests
130def DumpCtsModel(model, model_fd):
131    assert model.compiled
132    if model.dumped:
133        return
134    print("void %s(Model *model) {"%(model.createFunctionName), file=model_fd)
135
136    # Phase 0: types
137    for t in model.GetTypes():
138        if t.scale == 0.0 and t.zeroPoint == 0 and t.extraParams is None:
139            typeDef = "OperandType %s(Type::%s, %s);"%(t, t.type, t.GetDimensionsString())
140        else:
141            if t.extraParams is None or t.extraParams.hide:
142                typeDef = "OperandType %s(Type::%s, %s, %s, %d);"%(
143                    t, t.type, t.GetDimensionsString(), tg.PrettyPrintAsFloat(t.scale), t.zeroPoint)
144            else:
145                typeDef = "OperandType %s(Type::%s, %s, %s, %d, %s);"%(
146                    t, t.type, t.GetDimensionsString(), tg.PrettyPrintAsFloat(t.scale), t.zeroPoint,
147                    t.extraParams.GetConstructor())
148
149        IndentedPrint(typeDef, file=model_fd)
150
151    # Phase 1: add operands
152    print("  // Phase 1, operands", file=model_fd)
153    for op in model.operands:
154        IndentedPrint("auto %s = model->addOperand(&%s);"%(op, op.type), file=model_fd)
155
156    # Phase 2: operations
157    print("  // Phase 2, operations", file=model_fd)
158    for p in model.GetParameters():
159        paramDef = "static %s %s[] = %s;\nmodel->setOperandValue(%s, %s, sizeof(%s) * %d);"%(
160            p.type.GetCppTypeString(), p.initializer, p.GetListInitialization(), p,
161            p.initializer, p.type.GetCppTypeString(), p.type.GetNumberOfElements())
162        IndentedPrint(paramDef, file=model_fd)
163    for op in model.operations:
164        IndentedPrint("model->addOperation(ANEURALNETWORKS_%s, {%s}, {%s});"%(
165            op.optype, tg.GetJointStr(op.ins), tg.GetJointStr(op.outs)), file=model_fd)
166
167    # Phase 3: add inputs and outputs
168    print ("  // Phase 3, inputs and outputs", file=model_fd)
169    IndentedPrint("model->identifyInputsAndOutputs(\n  {%s},\n  {%s});"%(
170        tg.GetJointStr(model.GetInputs()), tg.GetJointStr(model.GetOutputs())), file=model_fd)
171
172    # Phase 4: set relaxed execution if needed
173    if (model.isRelaxed):
174        print ("  // Phase 4: set relaxed execution", file=model_fd)
175        print ("  model->relaxComputationFloat32toFloat16(true);", file=model_fd)
176
177    print ("  assert(model->isValid());", file=model_fd)
178    print ("}\n", file=model_fd)
179    DumpCtsIsIgnored(model, model_fd)
180    model.dumped = True
181
182def DumpMixedType(operands, feedDict):
183    supportedTensors = [
184        "DIMENSIONS",
185        "TENSOR_FLOAT32",
186        "TENSOR_INT32",
187        "TENSOR_QUANT8_ASYMM",
188        "TENSOR_OEM_BYTE",
189        "TENSOR_QUANT16_SYMM",
190        "TENSOR_FLOAT16",
191        "TENSOR_BOOL8",
192        "TENSOR_QUANT8_SYMM_PER_CHANNEL",
193        "TENSOR_QUANT16_ASYMM",
194        "TENSOR_QUANT8_SYMM",
195    ]
196    typedMap = {t: [] for t in supportedTensors}
197    FeedAndGet = lambda op, d: op.Feed(d).GetListInitialization()
198    # group the operands by type
199    for operand in operands:
200        try:
201            typedMap[operand.type.type].append(FeedAndGet(operand, feedDict))
202            typedMap["DIMENSIONS"].append("{%d, {%s}}"%(
203                operand.index, GetJointStr(operand.dimensions)))
204        except KeyError as e:
205            traceback.print_exc()
206            sys.exit("Cannot dump tensor of type {}".format(operand.type.type))
207    mixedTypeTemplate = """\
208{{ // See tools/test_generator/include/TestHarness.h:MixedTyped
209  // int -> Dimensions map
210  .operandDimensions = {{{dimensions_map}}},
211  // int -> FLOAT32 map
212  .float32Operands = {{{float32_map}}},
213  // int -> INT32 map
214  .int32Operands = {{{int32_map}}},
215  // int -> QUANT8_ASYMM map
216  .quant8AsymmOperands = {{{uint8_map}}},
217  // int -> QUANT16_SYMM map
218  .quant16SymmOperands = {{{int16_map}}},
219  // int -> FLOAT16 map
220  .float16Operands = {{{float16_map}}},
221  // int -> BOOL8 map
222  .bool8Operands = {{{bool8_map}}},
223  // int -> QUANT8_SYMM_PER_CHANNEL map
224  .quant8ChannelOperands = {{{int8_map}}},
225  // int -> QUANT16_ASYMM map
226  .quant16AsymmOperands = {{{uint16_map}}},
227  // int -> QUANT8_SYMM map
228  .quant8SymmOperands = {{{quant8_symm_map}}},
229}}"""
230    return mixedTypeTemplate.format(
231        dimensions_map=tg.GetJointStr(typedMap.get("DIMENSIONS", [])),
232        float32_map=tg.GetJointStr(typedMap.get("TENSOR_FLOAT32", [])),
233        int32_map=tg.GetJointStr(typedMap.get("TENSOR_INT32", [])),
234        uint8_map=tg.GetJointStr(typedMap.get("TENSOR_QUANT8_ASYMM", []) +
235                                 typedMap.get("TENSOR_OEM_BYTE", [])),
236        int16_map=tg.GetJointStr(typedMap.get("TENSOR_QUANT16_SYMM", [])),
237        float16_map=tg.GetJointStr(typedMap.get("TENSOR_FLOAT16", [])),
238        int8_map=tg.GetJointStr(typedMap.get("TENSOR_QUANT8_SYMM_PER_CHANNEL", [])),
239        bool8_map=tg.GetJointStr(typedMap.get("TENSOR_BOOL8", [])),
240        uint16_map=tg.GetJointStr(typedMap.get("TENSOR_QUANT16_ASYMM", [])),
241        quant8_symm_map=tg.GetJointStr(typedMap.get("TENSOR_QUANT8_SYMM", []))
242    )
243
244# Dump Example file for Cts tests
245def DumpCtsExample(example, example_fd):
246    print("std::vector<MixedTypedExample>& get_%s() {" % (example.examplesName), file=example_fd)
247    print("static std::vector<MixedTypedExample> %s = {" % (example.examplesName), file=example_fd)
248    for inputFeedDict, outputFeedDict in example.feedDicts:
249        print ('// Begin of an example', file = example_fd)
250        print ('{\n.operands = {', file = example_fd)
251        inputs = DumpMixedType(example.model.GetInputs(), inputFeedDict)
252        outputs = DumpMixedType(example.model.GetOutputs(), outputFeedDict)
253        print ('//Input(s)\n%s,' % inputs , file = example_fd)
254        print ('//Output(s)\n%s' % outputs, file = example_fd)
255        print ('},', file = example_fd)
256        if example.expectedMultinomialDistributionTolerance is not None:
257          print ('.expectedMultinomialDistributionTolerance = %f' %
258                 example.expectedMultinomialDistributionTolerance, file = example_fd)
259        print ('}, // End of an example', file = example_fd)
260    print("};", file=example_fd)
261    print("return %s;" % (example.examplesName), file=example_fd)
262    print("};\n", file=example_fd)
263
264# Dump Test file for Cts tests
265def DumpCtsTest(example, test_fd):
266    testTemplate = """\
267TEST_F({test_case_name}, {test_name}) {{
268    execute({namespace}::{create_model_name},
269            {namespace}::{is_ignored_name},
270            {namespace}::get_{examples_name}(){log_file});\n}}\n"""
271    if example.model.version is not None:
272        testTemplate += """\
273TEST_AVAILABLE_SINCE({version}, {test_name}, {namespace}::{create_model_name})\n"""
274    print(testTemplate.format(
275        test_case_name="DynamicOutputShapeTest" if example.model.hasDynamicOutputShape \
276                       else "GeneratedTests",
277        test_name=str(example.testName),
278        namespace=tg.FileNames.specName,
279        create_model_name=str(example.model.createFunctionName),
280        is_ignored_name=str(example.model.isIgnoredFunctionName),
281        examples_name=str(example.examplesName),
282        version=example.model.version,
283        log_file=tg.FileNames.logFile), file=test_fd)
284
285if __name__ == '__main__':
286    ParseCmdLine()
287    while tg.FileNames.NextFile():
288        if Configuration.force_regenerate or NeedRegenerate():
289            print("Generating test(s) from spec: %s" % tg.FileNames.specFile, file=sys.stderr)
290            exec(open(tg.FileNames.specFile, "r").read())
291            print("Output CTS model: %s" % tg.FileNames.modelFile, file=sys.stderr)
292            print("Output example:%s" % tg.FileNames.exampleFile, file=sys.stderr)
293            print("Output CTS test: %s" % tg.FileNames.testFile, file=sys.stderr)
294            with SmartOpen(tg.FileNames.modelFile) as model_fd, \
295                 SmartOpen(tg.FileNames.exampleFile) as example_fd, \
296                 SmartOpen(tg.FileNames.testFile) as test_fd:
297                InitializeFiles(model_fd, example_fd, test_fd)
298                Example.DumpAllExamples(
299                    DumpModel=DumpCtsModel, model_fd=model_fd,
300                    DumpExample=DumpCtsExample, example_fd=example_fd,
301                    DumpTest=DumpCtsTest, test_fd=test_fd)
302        else:
303            print("Skip file: %s" % tg.FileNames.specFile, file=sys.stderr)
304        with SmartOpen(tg.FileNames.ctsFile, mode="a") as cts_fd:
305            print("#include \"../generated/tests/%s.cpp\""%os.path.basename(tg.FileNames.specFile),
306                file=cts_fd)
307