1 /*
2  * Copyright (C) 2015 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <iomanip>
18 #include <iostream>
19 #include <cmath>
20 #include <sstream>
21 
22 #include "Generator.h"
23 #include "Specification.h"
24 #include "Utilities.h"
25 
26 using namespace std;
27 
28 // Converts float2 to FLOAT_32 and 2, etc.
convertToRsType(const string & name,string * dataType,char * vectorSize)29 static void convertToRsType(const string& name, string* dataType, char* vectorSize) {
30     string s = name;
31     int last = s.size() - 1;
32     char lastChar = s[last];
33     if (lastChar >= '1' && lastChar <= '4') {
34         s.erase(last);
35         *vectorSize = lastChar;
36     } else {
37         *vectorSize = '1';
38     }
39     dataType->clear();
40     for (int i = 0; i < NUM_TYPES; i++) {
41         if (s == TYPES[i].cType) {
42             *dataType = TYPES[i].rsDataType;
43             break;
44         }
45     }
46 }
47 
48 // Returns true if any permutation of the function have tests to b
needTestFiles(const Function & function,unsigned int versionOfTestFiles)49 static bool needTestFiles(const Function& function, unsigned int versionOfTestFiles) {
50     for (auto spec : function.getSpecifications()) {
51         if (spec->hasTests(versionOfTestFiles)) {
52             return true;
53         }
54     }
55     return false;
56 }
57 
58 /* One instance of this class is generated for each permutation of a function for which
59  * we are generating test code.  This instance will generate both the script and the Java
60  * section of the test files for this permutation.  The class is mostly used to keep track
61  * of the various names shared between script and Java files.
62  * WARNING: Because the constructor keeps a reference to the FunctionPermutation, PermutationWriter
63  * should not exceed the lifetime of FunctionPermutation.
64  */
65 class PermutationWriter {
66 private:
67     FunctionPermutation& mPermutation;
68 
69     string mRsKernelName;
70     string mJavaArgumentsClassName;
71     string mJavaArgumentsNClassName;
72     string mJavaVerifierComputeMethodName;
73     string mJavaVerifierVerifyMethodName;
74     string mJavaCheckMethodName;
75     string mJavaVerifyMethodName;
76 
77     // Pointer to the files we are generating.  Handy to avoid always passing them in the calls.
78     GeneratedFile* mRs;
79     GeneratedFile* mJava;
80 
81     /* Shortcuts to the return parameter and the first input parameter of the function
82      * specification.
83      */
84     const ParameterDefinition* mReturnParam;      // Can be nullptr.  NOT OWNED.
85     const ParameterDefinition* mFirstInputParam;  // Can be nullptr.  NOT OWNED.
86 
87     /* All the parameters plus the return param, if present.  Collecting them together
88      * simplifies code generation.  NOT OWNED.
89      */
90     vector<const ParameterDefinition*> mAllInputsAndOutputs;
91 
92     /* We use a class to pass the arguments between the generated code and the CoreVerifier.  This
93      * method generates this class.  The set keeps track if we've generated this class already
94      * for this test file, as more than one permutation may use the same argument class.
95      */
96     void writeJavaArgumentClass(bool scalar, set<string>* javaGeneratedArgumentClasses) const;
97 
98     // Generate the Check* method that invokes the script and calls the verifier.
99     void writeJavaCheckMethod(bool generateCallToVerifier) const;
100 
101     // Generate code to define and randomly initialize the input allocation.
102     void writeJavaInputAllocationDefinition(const ParameterDefinition& param) const;
103 
104     /* Generate code that instantiate an allocation of floats or integers and fills it with
105      * random data. This random data must be compatible with the specified type.  This is
106      * used for the convert_* tests, as converting values that don't fit yield undefined results.
107      */
108     void writeJavaRandomCompatibleFloatAllocation(const string& dataType, const string& seed,
109                                                   char vectorSize,
110                                                   const NumericalType& compatibleType,
111                                                   const NumericalType& generatedType) const;
112     void writeJavaRandomCompatibleIntegerAllocation(const string& dataType, const string& seed,
113                                                     char vectorSize,
114                                                     const NumericalType& compatibleType,
115                                                     const NumericalType& generatedType) const;
116 
117     // Generate code that defines an output allocation.
118     void writeJavaOutputAllocationDefinition(const ParameterDefinition& param) const;
119 
120     /* Generate the code that verifies the results for RenderScript functions where each entry
121      * of a vector is evaluated independently.  If verifierValidates is true, CoreMathVerifier
122      * does the actual validation instead of more commonly returning the range of acceptable values.
123      */
124     void writeJavaVerifyScalarMethod(bool verifierValidates) const;
125 
126     /* Generate the code that verify the results for a RenderScript function where a vector
127      * is a point in n-dimensional space.
128      */
129     void writeJavaVerifyVectorMethod() const;
130 
131     // Generate the line that creates the Target.
132     void writeJavaCreateTarget() const;
133 
134     // Generate the method header of the verify function.
135     void writeJavaVerifyMethodHeader() const;
136 
137     // Generate codes that copies the content of an allocation to an array.
138     void writeJavaArrayInitialization(const ParameterDefinition& p) const;
139 
140     // Generate code that tests one value returned from the script.
141     void writeJavaTestAndSetValid(const ParameterDefinition& p, const string& argsIndex,
142                                   const string& actualIndex) const;
143     void writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex,
144                                const string& actualIndex) const;
145     // For test:vector cases, generate code that compares returned vector vs. expected value.
146     void writeJavaVectorComparison(const ParameterDefinition& p) const;
147 
148     // Muliple functions that generates code to build the error message if an error is found.
149     void writeJavaAppendOutputToMessage(const ParameterDefinition& p, const string& argsIndex,
150                                         const string& actualIndex, bool verifierValidates) const;
151     void writeJavaAppendInputToMessage(const ParameterDefinition& p, const string& actual) const;
152     void writeJavaAppendNewLineToMessage() const;
153     void writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const;
154     void writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const;
155 
156     // Generate the set of instructions to call the script.
157     void writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const;
158 
159     // Write an allocation definition if not already emitted in the .rs file.
160     void writeRsAllocationDefinition(const ParameterDefinition& param,
161                                      set<string>* rsAllocationsGenerated) const;
162 
163 public:
164     /* NOTE: We keep pointers to the permutation and the files.  This object should not
165      * outlive the arguments.
166      */
167     PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile,
168                       GeneratedFile* javaFile);
getJavaCheckMethodName() const169     string getJavaCheckMethodName() const { return mJavaCheckMethodName; }
170 
171     // Write the script test function for this permutation.
172     void writeRsSection(set<string>* rsAllocationsGenerated) const;
173     // Write the section of the Java code that calls the script and validates the results
174     void writeJavaSection(set<string>* javaGeneratedArgumentClasses) const;
175 };
176 
PermutationWriter(FunctionPermutation & permutation,GeneratedFile * rsFile,GeneratedFile * javaFile)177 PermutationWriter::PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile,
178                                      GeneratedFile* javaFile)
179     : mPermutation(permutation),
180       mRs(rsFile),
181       mJava(javaFile),
182       mReturnParam(nullptr),
183       mFirstInputParam(nullptr) {
184     mRsKernelName = "test" + capitalize(permutation.getName());
185 
186     mJavaArgumentsClassName = "Arguments";
187     mJavaArgumentsNClassName = "Arguments";
188     const string trunk = capitalize(permutation.getNameTrunk());
189     mJavaCheckMethodName = "check" + trunk;
190     mJavaVerifyMethodName = "verifyResults" + trunk;
191 
192     for (auto p : permutation.getParams()) {
193         mAllInputsAndOutputs.push_back(p);
194         if (mFirstInputParam == nullptr && !p->isOutParameter) {
195             mFirstInputParam = p;
196         }
197     }
198     mReturnParam = permutation.getReturn();
199     if (mReturnParam) {
200         mAllInputsAndOutputs.push_back(mReturnParam);
201     }
202 
203     for (auto p : mAllInputsAndOutputs) {
204         const string capitalizedRsType = capitalize(p->rsType);
205         const string capitalizedBaseType = capitalize(p->rsBaseType);
206         mRsKernelName += capitalizedRsType;
207         mJavaArgumentsClassName += capitalizedBaseType;
208         mJavaArgumentsNClassName += capitalizedBaseType;
209         if (p->mVectorSize != "1") {
210             mJavaArgumentsNClassName += "N";
211         }
212         mJavaCheckMethodName += capitalizedRsType;
213         mJavaVerifyMethodName += capitalizedRsType;
214     }
215     mJavaVerifierComputeMethodName = "compute" + trunk;
216     mJavaVerifierVerifyMethodName = "verify" + trunk;
217 }
218 
writeJavaSection(set<string> * javaGeneratedArgumentClasses) const219 void PermutationWriter::writeJavaSection(set<string>* javaGeneratedArgumentClasses) const {
220     // By default, we test the results using item by item comparison.
221     const string test = mPermutation.getTest();
222     if (test == "scalar" || test == "limited") {
223         writeJavaArgumentClass(true, javaGeneratedArgumentClasses);
224         writeJavaCheckMethod(true);
225         writeJavaVerifyScalarMethod(false);
226     } else if (test == "custom") {
227         writeJavaArgumentClass(true, javaGeneratedArgumentClasses);
228         writeJavaCheckMethod(true);
229         writeJavaVerifyScalarMethod(true);
230     } else if (test == "vector") {
231         writeJavaArgumentClass(false, javaGeneratedArgumentClasses);
232         writeJavaCheckMethod(true);
233         writeJavaVerifyVectorMethod();
234     } else if (test == "noverify") {
235         writeJavaCheckMethod(false);
236     }
237 }
238 
writeJavaArgumentClass(bool scalar,set<string> * javaGeneratedArgumentClasses) const239 void PermutationWriter::writeJavaArgumentClass(bool scalar,
240                                                set<string>* javaGeneratedArgumentClasses) const {
241     string name;
242     if (scalar) {
243         name = mJavaArgumentsClassName;
244     } else {
245         name = mJavaArgumentsNClassName;
246     }
247 
248     // Make sure we have not generated the argument class already.
249     if (!testAndSet(name, javaGeneratedArgumentClasses)) {
250         mJava->indent() << "public class " << name;
251         mJava->startBlock();
252 
253         for (auto p : mAllInputsAndOutputs) {
254             bool isFieldArray = !scalar && p->mVectorSize != "1";
255             bool isFloatyField = p->isOutParameter && p->isFloatType && mPermutation.getTest() != "custom";
256 
257             mJava->indent() << "public ";
258             if (isFloatyField) {
259                 *mJava << "Target.Floaty";
260             } else {
261                 *mJava << p->javaBaseType;
262             }
263             if (isFieldArray) {
264                 *mJava << "[]";
265             }
266             *mJava << " " << p->variableName << ";\n";
267 
268             // For Float16 parameters, add an extra 'double' field in the class
269             // to hold the Double value converted from the input.
270             if (p->isFloat16Parameter() && !isFloatyField) {
271                 mJava->indent() << "public double";
272                 if (isFieldArray) {
273                     *mJava << "[]";
274                 }
275                 *mJava << " " + p->variableName << "Double;\n";
276             }
277         }
278         mJava->endBlock();
279         *mJava << "\n";
280     }
281 }
282 
writeJavaCheckMethod(bool generateCallToVerifier) const283 void PermutationWriter::writeJavaCheckMethod(bool generateCallToVerifier) const {
284     mJava->indent() << "private void " << mJavaCheckMethodName << "()";
285     mJava->startBlock();
286 
287     // Generate the input allocations and initialization.
288     for (auto p : mAllInputsAndOutputs) {
289         if (!p->isOutParameter) {
290             writeJavaInputAllocationDefinition(*p);
291         }
292     }
293     // Generate code to enforce ordering between two allocations if needed.
294     for (auto p : mAllInputsAndOutputs) {
295         if (!p->isOutParameter && !p->smallerParameter.empty()) {
296             string smallerAlloc = "in" + capitalize(p->smallerParameter);
297             mJava->indent() << "enforceOrdering(" << smallerAlloc << ", " << p->javaAllocName
298                             << ");\n";
299         }
300     }
301 
302     // Generate code to check the full and relaxed scripts.
303     writeJavaCallToRs(false, generateCallToVerifier);
304     writeJavaCallToRs(true, generateCallToVerifier);
305 
306     mJava->endBlock();
307     *mJava << "\n";
308 }
309 
writeJavaInputAllocationDefinition(const ParameterDefinition & param) const310 void PermutationWriter::writeJavaInputAllocationDefinition(const ParameterDefinition& param) const {
311     string dataType;
312     char vectorSize;
313     convertToRsType(param.rsType, &dataType, &vectorSize);
314 
315     const string seed = hashString(mJavaCheckMethodName + param.javaAllocName);
316     mJava->indent() << "Allocation " << param.javaAllocName << " = ";
317     if (param.compatibleTypeIndex >= 0) {
318         if (TYPES[param.typeIndex].kind == FLOATING_POINT) {
319             writeJavaRandomCompatibleFloatAllocation(dataType, seed, vectorSize,
320                                                      TYPES[param.compatibleTypeIndex],
321                                                      TYPES[param.typeIndex]);
322         } else {
323             writeJavaRandomCompatibleIntegerAllocation(dataType, seed, vectorSize,
324                                                        TYPES[param.compatibleTypeIndex],
325                                                        TYPES[param.typeIndex]);
326         }
327     } else if (!param.minValue.empty()) {
328         *mJava << "createRandomFloatAllocation(mRS, Element.DataType." << dataType << ", "
329                << vectorSize << ", " << seed << ", " << param.minValue << ", " << param.maxValue
330                << ")";
331     } else {
332         /* TODO Instead of passing always false, check whether we are doing a limited test.
333          * Use instead: (mPermutation.getTest() == "limited" ? "false" : "true")
334          */
335         *mJava << "createRandomAllocation(mRS, Element.DataType." << dataType << ", " << vectorSize
336                << ", " << seed << ", false)";
337     }
338     *mJava << ";\n";
339 }
340 
writeJavaRandomCompatibleFloatAllocation(const string & dataType,const string & seed,char vectorSize,const NumericalType & compatibleType,const NumericalType & generatedType) const341 void PermutationWriter::writeJavaRandomCompatibleFloatAllocation(
342             const string& dataType, const string& seed, char vectorSize,
343             const NumericalType& compatibleType, const NumericalType& generatedType) const {
344     *mJava << "createRandomFloatAllocation"
345            << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", ";
346     double minValue = 0.0;
347     double maxValue = 0.0;
348     switch (compatibleType.kind) {
349         case FLOATING_POINT: {
350             // We're generating floating point values.  We just worry about the exponent.
351             // Subtract 1 for the exponent sign.
352             int bits = min(compatibleType.exponentBits, generatedType.exponentBits) - 1;
353             maxValue = ldexp(0.95, (1 << bits) - 1);
354             minValue = -maxValue;
355             break;
356         }
357         case UNSIGNED_INTEGER:
358             maxValue = maxDoubleForInteger(compatibleType.significantBits,
359                                            generatedType.significantBits);
360             minValue = 0.0;
361             break;
362         case SIGNED_INTEGER:
363             maxValue = maxDoubleForInteger(compatibleType.significantBits,
364                                            generatedType.significantBits);
365             minValue = -maxValue - 1.0;
366             break;
367     }
368     *mJava << scientific << std::setprecision(19);
369     *mJava << minValue << ", " << maxValue << ")";
370     mJava->unsetf(ios_base::floatfield);
371 }
372 
writeJavaRandomCompatibleIntegerAllocation(const string & dataType,const string & seed,char vectorSize,const NumericalType & compatibleType,const NumericalType & generatedType) const373 void PermutationWriter::writeJavaRandomCompatibleIntegerAllocation(
374             const string& dataType, const string& seed, char vectorSize,
375             const NumericalType& compatibleType, const NumericalType& generatedType) const {
376     *mJava << "createRandomIntegerAllocation"
377            << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", ";
378 
379     if (compatibleType.kind == FLOATING_POINT) {
380         // Currently, all floating points can take any number we generate.
381         bool isSigned = generatedType.kind == SIGNED_INTEGER;
382         *mJava << (isSigned ? "true" : "false") << ", " << generatedType.significantBits;
383     } else {
384         bool isSigned =
385                     compatibleType.kind == SIGNED_INTEGER && generatedType.kind == SIGNED_INTEGER;
386         *mJava << (isSigned ? "true" : "false") << ", "
387                << min(compatibleType.significantBits, generatedType.significantBits);
388     }
389     *mJava << ")";
390 }
391 
writeJavaOutputAllocationDefinition(const ParameterDefinition & param) const392 void PermutationWriter::writeJavaOutputAllocationDefinition(
393             const ParameterDefinition& param) const {
394     string dataType;
395     char vectorSize;
396     convertToRsType(param.rsType, &dataType, &vectorSize);
397     mJava->indent() << "Allocation " << param.javaAllocName << " = Allocation.createSized(mRS, "
398                     << "getElement(mRS, Element.DataType." << dataType << ", " << vectorSize
399                     << "), INPUTSIZE);\n";
400 }
401 
writeJavaVerifyScalarMethod(bool verifierValidates) const402 void PermutationWriter::writeJavaVerifyScalarMethod(bool verifierValidates) const {
403     writeJavaVerifyMethodHeader();
404     mJava->startBlock();
405 
406     string vectorSize = "1";
407     for (auto p : mAllInputsAndOutputs) {
408         writeJavaArrayInitialization(*p);
409         if (p->mVectorSize != "1" && p->mVectorSize != vectorSize) {
410             if (vectorSize == "1") {
411                 vectorSize = p->mVectorSize;
412             } else {
413                 cerr << "Error.  Had vector " << vectorSize << " and " << p->mVectorSize << "\n";
414             }
415         }
416     }
417 
418     mJava->indent() << "StringBuilder message = new StringBuilder();\n";
419     mJava->indent() << "boolean errorFound = false;\n";
420     mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)";
421     mJava->startBlock();
422 
423     mJava->indent() << "for (int j = 0; j < " << vectorSize << " ; j++)";
424     mJava->startBlock();
425 
426     mJava->indent() << "// Extract the inputs.\n";
427     mJava->indent() << mJavaArgumentsClassName << " args = new " << mJavaArgumentsClassName
428                     << "();\n";
429     for (auto p : mAllInputsAndOutputs) {
430         if (!p->isOutParameter) {
431             mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i";
432             if (p->vectorWidth != "1") {
433                 *mJava << " * " << p->vectorWidth << " + j";
434             }
435             *mJava << "];\n";
436 
437             // Convert the Float16 parameter to double and store it in the appropriate field in the
438             // Arguments class.
439             if (p->isFloat16Parameter()) {
440                 mJava->indent() << "args." << p->doubleVariableName
441                                 << " = Float16Utils.convertFloat16ToDouble(args."
442                                 << p->variableName << ");\n";
443             }
444         }
445     }
446     const bool hasFloat = mPermutation.hasFloatAnswers();
447     if (verifierValidates) {
448         mJava->indent() << "// Extract the outputs.\n";
449         for (auto p : mAllInputsAndOutputs) {
450             if (p->isOutParameter) {
451                 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName
452                                 << "[i * " << p->vectorWidth << " + j];\n";
453                 if (p->isFloat16Parameter()) {
454                     mJava->indent() << "args." << p->doubleVariableName
455                                     << " = Float16Utils.convertFloat16ToDouble(args."
456                                     << p->variableName << ");\n";
457                 }
458             }
459         }
460         mJava->indent() << "// Ask the CoreMathVerifier to validate.\n";
461         if (hasFloat) {
462             writeJavaCreateTarget();
463         }
464         mJava->indent() << "String errorMessage = CoreMathVerifier."
465                         << mJavaVerifierVerifyMethodName << "(args";
466         if (hasFloat) {
467             *mJava << ", target";
468         }
469         *mJava << ");\n";
470         mJava->indent() << "boolean valid = errorMessage == null;\n";
471     } else {
472         mJava->indent() << "// Figure out what the outputs should have been.\n";
473         if (hasFloat) {
474             writeJavaCreateTarget();
475         }
476         mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args";
477         if (hasFloat) {
478             *mJava << ", target";
479         }
480         *mJava << ");\n";
481         mJava->indent() << "// Validate the outputs.\n";
482         mJava->indent() << "boolean valid = true;\n";
483         for (auto p : mAllInputsAndOutputs) {
484             if (p->isOutParameter) {
485                 writeJavaTestAndSetValid(*p, "", "[i * " + p->vectorWidth + " + j]");
486             }
487         }
488     }
489 
490     mJava->indent() << "if (!valid)";
491     mJava->startBlock();
492     mJava->indent() << "if (!errorFound)";
493     mJava->startBlock();
494     mJava->indent() << "errorFound = true;\n";
495 
496     for (auto p : mAllInputsAndOutputs) {
497         if (p->isOutParameter) {
498             writeJavaAppendOutputToMessage(*p, "", "[i * " + p->vectorWidth + " + j]",
499                                            verifierValidates);
500         } else {
501             writeJavaAppendInputToMessage(*p, "args." + p->variableName);
502         }
503     }
504     if (verifierValidates) {
505         mJava->indent() << "message.append(errorMessage);\n";
506     }
507     mJava->indent() << "message.append(\"Errors at\");\n";
508     mJava->endBlock();
509 
510     mJava->indent() << "message.append(\" [\");\n";
511     mJava->indent() << "message.append(Integer.toString(i));\n";
512     mJava->indent() << "message.append(\", \");\n";
513     mJava->indent() << "message.append(Integer.toString(j));\n";
514     mJava->indent() << "message.append(\"]\");\n";
515 
516     mJava->endBlock();
517     mJava->endBlock();
518     mJava->endBlock();
519 
520     mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n";
521     mJava->indentPlus()
522                 << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n";
523 
524     mJava->endBlock();
525     *mJava << "\n";
526 }
527 
writeJavaVerifyVectorMethod() const528 void PermutationWriter::writeJavaVerifyVectorMethod() const {
529     writeJavaVerifyMethodHeader();
530     mJava->startBlock();
531 
532     for (auto p : mAllInputsAndOutputs) {
533         writeJavaArrayInitialization(*p);
534     }
535     mJava->indent() << "StringBuilder message = new StringBuilder();\n";
536     mJava->indent() << "boolean errorFound = false;\n";
537     mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)";
538     mJava->startBlock();
539 
540     mJava->indent() << mJavaArgumentsNClassName << " args = new " << mJavaArgumentsNClassName
541                     << "();\n";
542 
543     mJava->indent() << "// Create the appropriate sized arrays in args\n";
544     for (auto p : mAllInputsAndOutputs) {
545         if (p->mVectorSize != "1") {
546             string type = p->javaBaseType;
547             if (p->isOutParameter && p->isFloatType) {
548                 type = "Target.Floaty";
549             }
550             mJava->indent() << "args." << p->variableName << " = new " << type << "["
551                             << p->mVectorSize << "];\n";
552             if (p->isFloat16Parameter() && !p->isOutParameter) {
553                 mJava->indent() << "args." << p->variableName << "Double = new double["
554                                 << p->mVectorSize << "];\n";
555             }
556         }
557     }
558 
559     mJava->indent() << "// Fill args with the input values\n";
560     for (auto p : mAllInputsAndOutputs) {
561         if (!p->isOutParameter) {
562             if (p->mVectorSize == "1") {
563                 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i]"
564                                 << ";\n";
565                 // Convert the Float16 parameter to double and store it in the appropriate field in
566                 // the Arguments class.
567                 if (p->isFloat16Parameter()) {
568                     mJava->indent() << "args." << p->doubleVariableName << " = "
569                                     << "Float16Utils.convertFloat16ToDouble(args."
570                                     << p->variableName << ");\n";
571                 }
572             } else {
573                 mJava->indent() << "for (int j = 0; j < " << p->mVectorSize << " ; j++)";
574                 mJava->startBlock();
575                 mJava->indent() << "args." << p->variableName << "[j] = "
576                                 << p->javaArrayName << "[i * " << p->vectorWidth << " + j]"
577                                 << ";\n";
578 
579                 // Convert the Float16 parameter to double and store it in the appropriate field in
580                 // the Arguments class.
581                 if (p->isFloat16Parameter()) {
582                     mJava->indent() << "args." << p->doubleVariableName << "[j] = "
583                                     << "Float16Utils.convertFloat16ToDouble(args."
584                                     << p->variableName << "[j]);\n";
585                 }
586                 mJava->endBlock();
587             }
588         }
589     }
590     writeJavaCreateTarget();
591     mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName
592                     << "(args, target);\n\n";
593 
594     mJava->indent() << "// Compare the expected outputs to the actual values returned by RS.\n";
595     mJava->indent() << "boolean valid = true;\n";
596     for (auto p : mAllInputsAndOutputs) {
597         if (p->isOutParameter) {
598             writeJavaVectorComparison(*p);
599         }
600     }
601 
602     mJava->indent() << "if (!valid)";
603     mJava->startBlock();
604     mJava->indent() << "if (!errorFound)";
605     mJava->startBlock();
606     mJava->indent() << "errorFound = true;\n";
607 
608     for (auto p : mAllInputsAndOutputs) {
609         if (p->isOutParameter) {
610             writeJavaAppendVectorOutputToMessage(*p);
611         } else {
612             writeJavaAppendVectorInputToMessage(*p);
613         }
614     }
615     mJava->indent() << "message.append(\"Errors at\");\n";
616     mJava->endBlock();
617 
618     mJava->indent() << "message.append(\" [\");\n";
619     mJava->indent() << "message.append(Integer.toString(i));\n";
620     mJava->indent() << "message.append(\"]\");\n";
621 
622     mJava->endBlock();
623     mJava->endBlock();
624 
625     mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n";
626     mJava->indentPlus()
627                 << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n";
628 
629     mJava->endBlock();
630     *mJava << "\n";
631 }
632 
633 
writeJavaCreateTarget() const634 void PermutationWriter::writeJavaCreateTarget() const {
635     string name = mPermutation.getName();
636 
637     const char* functionType = "NORMAL";
638     size_t end = name.find('_');
639     if (end != string::npos) {
640         if (name.compare(0, end, "native") == 0) {
641             functionType = "NATIVE";
642         } else if (name.compare(0, end, "half") == 0) {
643             functionType = "HALF";
644         } else if (name.compare(0, end, "fast") == 0) {
645             functionType = "FAST";
646         }
647     }
648 
649     string floatType = mReturnParam->specType;
650     const char* precisionStr = "";
651     if (floatType.compare("f16") == 0) {
652         precisionStr = "HALF";
653     } else if (floatType.compare("f32") == 0) {
654         precisionStr = "FLOAT";
655     } else if (floatType.compare("f64") == 0) {
656         precisionStr = "DOUBLE";
657     } else {
658         cerr << "Error. Unreachable.  Return type is not floating point\n";
659     }
660 
661     mJava->indent() << "Target target = new Target(Target.FunctionType." <<
662                     functionType << ", Target.ReturnType." << precisionStr <<
663                     ", relaxed);\n";
664 }
665 
writeJavaVerifyMethodHeader() const666 void PermutationWriter::writeJavaVerifyMethodHeader() const {
667     mJava->indent() << "private void " << mJavaVerifyMethodName << "(";
668     for (auto p : mAllInputsAndOutputs) {
669         *mJava << "Allocation " << p->javaAllocName << ", ";
670     }
671     *mJava << "boolean relaxed)";
672 }
673 
writeJavaArrayInitialization(const ParameterDefinition & p) const674 void PermutationWriter::writeJavaArrayInitialization(const ParameterDefinition& p) const {
675     mJava->indent() << p.javaBaseType << "[] " << p.javaArrayName << " = new " << p.javaBaseType
676                     << "[INPUTSIZE * " << p.vectorWidth << "];\n";
677 
678     /* For basic types, populate the array with values, to help understand failures.  We have had
679      * bugs where the output buffer was all 0.  We were not sure if there was a failed copy or
680      * the GPU driver was copying zeroes.
681      */
682     if (p.typeIndex >= 0) {
683         mJava->indent() << "Arrays.fill(" << p.javaArrayName << ", (" << TYPES[p.typeIndex].javaType
684                         << ") 42);\n";
685     }
686 
687     mJava->indent() << p.javaAllocName << ".copyTo(" << p.javaArrayName << ");\n";
688 }
689 
writeJavaTestAndSetValid(const ParameterDefinition & p,const string & argsIndex,const string & actualIndex) const690 void PermutationWriter::writeJavaTestAndSetValid(const ParameterDefinition& p,
691                                                  const string& argsIndex,
692                                                  const string& actualIndex) const {
693     writeJavaTestOneValue(p, argsIndex, actualIndex);
694     mJava->startBlock();
695     mJava->indent() << "valid = false;\n";
696     mJava->endBlock();
697 }
698 
writeJavaTestOneValue(const ParameterDefinition & p,const string & argsIndex,const string & actualIndex) const699 void PermutationWriter::writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex,
700                                               const string& actualIndex) const {
701     string actualOut;
702     if (p.isFloat16Parameter()) {
703         // For Float16 values, the output needs to be converted to Double.
704         actualOut = "Float16Utils.convertFloat16ToDouble(" + p.javaArrayName + actualIndex + ")";
705     } else {
706         actualOut = p.javaArrayName + actualIndex;
707     }
708 
709     mJava->indent() << "if (";
710     if (p.isFloatType) {
711         *mJava << "!args." << p.variableName << argsIndex << ".couldBe(" << actualOut;
712         const string s = mPermutation.getPrecisionLimit();
713         if (!s.empty()) {
714             *mJava << ", " << s;
715         }
716         *mJava << ")";
717     } else {
718         *mJava << "args." << p.variableName << argsIndex << " != " << p.javaArrayName
719                << actualIndex;
720     }
721 
722     if (p.undefinedIfOutIsNan && mReturnParam) {
723         *mJava << " && !args." << mReturnParam->variableName << argsIndex << ".isNaN()";
724     }
725     *mJava << ")";
726 }
727 
writeJavaVectorComparison(const ParameterDefinition & p) const728 void PermutationWriter::writeJavaVectorComparison(const ParameterDefinition& p) const {
729     if (p.mVectorSize == "1") {
730         writeJavaTestAndSetValid(p, "", "[i]");
731     } else {
732         mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
733         mJava->startBlock();
734         writeJavaTestAndSetValid(p, "[j]", "[i * " + p.vectorWidth + " + j]");
735         mJava->endBlock();
736     }
737 }
738 
writeJavaAppendOutputToMessage(const ParameterDefinition & p,const string & argsIndex,const string & actualIndex,bool verifierValidates) const739 void PermutationWriter::writeJavaAppendOutputToMessage(const ParameterDefinition& p,
740                                                        const string& argsIndex,
741                                                        const string& actualIndex,
742                                                        bool verifierValidates) const {
743     if (verifierValidates) {
744         mJava->indent() << "message.append(\"Output " << p.variableName << ": \");\n";
745         mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex
746                         << ");\n";
747         writeJavaAppendNewLineToMessage();
748         if (p.isFloat16Parameter()) {
749             writeJavaAppendNewLineToMessage();
750             mJava->indent() << "message.append(\"Output " << p.variableName
751                             << " (in double): \");\n";
752             mJava->indent() << "appendVariableToMessage(message, args." << p.doubleVariableName
753                             << ");\n";
754             writeJavaAppendNewLineToMessage();
755         }
756     } else {
757         mJava->indent() << "message.append(\"Expected output " << p.variableName << ": \");\n";
758         mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex
759                         << ");\n";
760         writeJavaAppendNewLineToMessage();
761 
762         mJava->indent() << "message.append(\"Actual   output " << p.variableName << ": \");\n";
763         mJava->indent() << "appendVariableToMessage(message, " << p.javaArrayName << actualIndex
764                         << ");\n";
765 
766         if (p.isFloat16Parameter()) {
767             writeJavaAppendNewLineToMessage();
768             mJava->indent() << "message.append(\"Actual   output " << p.variableName
769                             << " (in double): \");\n";
770             mJava->indent() << "appendVariableToMessage(message, Float16Utils.convertFloat16ToDouble("
771                             << p.javaArrayName << actualIndex << "));\n";
772         }
773 
774         writeJavaTestOneValue(p, argsIndex, actualIndex);
775         mJava->startBlock();
776         mJava->indent() << "message.append(\" FAIL\");\n";
777         mJava->endBlock();
778         writeJavaAppendNewLineToMessage();
779     }
780 }
781 
writeJavaAppendInputToMessage(const ParameterDefinition & p,const string & actual) const782 void PermutationWriter::writeJavaAppendInputToMessage(const ParameterDefinition& p,
783                                                       const string& actual) const {
784     mJava->indent() << "message.append(\"Input " << p.variableName << ": \");\n";
785     mJava->indent() << "appendVariableToMessage(message, " << actual << ");\n";
786     writeJavaAppendNewLineToMessage();
787 }
788 
writeJavaAppendNewLineToMessage() const789 void PermutationWriter::writeJavaAppendNewLineToMessage() const {
790     mJava->indent() << "message.append(\"\\n\");\n";
791 }
792 
writeJavaAppendVectorInputToMessage(const ParameterDefinition & p) const793 void PermutationWriter::writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const {
794     if (p.mVectorSize == "1") {
795         writeJavaAppendInputToMessage(p, p.javaArrayName + "[i]");
796     } else {
797         mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
798         mJava->startBlock();
799         writeJavaAppendInputToMessage(p, p.javaArrayName + "[i * " + p.vectorWidth + " + j]");
800         mJava->endBlock();
801     }
802 }
803 
writeJavaAppendVectorOutputToMessage(const ParameterDefinition & p) const804 void PermutationWriter::writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const {
805     if (p.mVectorSize == "1") {
806         writeJavaAppendOutputToMessage(p, "", "[i]", false);
807     } else {
808         mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
809         mJava->startBlock();
810         writeJavaAppendOutputToMessage(p, "[j]", "[i * " + p.vectorWidth + " + j]", false);
811         mJava->endBlock();
812     }
813 }
814 
writeJavaCallToRs(bool relaxed,bool generateCallToVerifier) const815 void PermutationWriter::writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const {
816     string script = "script";
817     if (relaxed) {
818         script += "Relaxed";
819     }
820 
821     mJava->indent() << "try";
822     mJava->startBlock();
823 
824     for (auto p : mAllInputsAndOutputs) {
825         if (p->isOutParameter) {
826             writeJavaOutputAllocationDefinition(*p);
827         }
828     }
829 
830     for (auto p : mPermutation.getParams()) {
831         if (p != mFirstInputParam) {
832             mJava->indent() << script << ".set_" << p->rsAllocName << "(" << p->javaAllocName
833                             << ");\n";
834         }
835     }
836 
837     mJava->indent() << script << ".forEach_" << mRsKernelName << "(";
838     bool needComma = false;
839     if (mFirstInputParam) {
840         *mJava << mFirstInputParam->javaAllocName;
841         needComma = true;
842     }
843     if (mReturnParam) {
844         if (needComma) {
845             *mJava << ", ";
846         }
847         *mJava << mReturnParam->variableName << ");\n";
848     }
849 
850     if (generateCallToVerifier) {
851         mJava->indent() << mJavaVerifyMethodName << "(";
852         for (auto p : mAllInputsAndOutputs) {
853             *mJava << p->variableName << ", ";
854         }
855 
856         if (relaxed) {
857             *mJava << "true";
858         } else {
859             *mJava << "false";
860         }
861         *mJava << ");\n";
862     }
863     mJava->decreaseIndent();
864     mJava->indent() << "} catch (Exception e) {\n";
865     mJava->increaseIndent();
866     mJava->indent() << "throw new RSRuntimeException(\"RenderScript. Can't invoke forEach_"
867                     << mRsKernelName << ": \" + e.toString());\n";
868     mJava->endBlock();
869 }
870 
871 /* Write the section of the .rs file for this permutation.
872  *
873  * We communicate the extra input and output parameters via global allocations.
874  * For example, if we have a function that takes three arguments, two for input
875  * and one for output:
876  *
877  * start:
878  * name: gamn
879  * ret: float3
880  * arg: float3 a
881  * arg: int b
882  * arg: float3 *c
883  * end:
884  *
885  * We'll produce:
886  *
887  * rs_allocation gAllocInB;
888  * rs_allocation gAllocOutC;
889  *
890  * float3 __attribute__((kernel)) test_gamn_float3_int_float3(float3 inA, unsigned int x) {
891  *    int inB;
892  *    float3 outC;
893  *    float2 out;
894  *    inB = rsGetElementAt_int(gAllocInB, x);
895  *    out = gamn(a, in_b, &outC);
896  *    rsSetElementAt_float4(gAllocOutC, &outC, x);
897  *    return out;
898  * }
899  *
900  * We avoid re-using x and y from the definition because these have reserved
901  * meanings in a .rs file.
902  */
writeRsSection(set<string> * rsAllocationsGenerated) const903 void PermutationWriter::writeRsSection(set<string>* rsAllocationsGenerated) const {
904     // Write the allocation declarations we'll need.
905     for (auto p : mPermutation.getParams()) {
906         // Don't need allocation for one input and one return value.
907         if (p != mFirstInputParam) {
908             writeRsAllocationDefinition(*p, rsAllocationsGenerated);
909         }
910     }
911     *mRs << "\n";
912 
913     // Write the function header.
914     if (mReturnParam) {
915         *mRs << mReturnParam->rsType;
916     } else {
917         *mRs << "void";
918     }
919     *mRs << " __attribute__((kernel)) " << mRsKernelName;
920     *mRs << "(";
921     bool needComma = false;
922     if (mFirstInputParam) {
923         *mRs << mFirstInputParam->rsType << " " << mFirstInputParam->variableName;
924         needComma = true;
925     }
926     if (mPermutation.getOutputCount() > 1 || mPermutation.getInputCount() > 1) {
927         if (needComma) {
928             *mRs << ", ";
929         }
930         *mRs << "unsigned int x";
931     }
932     *mRs << ")";
933     mRs->startBlock();
934 
935     // Write the local variable declarations and initializations.
936     for (auto p : mPermutation.getParams()) {
937         if (p == mFirstInputParam) {
938             continue;
939         }
940         mRs->indent() << p->rsType << " " << p->variableName;
941         if (p->isOutParameter) {
942             *mRs << " = 0;\n";
943         } else {
944             *mRs << " = rsGetElementAt_" << p->rsType << "(" << p->rsAllocName << ", x);\n";
945         }
946     }
947 
948     // Write the function call.
949     if (mReturnParam) {
950         if (mPermutation.getOutputCount() > 1) {
951             mRs->indent() << mReturnParam->rsType << " " << mReturnParam->variableName << " = ";
952         } else {
953             mRs->indent() << "return ";
954         }
955     }
956     *mRs << mPermutation.getName() << "(";
957     needComma = false;
958     for (auto p : mPermutation.getParams()) {
959         if (needComma) {
960             *mRs << ", ";
961         }
962         if (p->isOutParameter) {
963             *mRs << "&";
964         }
965         *mRs << p->variableName;
966         needComma = true;
967     }
968     *mRs << ");\n";
969 
970     if (mPermutation.getOutputCount() > 1) {
971         // Write setting the extra out parameters into the allocations.
972         for (auto p : mPermutation.getParams()) {
973             if (p->isOutParameter) {
974                 mRs->indent() << "rsSetElementAt_" << p->rsType << "(" << p->rsAllocName << ", ";
975                 // Check if we need to use '&' for this type of argument.
976                 char lastChar = p->variableName.back();
977                 if (lastChar >= '0' && lastChar <= '9') {
978                     *mRs << "&";
979                 }
980                 *mRs << p->variableName << ", x);\n";
981             }
982         }
983         if (mReturnParam) {
984             mRs->indent() << "return " << mReturnParam->variableName << ";\n";
985         }
986     }
987     mRs->endBlock();
988 }
989 
writeRsAllocationDefinition(const ParameterDefinition & param,set<string> * rsAllocationsGenerated) const990 void PermutationWriter::writeRsAllocationDefinition(const ParameterDefinition& param,
991                                                     set<string>* rsAllocationsGenerated) const {
992     if (!testAndSet(param.rsAllocName, rsAllocationsGenerated)) {
993         *mRs << "rs_allocation " << param.rsAllocName << ";\n";
994     }
995 }
996 
997 // Open the mJavaFile and writes the header.
startJavaFile(GeneratedFile * file,const Function & function,const string & directory,const string & testName,const string & relaxedTestName)998 static bool startJavaFile(GeneratedFile* file, const Function& function, const string& directory,
999                           const string& testName, const string& relaxedTestName) {
1000     const string fileName = testName + ".java";
1001     if (!file->start(directory, fileName)) {
1002         return false;
1003     }
1004     file->writeNotices();
1005 
1006     *file << "package android.renderscript.cts;\n\n";
1007 
1008     *file << "import android.renderscript.Allocation;\n";
1009     *file << "import android.renderscript.RSRuntimeException;\n";
1010     *file << "import android.renderscript.Element;\n";
1011     *file << "import android.renderscript.cts.Target;\n\n";
1012     *file << "import java.util.Arrays;\n\n";
1013 
1014     *file << "public class " << testName << " extends RSBaseCompute";
1015     file->startBlock();  // The corresponding endBlock() is in finishJavaFile()
1016     *file << "\n";
1017 
1018     file->indent() << "private ScriptC_" << testName << " script;\n";
1019     file->indent() << "private ScriptC_" << relaxedTestName << " scriptRelaxed;\n\n";
1020 
1021     file->indent() << "@Override\n";
1022     file->indent() << "protected void setUp() throws Exception";
1023     file->startBlock();
1024 
1025     file->indent() << "super.setUp();\n";
1026     file->indent() << "script = new ScriptC_" << testName << "(mRS);\n";
1027     file->indent() << "scriptRelaxed = new ScriptC_" << relaxedTestName << "(mRS);\n";
1028 
1029     file->endBlock();
1030     *file << "\n";
1031     return true;
1032 }
1033 
1034 // Write the test method that calls all the generated Check methods.
finishJavaFile(GeneratedFile * file,const Function & function,const vector<string> & javaCheckMethods)1035 static void finishJavaFile(GeneratedFile* file, const Function& function,
1036                            const vector<string>& javaCheckMethods) {
1037     file->indent() << "public void test" << function.getCapitalizedName() << "()";
1038     file->startBlock();
1039     for (auto m : javaCheckMethods) {
1040         file->indent() << m << "();\n";
1041     }
1042     file->endBlock();
1043 
1044     file->endBlock();
1045 }
1046 
1047 // Open the script file and write its header.
startRsFile(GeneratedFile * file,const Function & function,const string & directory,const string & testName)1048 static bool startRsFile(GeneratedFile* file, const Function& function, const string& directory,
1049                         const string& testName) {
1050     string fileName = testName + ".rs";
1051     if (!file->start(directory, fileName)) {
1052         return false;
1053     }
1054     file->writeNotices();
1055 
1056     *file << "#pragma version(1)\n";
1057     *file << "#pragma rs java_package_name(android.renderscript.cts)\n\n";
1058     return true;
1059 }
1060 
1061 // Write the entire *Relaxed.rs test file, as it only depends on the name.
writeRelaxedRsFile(const Function & function,const string & directory,const string & testName,const string & relaxedTestName)1062 static bool writeRelaxedRsFile(const Function& function, const string& directory,
1063                                const string& testName, const string& relaxedTestName) {
1064     string name = relaxedTestName + ".rs";
1065 
1066     GeneratedFile file;
1067     if (!file.start(directory, name)) {
1068         return false;
1069     }
1070     file.writeNotices();
1071 
1072     file << "#include \"" << testName << ".rs\"\n";
1073     file << "#pragma rs_fp_relaxed\n";
1074     file.close();
1075     return true;
1076 }
1077 
1078 /* Write the .java and the two .rs test files.  versionOfTestFiles is used to restrict which API
1079  * to test.
1080  */
writeTestFilesForFunction(const Function & function,const string & directory,unsigned int versionOfTestFiles)1081 static bool writeTestFilesForFunction(const Function& function, const string& directory,
1082                                       unsigned int versionOfTestFiles) {
1083     // Avoid creating empty files if we're not testing this function.
1084     if (!needTestFiles(function, versionOfTestFiles)) {
1085         return true;
1086     }
1087 
1088     const string testName = "Test" + function.getCapitalizedName();
1089     const string relaxedTestName = testName + "Relaxed";
1090 
1091     if (!writeRelaxedRsFile(function, directory, testName, relaxedTestName)) {
1092         return false;
1093     }
1094 
1095     GeneratedFile rsFile;    // The Renderscript test file we're generating.
1096     GeneratedFile javaFile;  // The Jave test file we're generating.
1097     if (!startRsFile(&rsFile, function, directory, testName)) {
1098         return false;
1099     }
1100 
1101     if (!startJavaFile(&javaFile, function, directory, testName, relaxedTestName)) {
1102         return false;
1103     }
1104 
1105     /* We keep track of the allocations generated in the .rs file and the argument classes defined
1106      * in the Java file, as we share these between the functions created for each specification.
1107      */
1108     set<string> rsAllocationsGenerated;
1109     set<string> javaGeneratedArgumentClasses;
1110     // Lines of Java code to invoke the check methods.
1111     vector<string> javaCheckMethods;
1112 
1113     for (auto spec : function.getSpecifications()) {
1114         if (spec->hasTests(versionOfTestFiles)) {
1115             for (auto permutation : spec->getPermutations()) {
1116                 PermutationWriter w(*permutation, &rsFile, &javaFile);
1117                 w.writeRsSection(&rsAllocationsGenerated);
1118                 w.writeJavaSection(&javaGeneratedArgumentClasses);
1119 
1120                 // Store the check method to be called.
1121                 javaCheckMethods.push_back(w.getJavaCheckMethodName());
1122             }
1123         }
1124     }
1125 
1126     finishJavaFile(&javaFile, function, javaCheckMethods);
1127     // There's no work to wrap-up in the .rs file.
1128 
1129     rsFile.close();
1130     javaFile.close();
1131     return true;
1132 }
1133 
generateTestFiles(const string & directory,unsigned int versionOfTestFiles)1134 bool generateTestFiles(const string& directory, unsigned int versionOfTestFiles) {
1135     bool success = true;
1136     for (auto f : systemSpecification.getFunctions()) {
1137         if (!writeTestFilesForFunction(*f.second, directory, versionOfTestFiles)) {
1138             success = false;
1139         }
1140     }
1141     return success;
1142 }
1143