1 /*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <iomanip>
18 #include <iostream>
19 #include <cmath>
20 #include <sstream>
21
22 #include "Generator.h"
23 #include "Specification.h"
24 #include "Utilities.h"
25
26 using namespace std;
27
28 // Converts float2 to FLOAT_32 and 2, etc.
convertToRsType(const string & name,string * dataType,char * vectorSize)29 static void convertToRsType(const string& name, string* dataType, char* vectorSize) {
30 string s = name;
31 int last = s.size() - 1;
32 char lastChar = s[last];
33 if (lastChar >= '1' && lastChar <= '4') {
34 s.erase(last);
35 *vectorSize = lastChar;
36 } else {
37 *vectorSize = '1';
38 }
39 dataType->clear();
40 for (int i = 0; i < NUM_TYPES; i++) {
41 if (s == TYPES[i].cType) {
42 *dataType = TYPES[i].rsDataType;
43 break;
44 }
45 }
46 }
47
48 // Returns true if any permutation of the function have tests to b
needTestFiles(const Function & function,unsigned int versionOfTestFiles)49 static bool needTestFiles(const Function& function, unsigned int versionOfTestFiles) {
50 for (auto spec : function.getSpecifications()) {
51 if (spec->hasTests(versionOfTestFiles)) {
52 return true;
53 }
54 }
55 return false;
56 }
57
58 /* One instance of this class is generated for each permutation of a function for which
59 * we are generating test code. This instance will generate both the script and the Java
60 * section of the test files for this permutation. The class is mostly used to keep track
61 * of the various names shared between script and Java files.
62 * WARNING: Because the constructor keeps a reference to the FunctionPermutation, PermutationWriter
63 * should not exceed the lifetime of FunctionPermutation.
64 */
65 class PermutationWriter {
66 private:
67 FunctionPermutation& mPermutation;
68
69 string mRsKernelName;
70 string mJavaArgumentsClassName;
71 string mJavaArgumentsNClassName;
72 string mJavaVerifierComputeMethodName;
73 string mJavaVerifierVerifyMethodName;
74 string mJavaCheckMethodName;
75 string mJavaVerifyMethodName;
76
77 // Pointer to the files we are generating. Handy to avoid always passing them in the calls.
78 GeneratedFile* mRs;
79 GeneratedFile* mJava;
80
81 /* Shortcuts to the return parameter and the first input parameter of the function
82 * specification.
83 */
84 const ParameterDefinition* mReturnParam; // Can be nullptr. NOT OWNED.
85 const ParameterDefinition* mFirstInputParam; // Can be nullptr. NOT OWNED.
86
87 /* All the parameters plus the return param, if present. Collecting them together
88 * simplifies code generation. NOT OWNED.
89 */
90 vector<const ParameterDefinition*> mAllInputsAndOutputs;
91
92 /* We use a class to pass the arguments between the generated code and the CoreVerifier. This
93 * method generates this class. The set keeps track if we've generated this class already
94 * for this test file, as more than one permutation may use the same argument class.
95 */
96 void writeJavaArgumentClass(bool scalar, set<string>* javaGeneratedArgumentClasses) const;
97
98 // Generate the Check* method that invokes the script and calls the verifier.
99 void writeJavaCheckMethod(bool generateCallToVerifier) const;
100
101 // Generate code to define and randomly initialize the input allocation.
102 void writeJavaInputAllocationDefinition(const ParameterDefinition& param) const;
103
104 /* Generate code that instantiate an allocation of floats or integers and fills it with
105 * random data. This random data must be compatible with the specified type. This is
106 * used for the convert_* tests, as converting values that don't fit yield undefined results.
107 */
108 void writeJavaRandomCompatibleFloatAllocation(const string& dataType, const string& seed,
109 char vectorSize,
110 const NumericalType& compatibleType,
111 const NumericalType& generatedType) const;
112 void writeJavaRandomCompatibleIntegerAllocation(const string& dataType, const string& seed,
113 char vectorSize,
114 const NumericalType& compatibleType,
115 const NumericalType& generatedType) const;
116
117 // Generate code that defines an output allocation.
118 void writeJavaOutputAllocationDefinition(const ParameterDefinition& param) const;
119
120 /* Generate the code that verifies the results for RenderScript functions where each entry
121 * of a vector is evaluated independently. If verifierValidates is true, CoreMathVerifier
122 * does the actual validation instead of more commonly returning the range of acceptable values.
123 */
124 void writeJavaVerifyScalarMethod(bool verifierValidates) const;
125
126 /* Generate the code that verify the results for a RenderScript function where a vector
127 * is a point in n-dimensional space.
128 */
129 void writeJavaVerifyVectorMethod() const;
130
131 // Generate the line that creates the Target.
132 void writeJavaCreateTarget() const;
133
134 // Generate the method header of the verify function.
135 void writeJavaVerifyMethodHeader() const;
136
137 // Generate codes that copies the content of an allocation to an array.
138 void writeJavaArrayInitialization(const ParameterDefinition& p) const;
139
140 // Generate code that tests one value returned from the script.
141 void writeJavaTestAndSetValid(const ParameterDefinition& p, const string& argsIndex,
142 const string& actualIndex) const;
143 void writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex,
144 const string& actualIndex) const;
145 // For test:vector cases, generate code that compares returned vector vs. expected value.
146 void writeJavaVectorComparison(const ParameterDefinition& p) const;
147
148 // Muliple functions that generates code to build the error message if an error is found.
149 void writeJavaAppendOutputToMessage(const ParameterDefinition& p, const string& argsIndex,
150 const string& actualIndex, bool verifierValidates) const;
151 void writeJavaAppendInputToMessage(const ParameterDefinition& p, const string& actual) const;
152 void writeJavaAppendNewLineToMessage() const;
153 void writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const;
154 void writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const;
155
156 // Generate the set of instructions to call the script.
157 void writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const;
158
159 // Write an allocation definition if not already emitted in the .rs file.
160 void writeRsAllocationDefinition(const ParameterDefinition& param,
161 set<string>* rsAllocationsGenerated) const;
162
163 public:
164 /* NOTE: We keep pointers to the permutation and the files. This object should not
165 * outlive the arguments.
166 */
167 PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile,
168 GeneratedFile* javaFile);
getJavaCheckMethodName() const169 string getJavaCheckMethodName() const { return mJavaCheckMethodName; }
170
171 // Write the script test function for this permutation.
172 void writeRsSection(set<string>* rsAllocationsGenerated) const;
173 // Write the section of the Java code that calls the script and validates the results
174 void writeJavaSection(set<string>* javaGeneratedArgumentClasses) const;
175 };
176
PermutationWriter(FunctionPermutation & permutation,GeneratedFile * rsFile,GeneratedFile * javaFile)177 PermutationWriter::PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile,
178 GeneratedFile* javaFile)
179 : mPermutation(permutation),
180 mRs(rsFile),
181 mJava(javaFile),
182 mReturnParam(nullptr),
183 mFirstInputParam(nullptr) {
184 mRsKernelName = "test" + capitalize(permutation.getName());
185
186 mJavaArgumentsClassName = "Arguments";
187 mJavaArgumentsNClassName = "Arguments";
188 const string trunk = capitalize(permutation.getNameTrunk());
189 mJavaCheckMethodName = "check" + trunk;
190 mJavaVerifyMethodName = "verifyResults" + trunk;
191
192 for (auto p : permutation.getParams()) {
193 mAllInputsAndOutputs.push_back(p);
194 if (mFirstInputParam == nullptr && !p->isOutParameter) {
195 mFirstInputParam = p;
196 }
197 }
198 mReturnParam = permutation.getReturn();
199 if (mReturnParam) {
200 mAllInputsAndOutputs.push_back(mReturnParam);
201 }
202
203 for (auto p : mAllInputsAndOutputs) {
204 const string capitalizedRsType = capitalize(p->rsType);
205 const string capitalizedBaseType = capitalize(p->rsBaseType);
206 mRsKernelName += capitalizedRsType;
207 mJavaArgumentsClassName += capitalizedBaseType;
208 mJavaArgumentsNClassName += capitalizedBaseType;
209 if (p->mVectorSize != "1") {
210 mJavaArgumentsNClassName += "N";
211 }
212 mJavaCheckMethodName += capitalizedRsType;
213 mJavaVerifyMethodName += capitalizedRsType;
214 }
215 mJavaVerifierComputeMethodName = "compute" + trunk;
216 mJavaVerifierVerifyMethodName = "verify" + trunk;
217 }
218
writeJavaSection(set<string> * javaGeneratedArgumentClasses) const219 void PermutationWriter::writeJavaSection(set<string>* javaGeneratedArgumentClasses) const {
220 // By default, we test the results using item by item comparison.
221 const string test = mPermutation.getTest();
222 if (test == "scalar" || test == "limited") {
223 writeJavaArgumentClass(true, javaGeneratedArgumentClasses);
224 writeJavaCheckMethod(true);
225 writeJavaVerifyScalarMethod(false);
226 } else if (test == "custom") {
227 writeJavaArgumentClass(true, javaGeneratedArgumentClasses);
228 writeJavaCheckMethod(true);
229 writeJavaVerifyScalarMethod(true);
230 } else if (test == "vector") {
231 writeJavaArgumentClass(false, javaGeneratedArgumentClasses);
232 writeJavaCheckMethod(true);
233 writeJavaVerifyVectorMethod();
234 } else if (test == "noverify") {
235 writeJavaCheckMethod(false);
236 }
237 }
238
writeJavaArgumentClass(bool scalar,set<string> * javaGeneratedArgumentClasses) const239 void PermutationWriter::writeJavaArgumentClass(bool scalar,
240 set<string>* javaGeneratedArgumentClasses) const {
241 string name;
242 if (scalar) {
243 name = mJavaArgumentsClassName;
244 } else {
245 name = mJavaArgumentsNClassName;
246 }
247
248 // Make sure we have not generated the argument class already.
249 if (!testAndSet(name, javaGeneratedArgumentClasses)) {
250 mJava->indent() << "public class " << name;
251 mJava->startBlock();
252
253 for (auto p : mAllInputsAndOutputs) {
254 bool isFieldArray = !scalar && p->mVectorSize != "1";
255 bool isFloatyField = p->isOutParameter && p->isFloatType && mPermutation.getTest() != "custom";
256
257 mJava->indent() << "public ";
258 if (isFloatyField) {
259 *mJava << "Target.Floaty";
260 } else {
261 *mJava << p->javaBaseType;
262 }
263 if (isFieldArray) {
264 *mJava << "[]";
265 }
266 *mJava << " " << p->variableName << ";\n";
267
268 // For Float16 parameters, add an extra 'double' field in the class
269 // to hold the Double value converted from the input.
270 if (p->isFloat16Parameter() && !isFloatyField) {
271 mJava->indent() << "public double";
272 if (isFieldArray) {
273 *mJava << "[]";
274 }
275 *mJava << " " + p->variableName << "Double;\n";
276 }
277 }
278 mJava->endBlock();
279 *mJava << "\n";
280 }
281 }
282
writeJavaCheckMethod(bool generateCallToVerifier) const283 void PermutationWriter::writeJavaCheckMethod(bool generateCallToVerifier) const {
284 mJava->indent() << "private void " << mJavaCheckMethodName << "()";
285 mJava->startBlock();
286
287 // Generate the input allocations and initialization.
288 for (auto p : mAllInputsAndOutputs) {
289 if (!p->isOutParameter) {
290 writeJavaInputAllocationDefinition(*p);
291 }
292 }
293 // Generate code to enforce ordering between two allocations if needed.
294 for (auto p : mAllInputsAndOutputs) {
295 if (!p->isOutParameter && !p->smallerParameter.empty()) {
296 string smallerAlloc = "in" + capitalize(p->smallerParameter);
297 mJava->indent() << "enforceOrdering(" << smallerAlloc << ", " << p->javaAllocName
298 << ");\n";
299 }
300 }
301
302 // Generate code to check the full and relaxed scripts.
303 writeJavaCallToRs(false, generateCallToVerifier);
304 writeJavaCallToRs(true, generateCallToVerifier);
305
306 mJava->endBlock();
307 *mJava << "\n";
308 }
309
writeJavaInputAllocationDefinition(const ParameterDefinition & param) const310 void PermutationWriter::writeJavaInputAllocationDefinition(const ParameterDefinition& param) const {
311 string dataType;
312 char vectorSize;
313 convertToRsType(param.rsType, &dataType, &vectorSize);
314
315 const string seed = hashString(mJavaCheckMethodName + param.javaAllocName);
316 mJava->indent() << "Allocation " << param.javaAllocName << " = ";
317 if (param.compatibleTypeIndex >= 0) {
318 if (TYPES[param.typeIndex].kind == FLOATING_POINT) {
319 writeJavaRandomCompatibleFloatAllocation(dataType, seed, vectorSize,
320 TYPES[param.compatibleTypeIndex],
321 TYPES[param.typeIndex]);
322 } else {
323 writeJavaRandomCompatibleIntegerAllocation(dataType, seed, vectorSize,
324 TYPES[param.compatibleTypeIndex],
325 TYPES[param.typeIndex]);
326 }
327 } else if (!param.minValue.empty()) {
328 *mJava << "createRandomFloatAllocation(mRS, Element.DataType." << dataType << ", "
329 << vectorSize << ", " << seed << ", " << param.minValue << ", " << param.maxValue
330 << ")";
331 } else {
332 /* TODO Instead of passing always false, check whether we are doing a limited test.
333 * Use instead: (mPermutation.getTest() == "limited" ? "false" : "true")
334 */
335 *mJava << "createRandomAllocation(mRS, Element.DataType." << dataType << ", " << vectorSize
336 << ", " << seed << ", false)";
337 }
338 *mJava << ";\n";
339 }
340
writeJavaRandomCompatibleFloatAllocation(const string & dataType,const string & seed,char vectorSize,const NumericalType & compatibleType,const NumericalType & generatedType) const341 void PermutationWriter::writeJavaRandomCompatibleFloatAllocation(
342 const string& dataType, const string& seed, char vectorSize,
343 const NumericalType& compatibleType, const NumericalType& generatedType) const {
344 *mJava << "createRandomFloatAllocation"
345 << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", ";
346 double minValue = 0.0;
347 double maxValue = 0.0;
348 switch (compatibleType.kind) {
349 case FLOATING_POINT: {
350 // We're generating floating point values. We just worry about the exponent.
351 // Subtract 1 for the exponent sign.
352 int bits = min(compatibleType.exponentBits, generatedType.exponentBits) - 1;
353 maxValue = ldexp(0.95, (1 << bits) - 1);
354 minValue = -maxValue;
355 break;
356 }
357 case UNSIGNED_INTEGER:
358 maxValue = maxDoubleForInteger(compatibleType.significantBits,
359 generatedType.significantBits);
360 minValue = 0.0;
361 break;
362 case SIGNED_INTEGER:
363 maxValue = maxDoubleForInteger(compatibleType.significantBits,
364 generatedType.significantBits);
365 minValue = -maxValue - 1.0;
366 break;
367 }
368 *mJava << scientific << std::setprecision(19);
369 *mJava << minValue << ", " << maxValue << ")";
370 mJava->unsetf(ios_base::floatfield);
371 }
372
writeJavaRandomCompatibleIntegerAllocation(const string & dataType,const string & seed,char vectorSize,const NumericalType & compatibleType,const NumericalType & generatedType) const373 void PermutationWriter::writeJavaRandomCompatibleIntegerAllocation(
374 const string& dataType, const string& seed, char vectorSize,
375 const NumericalType& compatibleType, const NumericalType& generatedType) const {
376 *mJava << "createRandomIntegerAllocation"
377 << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", ";
378
379 if (compatibleType.kind == FLOATING_POINT) {
380 // Currently, all floating points can take any number we generate.
381 bool isSigned = generatedType.kind == SIGNED_INTEGER;
382 *mJava << (isSigned ? "true" : "false") << ", " << generatedType.significantBits;
383 } else {
384 bool isSigned =
385 compatibleType.kind == SIGNED_INTEGER && generatedType.kind == SIGNED_INTEGER;
386 *mJava << (isSigned ? "true" : "false") << ", "
387 << min(compatibleType.significantBits, generatedType.significantBits);
388 }
389 *mJava << ")";
390 }
391
writeJavaOutputAllocationDefinition(const ParameterDefinition & param) const392 void PermutationWriter::writeJavaOutputAllocationDefinition(
393 const ParameterDefinition& param) const {
394 string dataType;
395 char vectorSize;
396 convertToRsType(param.rsType, &dataType, &vectorSize);
397 mJava->indent() << "Allocation " << param.javaAllocName << " = Allocation.createSized(mRS, "
398 << "getElement(mRS, Element.DataType." << dataType << ", " << vectorSize
399 << "), INPUTSIZE);\n";
400 }
401
writeJavaVerifyScalarMethod(bool verifierValidates) const402 void PermutationWriter::writeJavaVerifyScalarMethod(bool verifierValidates) const {
403 writeJavaVerifyMethodHeader();
404 mJava->startBlock();
405
406 string vectorSize = "1";
407 for (auto p : mAllInputsAndOutputs) {
408 writeJavaArrayInitialization(*p);
409 if (p->mVectorSize != "1" && p->mVectorSize != vectorSize) {
410 if (vectorSize == "1") {
411 vectorSize = p->mVectorSize;
412 } else {
413 cerr << "Error. Had vector " << vectorSize << " and " << p->mVectorSize << "\n";
414 }
415 }
416 }
417
418 mJava->indent() << "StringBuilder message = new StringBuilder();\n";
419 mJava->indent() << "boolean errorFound = false;\n";
420 mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)";
421 mJava->startBlock();
422
423 mJava->indent() << "for (int j = 0; j < " << vectorSize << " ; j++)";
424 mJava->startBlock();
425
426 mJava->indent() << "// Extract the inputs.\n";
427 mJava->indent() << mJavaArgumentsClassName << " args = new " << mJavaArgumentsClassName
428 << "();\n";
429 for (auto p : mAllInputsAndOutputs) {
430 if (!p->isOutParameter) {
431 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i";
432 if (p->vectorWidth != "1") {
433 *mJava << " * " << p->vectorWidth << " + j";
434 }
435 *mJava << "];\n";
436
437 // Convert the Float16 parameter to double and store it in the appropriate field in the
438 // Arguments class.
439 if (p->isFloat16Parameter()) {
440 mJava->indent() << "args." << p->doubleVariableName
441 << " = Float16Utils.convertFloat16ToDouble(args."
442 << p->variableName << ");\n";
443 }
444 }
445 }
446 const bool hasFloat = mPermutation.hasFloatAnswers();
447 if (verifierValidates) {
448 mJava->indent() << "// Extract the outputs.\n";
449 for (auto p : mAllInputsAndOutputs) {
450 if (p->isOutParameter) {
451 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName
452 << "[i * " << p->vectorWidth << " + j];\n";
453 if (p->isFloat16Parameter()) {
454 mJava->indent() << "args." << p->doubleVariableName
455 << " = Float16Utils.convertFloat16ToDouble(args."
456 << p->variableName << ");\n";
457 }
458 }
459 }
460 mJava->indent() << "// Ask the CoreMathVerifier to validate.\n";
461 if (hasFloat) {
462 writeJavaCreateTarget();
463 }
464 mJava->indent() << "String errorMessage = CoreMathVerifier."
465 << mJavaVerifierVerifyMethodName << "(args";
466 if (hasFloat) {
467 *mJava << ", target";
468 }
469 *mJava << ");\n";
470 mJava->indent() << "boolean valid = errorMessage == null;\n";
471 } else {
472 mJava->indent() << "// Figure out what the outputs should have been.\n";
473 if (hasFloat) {
474 writeJavaCreateTarget();
475 }
476 mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args";
477 if (hasFloat) {
478 *mJava << ", target";
479 }
480 *mJava << ");\n";
481 mJava->indent() << "// Validate the outputs.\n";
482 mJava->indent() << "boolean valid = true;\n";
483 for (auto p : mAllInputsAndOutputs) {
484 if (p->isOutParameter) {
485 writeJavaTestAndSetValid(*p, "", "[i * " + p->vectorWidth + " + j]");
486 }
487 }
488 }
489
490 mJava->indent() << "if (!valid)";
491 mJava->startBlock();
492 mJava->indent() << "if (!errorFound)";
493 mJava->startBlock();
494 mJava->indent() << "errorFound = true;\n";
495
496 for (auto p : mAllInputsAndOutputs) {
497 if (p->isOutParameter) {
498 writeJavaAppendOutputToMessage(*p, "", "[i * " + p->vectorWidth + " + j]",
499 verifierValidates);
500 } else {
501 writeJavaAppendInputToMessage(*p, "args." + p->variableName);
502 }
503 }
504 if (verifierValidates) {
505 mJava->indent() << "message.append(errorMessage);\n";
506 }
507 mJava->indent() << "message.append(\"Errors at\");\n";
508 mJava->endBlock();
509
510 mJava->indent() << "message.append(\" [\");\n";
511 mJava->indent() << "message.append(Integer.toString(i));\n";
512 mJava->indent() << "message.append(\", \");\n";
513 mJava->indent() << "message.append(Integer.toString(j));\n";
514 mJava->indent() << "message.append(\"]\");\n";
515
516 mJava->endBlock();
517 mJava->endBlock();
518 mJava->endBlock();
519
520 mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n";
521 mJava->indentPlus()
522 << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n";
523
524 mJava->endBlock();
525 *mJava << "\n";
526 }
527
writeJavaVerifyVectorMethod() const528 void PermutationWriter::writeJavaVerifyVectorMethod() const {
529 writeJavaVerifyMethodHeader();
530 mJava->startBlock();
531
532 for (auto p : mAllInputsAndOutputs) {
533 writeJavaArrayInitialization(*p);
534 }
535 mJava->indent() << "StringBuilder message = new StringBuilder();\n";
536 mJava->indent() << "boolean errorFound = false;\n";
537 mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)";
538 mJava->startBlock();
539
540 mJava->indent() << mJavaArgumentsNClassName << " args = new " << mJavaArgumentsNClassName
541 << "();\n";
542
543 mJava->indent() << "// Create the appropriate sized arrays in args\n";
544 for (auto p : mAllInputsAndOutputs) {
545 if (p->mVectorSize != "1") {
546 string type = p->javaBaseType;
547 if (p->isOutParameter && p->isFloatType) {
548 type = "Target.Floaty";
549 }
550 mJava->indent() << "args." << p->variableName << " = new " << type << "["
551 << p->mVectorSize << "];\n";
552 if (p->isFloat16Parameter() && !p->isOutParameter) {
553 mJava->indent() << "args." << p->variableName << "Double = new double["
554 << p->mVectorSize << "];\n";
555 }
556 }
557 }
558
559 mJava->indent() << "// Fill args with the input values\n";
560 for (auto p : mAllInputsAndOutputs) {
561 if (!p->isOutParameter) {
562 if (p->mVectorSize == "1") {
563 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i]"
564 << ";\n";
565 // Convert the Float16 parameter to double and store it in the appropriate field in
566 // the Arguments class.
567 if (p->isFloat16Parameter()) {
568 mJava->indent() << "args." << p->doubleVariableName << " = "
569 << "Float16Utils.convertFloat16ToDouble(args."
570 << p->variableName << ");\n";
571 }
572 } else {
573 mJava->indent() << "for (int j = 0; j < " << p->mVectorSize << " ; j++)";
574 mJava->startBlock();
575 mJava->indent() << "args." << p->variableName << "[j] = "
576 << p->javaArrayName << "[i * " << p->vectorWidth << " + j]"
577 << ";\n";
578
579 // Convert the Float16 parameter to double and store it in the appropriate field in
580 // the Arguments class.
581 if (p->isFloat16Parameter()) {
582 mJava->indent() << "args." << p->doubleVariableName << "[j] = "
583 << "Float16Utils.convertFloat16ToDouble(args."
584 << p->variableName << "[j]);\n";
585 }
586 mJava->endBlock();
587 }
588 }
589 }
590 writeJavaCreateTarget();
591 mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName
592 << "(args, target);\n\n";
593
594 mJava->indent() << "// Compare the expected outputs to the actual values returned by RS.\n";
595 mJava->indent() << "boolean valid = true;\n";
596 for (auto p : mAllInputsAndOutputs) {
597 if (p->isOutParameter) {
598 writeJavaVectorComparison(*p);
599 }
600 }
601
602 mJava->indent() << "if (!valid)";
603 mJava->startBlock();
604 mJava->indent() << "if (!errorFound)";
605 mJava->startBlock();
606 mJava->indent() << "errorFound = true;\n";
607
608 for (auto p : mAllInputsAndOutputs) {
609 if (p->isOutParameter) {
610 writeJavaAppendVectorOutputToMessage(*p);
611 } else {
612 writeJavaAppendVectorInputToMessage(*p);
613 }
614 }
615 mJava->indent() << "message.append(\"Errors at\");\n";
616 mJava->endBlock();
617
618 mJava->indent() << "message.append(\" [\");\n";
619 mJava->indent() << "message.append(Integer.toString(i));\n";
620 mJava->indent() << "message.append(\"]\");\n";
621
622 mJava->endBlock();
623 mJava->endBlock();
624
625 mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n";
626 mJava->indentPlus()
627 << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n";
628
629 mJava->endBlock();
630 *mJava << "\n";
631 }
632
633
writeJavaCreateTarget() const634 void PermutationWriter::writeJavaCreateTarget() const {
635 string name = mPermutation.getName();
636
637 const char* functionType = "NORMAL";
638 size_t end = name.find('_');
639 if (end != string::npos) {
640 if (name.compare(0, end, "native") == 0) {
641 functionType = "NATIVE";
642 } else if (name.compare(0, end, "half") == 0) {
643 functionType = "HALF";
644 } else if (name.compare(0, end, "fast") == 0) {
645 functionType = "FAST";
646 }
647 }
648
649 string floatType = mReturnParam->specType;
650 const char* precisionStr = "";
651 if (floatType.compare("f16") == 0) {
652 precisionStr = "HALF";
653 } else if (floatType.compare("f32") == 0) {
654 precisionStr = "FLOAT";
655 } else if (floatType.compare("f64") == 0) {
656 precisionStr = "DOUBLE";
657 } else {
658 cerr << "Error. Unreachable. Return type is not floating point\n";
659 }
660
661 mJava->indent() << "Target target = new Target(Target.FunctionType." <<
662 functionType << ", Target.ReturnType." << precisionStr <<
663 ", relaxed);\n";
664 }
665
writeJavaVerifyMethodHeader() const666 void PermutationWriter::writeJavaVerifyMethodHeader() const {
667 mJava->indent() << "private void " << mJavaVerifyMethodName << "(";
668 for (auto p : mAllInputsAndOutputs) {
669 *mJava << "Allocation " << p->javaAllocName << ", ";
670 }
671 *mJava << "boolean relaxed)";
672 }
673
writeJavaArrayInitialization(const ParameterDefinition & p) const674 void PermutationWriter::writeJavaArrayInitialization(const ParameterDefinition& p) const {
675 mJava->indent() << p.javaBaseType << "[] " << p.javaArrayName << " = new " << p.javaBaseType
676 << "[INPUTSIZE * " << p.vectorWidth << "];\n";
677
678 /* For basic types, populate the array with values, to help understand failures. We have had
679 * bugs where the output buffer was all 0. We were not sure if there was a failed copy or
680 * the GPU driver was copying zeroes.
681 */
682 if (p.typeIndex >= 0) {
683 mJava->indent() << "Arrays.fill(" << p.javaArrayName << ", (" << TYPES[p.typeIndex].javaType
684 << ") 42);\n";
685 }
686
687 mJava->indent() << p.javaAllocName << ".copyTo(" << p.javaArrayName << ");\n";
688 }
689
writeJavaTestAndSetValid(const ParameterDefinition & p,const string & argsIndex,const string & actualIndex) const690 void PermutationWriter::writeJavaTestAndSetValid(const ParameterDefinition& p,
691 const string& argsIndex,
692 const string& actualIndex) const {
693 writeJavaTestOneValue(p, argsIndex, actualIndex);
694 mJava->startBlock();
695 mJava->indent() << "valid = false;\n";
696 mJava->endBlock();
697 }
698
writeJavaTestOneValue(const ParameterDefinition & p,const string & argsIndex,const string & actualIndex) const699 void PermutationWriter::writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex,
700 const string& actualIndex) const {
701 string actualOut;
702 if (p.isFloat16Parameter()) {
703 // For Float16 values, the output needs to be converted to Double.
704 actualOut = "Float16Utils.convertFloat16ToDouble(" + p.javaArrayName + actualIndex + ")";
705 } else {
706 actualOut = p.javaArrayName + actualIndex;
707 }
708
709 mJava->indent() << "if (";
710 if (p.isFloatType) {
711 *mJava << "!args." << p.variableName << argsIndex << ".couldBe(" << actualOut;
712 const string s = mPermutation.getPrecisionLimit();
713 if (!s.empty()) {
714 *mJava << ", " << s;
715 }
716 *mJava << ")";
717 } else {
718 *mJava << "args." << p.variableName << argsIndex << " != " << p.javaArrayName
719 << actualIndex;
720 }
721
722 if (p.undefinedIfOutIsNan && mReturnParam) {
723 *mJava << " && !args." << mReturnParam->variableName << argsIndex << ".isNaN()";
724 }
725 *mJava << ")";
726 }
727
writeJavaVectorComparison(const ParameterDefinition & p) const728 void PermutationWriter::writeJavaVectorComparison(const ParameterDefinition& p) const {
729 if (p.mVectorSize == "1") {
730 writeJavaTestAndSetValid(p, "", "[i]");
731 } else {
732 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
733 mJava->startBlock();
734 writeJavaTestAndSetValid(p, "[j]", "[i * " + p.vectorWidth + " + j]");
735 mJava->endBlock();
736 }
737 }
738
writeJavaAppendOutputToMessage(const ParameterDefinition & p,const string & argsIndex,const string & actualIndex,bool verifierValidates) const739 void PermutationWriter::writeJavaAppendOutputToMessage(const ParameterDefinition& p,
740 const string& argsIndex,
741 const string& actualIndex,
742 bool verifierValidates) const {
743 if (verifierValidates) {
744 mJava->indent() << "message.append(\"Output " << p.variableName << ": \");\n";
745 mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex
746 << ");\n";
747 writeJavaAppendNewLineToMessage();
748 if (p.isFloat16Parameter()) {
749 writeJavaAppendNewLineToMessage();
750 mJava->indent() << "message.append(\"Output " << p.variableName
751 << " (in double): \");\n";
752 mJava->indent() << "appendVariableToMessage(message, args." << p.doubleVariableName
753 << ");\n";
754 writeJavaAppendNewLineToMessage();
755 }
756 } else {
757 mJava->indent() << "message.append(\"Expected output " << p.variableName << ": \");\n";
758 mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex
759 << ");\n";
760 writeJavaAppendNewLineToMessage();
761
762 mJava->indent() << "message.append(\"Actual output " << p.variableName << ": \");\n";
763 mJava->indent() << "appendVariableToMessage(message, " << p.javaArrayName << actualIndex
764 << ");\n";
765
766 if (p.isFloat16Parameter()) {
767 writeJavaAppendNewLineToMessage();
768 mJava->indent() << "message.append(\"Actual output " << p.variableName
769 << " (in double): \");\n";
770 mJava->indent() << "appendVariableToMessage(message, Float16Utils.convertFloat16ToDouble("
771 << p.javaArrayName << actualIndex << "));\n";
772 }
773
774 writeJavaTestOneValue(p, argsIndex, actualIndex);
775 mJava->startBlock();
776 mJava->indent() << "message.append(\" FAIL\");\n";
777 mJava->endBlock();
778 writeJavaAppendNewLineToMessage();
779 }
780 }
781
writeJavaAppendInputToMessage(const ParameterDefinition & p,const string & actual) const782 void PermutationWriter::writeJavaAppendInputToMessage(const ParameterDefinition& p,
783 const string& actual) const {
784 mJava->indent() << "message.append(\"Input " << p.variableName << ": \");\n";
785 mJava->indent() << "appendVariableToMessage(message, " << actual << ");\n";
786 writeJavaAppendNewLineToMessage();
787 }
788
writeJavaAppendNewLineToMessage() const789 void PermutationWriter::writeJavaAppendNewLineToMessage() const {
790 mJava->indent() << "message.append(\"\\n\");\n";
791 }
792
writeJavaAppendVectorInputToMessage(const ParameterDefinition & p) const793 void PermutationWriter::writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const {
794 if (p.mVectorSize == "1") {
795 writeJavaAppendInputToMessage(p, p.javaArrayName + "[i]");
796 } else {
797 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
798 mJava->startBlock();
799 writeJavaAppendInputToMessage(p, p.javaArrayName + "[i * " + p.vectorWidth + " + j]");
800 mJava->endBlock();
801 }
802 }
803
writeJavaAppendVectorOutputToMessage(const ParameterDefinition & p) const804 void PermutationWriter::writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const {
805 if (p.mVectorSize == "1") {
806 writeJavaAppendOutputToMessage(p, "", "[i]", false);
807 } else {
808 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
809 mJava->startBlock();
810 writeJavaAppendOutputToMessage(p, "[j]", "[i * " + p.vectorWidth + " + j]", false);
811 mJava->endBlock();
812 }
813 }
814
writeJavaCallToRs(bool relaxed,bool generateCallToVerifier) const815 void PermutationWriter::writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const {
816 string script = "script";
817 if (relaxed) {
818 script += "Relaxed";
819 }
820
821 mJava->indent() << "try";
822 mJava->startBlock();
823
824 for (auto p : mAllInputsAndOutputs) {
825 if (p->isOutParameter) {
826 writeJavaOutputAllocationDefinition(*p);
827 }
828 }
829
830 for (auto p : mPermutation.getParams()) {
831 if (p != mFirstInputParam) {
832 mJava->indent() << script << ".set_" << p->rsAllocName << "(" << p->javaAllocName
833 << ");\n";
834 }
835 }
836
837 mJava->indent() << script << ".forEach_" << mRsKernelName << "(";
838 bool needComma = false;
839 if (mFirstInputParam) {
840 *mJava << mFirstInputParam->javaAllocName;
841 needComma = true;
842 }
843 if (mReturnParam) {
844 if (needComma) {
845 *mJava << ", ";
846 }
847 *mJava << mReturnParam->variableName << ");\n";
848 }
849
850 if (generateCallToVerifier) {
851 mJava->indent() << mJavaVerifyMethodName << "(";
852 for (auto p : mAllInputsAndOutputs) {
853 *mJava << p->variableName << ", ";
854 }
855
856 if (relaxed) {
857 *mJava << "true";
858 } else {
859 *mJava << "false";
860 }
861 *mJava << ");\n";
862 }
863 mJava->decreaseIndent();
864 mJava->indent() << "} catch (Exception e) {\n";
865 mJava->increaseIndent();
866 mJava->indent() << "throw new RSRuntimeException(\"RenderScript. Can't invoke forEach_"
867 << mRsKernelName << ": \" + e.toString());\n";
868 mJava->endBlock();
869 }
870
871 /* Write the section of the .rs file for this permutation.
872 *
873 * We communicate the extra input and output parameters via global allocations.
874 * For example, if we have a function that takes three arguments, two for input
875 * and one for output:
876 *
877 * start:
878 * name: gamn
879 * ret: float3
880 * arg: float3 a
881 * arg: int b
882 * arg: float3 *c
883 * end:
884 *
885 * We'll produce:
886 *
887 * rs_allocation gAllocInB;
888 * rs_allocation gAllocOutC;
889 *
890 * float3 __attribute__((kernel)) test_gamn_float3_int_float3(float3 inA, unsigned int x) {
891 * int inB;
892 * float3 outC;
893 * float2 out;
894 * inB = rsGetElementAt_int(gAllocInB, x);
895 * out = gamn(a, in_b, &outC);
896 * rsSetElementAt_float4(gAllocOutC, &outC, x);
897 * return out;
898 * }
899 *
900 * We avoid re-using x and y from the definition because these have reserved
901 * meanings in a .rs file.
902 */
writeRsSection(set<string> * rsAllocationsGenerated) const903 void PermutationWriter::writeRsSection(set<string>* rsAllocationsGenerated) const {
904 // Write the allocation declarations we'll need.
905 for (auto p : mPermutation.getParams()) {
906 // Don't need allocation for one input and one return value.
907 if (p != mFirstInputParam) {
908 writeRsAllocationDefinition(*p, rsAllocationsGenerated);
909 }
910 }
911 *mRs << "\n";
912
913 // Write the function header.
914 if (mReturnParam) {
915 *mRs << mReturnParam->rsType;
916 } else {
917 *mRs << "void";
918 }
919 *mRs << " __attribute__((kernel)) " << mRsKernelName;
920 *mRs << "(";
921 bool needComma = false;
922 if (mFirstInputParam) {
923 *mRs << mFirstInputParam->rsType << " " << mFirstInputParam->variableName;
924 needComma = true;
925 }
926 if (mPermutation.getOutputCount() > 1 || mPermutation.getInputCount() > 1) {
927 if (needComma) {
928 *mRs << ", ";
929 }
930 *mRs << "unsigned int x";
931 }
932 *mRs << ")";
933 mRs->startBlock();
934
935 // Write the local variable declarations and initializations.
936 for (auto p : mPermutation.getParams()) {
937 if (p == mFirstInputParam) {
938 continue;
939 }
940 mRs->indent() << p->rsType << " " << p->variableName;
941 if (p->isOutParameter) {
942 *mRs << " = 0;\n";
943 } else {
944 *mRs << " = rsGetElementAt_" << p->rsType << "(" << p->rsAllocName << ", x);\n";
945 }
946 }
947
948 // Write the function call.
949 if (mReturnParam) {
950 if (mPermutation.getOutputCount() > 1) {
951 mRs->indent() << mReturnParam->rsType << " " << mReturnParam->variableName << " = ";
952 } else {
953 mRs->indent() << "return ";
954 }
955 }
956 *mRs << mPermutation.getName() << "(";
957 needComma = false;
958 for (auto p : mPermutation.getParams()) {
959 if (needComma) {
960 *mRs << ", ";
961 }
962 if (p->isOutParameter) {
963 *mRs << "&";
964 }
965 *mRs << p->variableName;
966 needComma = true;
967 }
968 *mRs << ");\n";
969
970 if (mPermutation.getOutputCount() > 1) {
971 // Write setting the extra out parameters into the allocations.
972 for (auto p : mPermutation.getParams()) {
973 if (p->isOutParameter) {
974 mRs->indent() << "rsSetElementAt_" << p->rsType << "(" << p->rsAllocName << ", ";
975 // Check if we need to use '&' for this type of argument.
976 char lastChar = p->variableName.back();
977 if (lastChar >= '0' && lastChar <= '9') {
978 *mRs << "&";
979 }
980 *mRs << p->variableName << ", x);\n";
981 }
982 }
983 if (mReturnParam) {
984 mRs->indent() << "return " << mReturnParam->variableName << ";\n";
985 }
986 }
987 mRs->endBlock();
988 }
989
writeRsAllocationDefinition(const ParameterDefinition & param,set<string> * rsAllocationsGenerated) const990 void PermutationWriter::writeRsAllocationDefinition(const ParameterDefinition& param,
991 set<string>* rsAllocationsGenerated) const {
992 if (!testAndSet(param.rsAllocName, rsAllocationsGenerated)) {
993 *mRs << "rs_allocation " << param.rsAllocName << ";\n";
994 }
995 }
996
997 // Open the mJavaFile and writes the header.
startJavaFile(GeneratedFile * file,const Function & function,const string & directory,const string & testName,const string & relaxedTestName)998 static bool startJavaFile(GeneratedFile* file, const Function& function, const string& directory,
999 const string& testName, const string& relaxedTestName) {
1000 const string fileName = testName + ".java";
1001 if (!file->start(directory, fileName)) {
1002 return false;
1003 }
1004 file->writeNotices();
1005
1006 *file << "package android.renderscript.cts;\n\n";
1007
1008 *file << "import android.renderscript.Allocation;\n";
1009 *file << "import android.renderscript.RSRuntimeException;\n";
1010 *file << "import android.renderscript.Element;\n";
1011 *file << "import android.renderscript.cts.Target;\n\n";
1012 *file << "import java.util.Arrays;\n\n";
1013
1014 *file << "public class " << testName << " extends RSBaseCompute";
1015 file->startBlock(); // The corresponding endBlock() is in finishJavaFile()
1016 *file << "\n";
1017
1018 file->indent() << "private ScriptC_" << testName << " script;\n";
1019 file->indent() << "private ScriptC_" << relaxedTestName << " scriptRelaxed;\n\n";
1020
1021 file->indent() << "@Override\n";
1022 file->indent() << "protected void setUp() throws Exception";
1023 file->startBlock();
1024
1025 file->indent() << "super.setUp();\n";
1026 file->indent() << "script = new ScriptC_" << testName << "(mRS);\n";
1027 file->indent() << "scriptRelaxed = new ScriptC_" << relaxedTestName << "(mRS);\n";
1028
1029 file->endBlock();
1030 *file << "\n";
1031 return true;
1032 }
1033
1034 // Write the test method that calls all the generated Check methods.
finishJavaFile(GeneratedFile * file,const Function & function,const vector<string> & javaCheckMethods)1035 static void finishJavaFile(GeneratedFile* file, const Function& function,
1036 const vector<string>& javaCheckMethods) {
1037 file->indent() << "public void test" << function.getCapitalizedName() << "()";
1038 file->startBlock();
1039 for (auto m : javaCheckMethods) {
1040 file->indent() << m << "();\n";
1041 }
1042 file->endBlock();
1043
1044 file->endBlock();
1045 }
1046
1047 // Open the script file and write its header.
startRsFile(GeneratedFile * file,const Function & function,const string & directory,const string & testName)1048 static bool startRsFile(GeneratedFile* file, const Function& function, const string& directory,
1049 const string& testName) {
1050 string fileName = testName + ".rs";
1051 if (!file->start(directory, fileName)) {
1052 return false;
1053 }
1054 file->writeNotices();
1055
1056 *file << "#pragma version(1)\n";
1057 *file << "#pragma rs java_package_name(android.renderscript.cts)\n\n";
1058 return true;
1059 }
1060
1061 // Write the entire *Relaxed.rs test file, as it only depends on the name.
writeRelaxedRsFile(const Function & function,const string & directory,const string & testName,const string & relaxedTestName)1062 static bool writeRelaxedRsFile(const Function& function, const string& directory,
1063 const string& testName, const string& relaxedTestName) {
1064 string name = relaxedTestName + ".rs";
1065
1066 GeneratedFile file;
1067 if (!file.start(directory, name)) {
1068 return false;
1069 }
1070 file.writeNotices();
1071
1072 file << "#include \"" << testName << ".rs\"\n";
1073 file << "#pragma rs_fp_relaxed\n";
1074 file.close();
1075 return true;
1076 }
1077
1078 /* Write the .java and the two .rs test files. versionOfTestFiles is used to restrict which API
1079 * to test.
1080 */
writeTestFilesForFunction(const Function & function,const string & directory,unsigned int versionOfTestFiles)1081 static bool writeTestFilesForFunction(const Function& function, const string& directory,
1082 unsigned int versionOfTestFiles) {
1083 // Avoid creating empty files if we're not testing this function.
1084 if (!needTestFiles(function, versionOfTestFiles)) {
1085 return true;
1086 }
1087
1088 const string testName = "Test" + function.getCapitalizedName();
1089 const string relaxedTestName = testName + "Relaxed";
1090
1091 if (!writeRelaxedRsFile(function, directory, testName, relaxedTestName)) {
1092 return false;
1093 }
1094
1095 GeneratedFile rsFile; // The Renderscript test file we're generating.
1096 GeneratedFile javaFile; // The Jave test file we're generating.
1097 if (!startRsFile(&rsFile, function, directory, testName)) {
1098 return false;
1099 }
1100
1101 if (!startJavaFile(&javaFile, function, directory, testName, relaxedTestName)) {
1102 return false;
1103 }
1104
1105 /* We keep track of the allocations generated in the .rs file and the argument classes defined
1106 * in the Java file, as we share these between the functions created for each specification.
1107 */
1108 set<string> rsAllocationsGenerated;
1109 set<string> javaGeneratedArgumentClasses;
1110 // Lines of Java code to invoke the check methods.
1111 vector<string> javaCheckMethods;
1112
1113 for (auto spec : function.getSpecifications()) {
1114 if (spec->hasTests(versionOfTestFiles)) {
1115 for (auto permutation : spec->getPermutations()) {
1116 PermutationWriter w(*permutation, &rsFile, &javaFile);
1117 w.writeRsSection(&rsAllocationsGenerated);
1118 w.writeJavaSection(&javaGeneratedArgumentClasses);
1119
1120 // Store the check method to be called.
1121 javaCheckMethods.push_back(w.getJavaCheckMethodName());
1122 }
1123 }
1124 }
1125
1126 finishJavaFile(&javaFile, function, javaCheckMethods);
1127 // There's no work to wrap-up in the .rs file.
1128
1129 rsFile.close();
1130 javaFile.close();
1131 return true;
1132 }
1133
generateTestFiles(const string & directory,unsigned int versionOfTestFiles)1134 bool generateTestFiles(const string& directory, unsigned int versionOfTestFiles) {
1135 bool success = true;
1136 for (auto f : systemSpecification.getFunctions()) {
1137 if (!writeTestFilesForFunction(*f.second, directory, versionOfTestFiles)) {
1138 success = false;
1139 }
1140 }
1141 return success;
1142 }
1143