1 /*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <iomanip>
18 #include <iostream>
19 #include <cmath>
20 #include <sstream>
21
22 #include "Generator.h"
23 #include "Specification.h"
24 #include "Utilities.h"
25
26 using namespace std;
27
28 // Converts float2 to FLOAT_32 and 2, etc.
convertToRsType(const string & name,string * dataType,char * vectorSize)29 static void convertToRsType(const string& name, string* dataType, char* vectorSize) {
30 string s = name;
31 int last = s.size() - 1;
32 char lastChar = s[last];
33 if (lastChar >= '1' && lastChar <= '4') {
34 s.erase(last);
35 *vectorSize = lastChar;
36 } else {
37 *vectorSize = '1';
38 }
39 dataType->clear();
40 for (int i = 0; i < NUM_TYPES; i++) {
41 if (s == TYPES[i].cType) {
42 *dataType = TYPES[i].rsDataType;
43 break;
44 }
45 }
46 }
47
48 // Returns true if any permutation of the function have tests to b
needTestFiles(const Function & function,unsigned int versionOfTestFiles)49 static bool needTestFiles(const Function& function, unsigned int versionOfTestFiles) {
50 for (auto spec : function.getSpecifications()) {
51 if (spec->hasTests(versionOfTestFiles)) {
52 return true;
53 }
54 }
55 return false;
56 }
57
58 /* One instance of this class is generated for each permutation of a function for which
59 * we are generating test code. This instance will generate both the script and the Java
60 * section of the test files for this permutation. The class is mostly used to keep track
61 * of the various names shared between script and Java files.
62 * WARNING: Because the constructor keeps a reference to the FunctionPermutation, PermutationWriter
63 * should not exceed the lifetime of FunctionPermutation.
64 */
65 class PermutationWriter {
66 private:
67 FunctionPermutation& mPermutation;
68
69 string mRsKernelName;
70 string mJavaArgumentsClassName;
71 string mJavaArgumentsNClassName;
72 string mJavaVerifierComputeMethodName;
73 string mJavaVerifierVerifyMethodName;
74 string mJavaCheckMethodName;
75 string mJavaVerifyMethodName;
76
77 // Pointer to the files we are generating. Handy to avoid always passing them in the calls.
78 GeneratedFile* mRs;
79 GeneratedFile* mJava;
80
81 /* Shortcuts to the return parameter and the first input parameter of the function
82 * specification.
83 */
84 const ParameterDefinition* mReturnParam; // Can be nullptr. NOT OWNED.
85 const ParameterDefinition* mFirstInputParam; // Can be nullptr. NOT OWNED.
86
87 /* All the parameters plus the return param, if present. Collecting them together
88 * simplifies code generation. NOT OWNED.
89 */
90 vector<const ParameterDefinition*> mAllInputsAndOutputs;
91
92 /* We use a class to pass the arguments between the generated code and the CoreVerifier. This
93 * method generates this class. The set keeps track if we've generated this class already
94 * for this test file, as more than one permutation may use the same argument class.
95 */
96 void writeJavaArgumentClass(bool scalar, set<string>* javaGeneratedArgumentClasses) const;
97
98 // Generate the Check* method that invokes the script and calls the verifier.
99 void writeJavaCheckMethod(bool generateCallToVerifier) const;
100
101 // Generate code to define and randomly initialize the input allocation.
102 void writeJavaInputAllocationDefinition(const ParameterDefinition& param) const;
103
104 /* Generate code that instantiate an allocation of floats or integers and fills it with
105 * random data. This random data must be compatible with the specified type. This is
106 * used for the convert_* tests, as converting values that don't fit yield undefined results.
107 */
108 void writeJavaRandomCompatibleFloatAllocation(const string& dataType, const string& seed,
109 char vectorSize,
110 const NumericalType& compatibleType,
111 const NumericalType& generatedType) const;
112 void writeJavaRandomCompatibleIntegerAllocation(const string& dataType, const string& seed,
113 char vectorSize,
114 const NumericalType& compatibleType,
115 const NumericalType& generatedType) const;
116
117 // Generate code that defines an output allocation.
118 void writeJavaOutputAllocationDefinition(const ParameterDefinition& param) const;
119
120 /* Generate the code that verifies the results for RenderScript functions where each entry
121 * of a vector is evaluated independently. If verifierValidates is true, CoreMathVerifier
122 * does the actual validation instead of more commonly returning the range of acceptable values.
123 */
124 void writeJavaVerifyScalarMethod(bool verifierValidates) const;
125
126 /* Generate the code that verify the results for a RenderScript function where a vector
127 * is a point in n-dimensional space.
128 */
129 void writeJavaVerifyVectorMethod() const;
130
131 // Generate the line that creates the Target.
132 void writeJavaCreateTarget() const;
133
134 // Generate the method header of the verify function.
135 void writeJavaVerifyMethodHeader() const;
136
137 // Generate codes that copies the content of an allocation to an array.
138 void writeJavaArrayInitialization(const ParameterDefinition& p) const;
139
140 // Generate code that tests one value returned from the script.
141 void writeJavaTestAndSetValid(const ParameterDefinition& p, const string& argsIndex,
142 const string& actualIndex) const;
143 void writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex,
144 const string& actualIndex) const;
145 // For test:vector cases, generate code that compares returned vector vs. expected value.
146 void writeJavaVectorComparison(const ParameterDefinition& p) const;
147
148 // Muliple functions that generates code to build the error message if an error is found.
149 void writeJavaAppendOutputToMessage(const ParameterDefinition& p, const string& argsIndex,
150 const string& actualIndex, bool verifierValidates) const;
151 void writeJavaAppendInputToMessage(const ParameterDefinition& p, const string& actual) const;
152 void writeJavaAppendNewLineToMessage() const;
153 void writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const;
154 void writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const;
155
156 // Generate the set of instructions to call the script.
157 void writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const;
158
159 // Write an allocation definition if not already emitted in the .rs file.
160 void writeRsAllocationDefinition(const ParameterDefinition& param,
161 set<string>* rsAllocationsGenerated) const;
162
163 public:
164 /* NOTE: We keep pointers to the permutation and the files. This object should not
165 * outlive the arguments.
166 */
167 PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile,
168 GeneratedFile* javaFile);
getJavaCheckMethodName() const169 string getJavaCheckMethodName() const { return mJavaCheckMethodName; }
170
171 // Write the script test function for this permutation.
172 void writeRsSection(set<string>* rsAllocationsGenerated) const;
173 // Write the section of the Java code that calls the script and validates the results
174 void writeJavaSection(set<string>* javaGeneratedArgumentClasses) const;
175 };
176
PermutationWriter(FunctionPermutation & permutation,GeneratedFile * rsFile,GeneratedFile * javaFile)177 PermutationWriter::PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile,
178 GeneratedFile* javaFile)
179 : mPermutation(permutation),
180 mRs(rsFile),
181 mJava(javaFile),
182 mReturnParam(nullptr),
183 mFirstInputParam(nullptr) {
184 mRsKernelName = "test" + capitalize(permutation.getName());
185
186 mJavaArgumentsClassName = "Arguments";
187 mJavaArgumentsNClassName = "Arguments";
188 const string trunk = capitalize(permutation.getNameTrunk());
189 mJavaCheckMethodName = "check" + trunk;
190 mJavaVerifyMethodName = "verifyResults" + trunk;
191
192 for (auto p : permutation.getParams()) {
193 mAllInputsAndOutputs.push_back(p);
194 if (mFirstInputParam == nullptr && !p->isOutParameter) {
195 mFirstInputParam = p;
196 }
197 }
198 mReturnParam = permutation.getReturn();
199 if (mReturnParam) {
200 mAllInputsAndOutputs.push_back(mReturnParam);
201 }
202
203 for (auto p : mAllInputsAndOutputs) {
204 const string capitalizedRsType = capitalize(p->rsType);
205 const string capitalizedBaseType = capitalize(p->rsBaseType);
206 mRsKernelName += capitalizedRsType;
207 mJavaArgumentsClassName += capitalizedBaseType;
208 mJavaArgumentsNClassName += capitalizedBaseType;
209 if (p->mVectorSize != "1") {
210 mJavaArgumentsNClassName += "N";
211 }
212 mJavaCheckMethodName += capitalizedRsType;
213 mJavaVerifyMethodName += capitalizedRsType;
214 }
215 mJavaVerifierComputeMethodName = "compute" + trunk;
216 mJavaVerifierVerifyMethodName = "verify" + trunk;
217 }
218
writeJavaSection(set<string> * javaGeneratedArgumentClasses) const219 void PermutationWriter::writeJavaSection(set<string>* javaGeneratedArgumentClasses) const {
220 // By default, we test the results using item by item comparison.
221 const string test = mPermutation.getTest();
222 if (test == "scalar" || test == "limited") {
223 writeJavaArgumentClass(true, javaGeneratedArgumentClasses);
224 writeJavaCheckMethod(true);
225 writeJavaVerifyScalarMethod(false);
226 } else if (test == "custom") {
227 writeJavaArgumentClass(true, javaGeneratedArgumentClasses);
228 writeJavaCheckMethod(true);
229 writeJavaVerifyScalarMethod(true);
230 } else if (test == "vector") {
231 writeJavaArgumentClass(false, javaGeneratedArgumentClasses);
232 writeJavaCheckMethod(true);
233 writeJavaVerifyVectorMethod();
234 } else if (test == "noverify") {
235 writeJavaCheckMethod(false);
236 }
237 }
238
writeJavaArgumentClass(bool scalar,set<string> * javaGeneratedArgumentClasses) const239 void PermutationWriter::writeJavaArgumentClass(bool scalar,
240 set<string>* javaGeneratedArgumentClasses) const {
241 string name;
242 if (scalar) {
243 name = mJavaArgumentsClassName;
244 } else {
245 name = mJavaArgumentsNClassName;
246 }
247
248 // Make sure we have not generated the argument class already.
249 if (!testAndSet(name, javaGeneratedArgumentClasses)) {
250 mJava->indent() << "public class " << name;
251 mJava->startBlock();
252
253 for (auto p : mAllInputsAndOutputs) {
254 bool isFieldArray = !scalar && p->mVectorSize != "1";
255 bool isFloatyField = p->isOutParameter && p->isFloatType && mPermutation.getTest() != "custom";
256
257 mJava->indent() << "public ";
258 if (isFloatyField) {
259 *mJava << "Target.Floaty";
260 } else {
261 *mJava << p->javaBaseType;
262 }
263 if (isFieldArray) {
264 *mJava << "[]";
265 }
266 *mJava << " " << p->variableName << ";\n";
267
268 // For Float16 parameters, add an extra 'double' field in the class
269 // to hold the Double value converted from the input.
270 if (p->isFloat16Parameter() && !isFloatyField) {
271 mJava->indent() << "public double";
272 if (isFieldArray) {
273 *mJava << "[]";
274 }
275 *mJava << " " + p->variableName << "Double;\n";
276 }
277 }
278 mJava->endBlock();
279 *mJava << "\n";
280 }
281 }
282
writeJavaCheckMethod(bool generateCallToVerifier) const283 void PermutationWriter::writeJavaCheckMethod(bool generateCallToVerifier) const {
284 mJava->indent() << "private void " << mJavaCheckMethodName << "()";
285 mJava->startBlock();
286
287 // Generate the input allocations and initialization.
288 for (auto p : mAllInputsAndOutputs) {
289 if (!p->isOutParameter) {
290 writeJavaInputAllocationDefinition(*p);
291 }
292 }
293 // Generate code to enforce ordering between two allocations if needed.
294 for (auto p : mAllInputsAndOutputs) {
295 if (!p->isOutParameter && !p->smallerParameter.empty()) {
296 string smallerAlloc = "in" + capitalize(p->smallerParameter);
297 mJava->indent() << "enforceOrdering(" << smallerAlloc << ", " << p->javaAllocName
298 << ");\n";
299 }
300 }
301
302 // Generate code to check the full and relaxed scripts.
303 writeJavaCallToRs(false, generateCallToVerifier);
304 writeJavaCallToRs(true, generateCallToVerifier);
305
306 // Generate code to destroy input Allocations.
307 for (auto p : mAllInputsAndOutputs) {
308 if (!p->isOutParameter) {
309 mJava->indent() << p->javaAllocName << ".destroy();\n";
310 }
311 }
312
313 mJava->endBlock();
314 *mJava << "\n";
315 }
316
writeJavaInputAllocationDefinition(const ParameterDefinition & param) const317 void PermutationWriter::writeJavaInputAllocationDefinition(const ParameterDefinition& param) const {
318 string dataType;
319 char vectorSize;
320 convertToRsType(param.rsType, &dataType, &vectorSize);
321
322 const string seed = hashString(mJavaCheckMethodName + param.javaAllocName);
323 mJava->indent() << "Allocation " << param.javaAllocName << " = ";
324 if (param.compatibleTypeIndex >= 0) {
325 if (TYPES[param.typeIndex].kind == FLOATING_POINT) {
326 writeJavaRandomCompatibleFloatAllocation(dataType, seed, vectorSize,
327 TYPES[param.compatibleTypeIndex],
328 TYPES[param.typeIndex]);
329 } else {
330 writeJavaRandomCompatibleIntegerAllocation(dataType, seed, vectorSize,
331 TYPES[param.compatibleTypeIndex],
332 TYPES[param.typeIndex]);
333 }
334 } else if (!param.minValue.empty()) {
335 *mJava << "createRandomFloatAllocation(mRS, Element.DataType." << dataType << ", "
336 << vectorSize << ", " << seed << ", " << param.minValue << ", " << param.maxValue
337 << ")";
338 } else {
339 /* TODO Instead of passing always false, check whether we are doing a limited test.
340 * Use instead: (mPermutation.getTest() == "limited" ? "false" : "true")
341 */
342 *mJava << "createRandomAllocation(mRS, Element.DataType." << dataType << ", " << vectorSize
343 << ", " << seed << ", false)";
344 }
345 *mJava << ";\n";
346 }
347
writeJavaRandomCompatibleFloatAllocation(const string & dataType,const string & seed,char vectorSize,const NumericalType & compatibleType,const NumericalType & generatedType) const348 void PermutationWriter::writeJavaRandomCompatibleFloatAllocation(
349 const string& dataType, const string& seed, char vectorSize,
350 const NumericalType& compatibleType, const NumericalType& generatedType) const {
351 *mJava << "createRandomFloatAllocation"
352 << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", ";
353 double minValue = 0.0;
354 double maxValue = 0.0;
355 switch (compatibleType.kind) {
356 case FLOATING_POINT: {
357 // We're generating floating point values. We just worry about the exponent.
358 // Subtract 1 for the exponent sign.
359 int bits = min(compatibleType.exponentBits, generatedType.exponentBits) - 1;
360 maxValue = ldexp(0.95, (1 << bits) - 1);
361 minValue = -maxValue;
362 break;
363 }
364 case UNSIGNED_INTEGER:
365 maxValue = maxDoubleForInteger(compatibleType.significantBits,
366 generatedType.significantBits);
367 minValue = 0.0;
368 break;
369 case SIGNED_INTEGER:
370 maxValue = maxDoubleForInteger(compatibleType.significantBits,
371 generatedType.significantBits);
372 minValue = -maxValue - 1.0;
373 break;
374 }
375 *mJava << scientific << std::setprecision(19);
376 *mJava << minValue << ", " << maxValue << ")";
377 mJava->unsetf(ios_base::floatfield);
378 }
379
writeJavaRandomCompatibleIntegerAllocation(const string & dataType,const string & seed,char vectorSize,const NumericalType & compatibleType,const NumericalType & generatedType) const380 void PermutationWriter::writeJavaRandomCompatibleIntegerAllocation(
381 const string& dataType, const string& seed, char vectorSize,
382 const NumericalType& compatibleType, const NumericalType& generatedType) const {
383 *mJava << "createRandomIntegerAllocation"
384 << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", ";
385
386 if (compatibleType.kind == FLOATING_POINT) {
387 // Currently, all floating points can take any number we generate.
388 bool isSigned = generatedType.kind == SIGNED_INTEGER;
389 *mJava << (isSigned ? "true" : "false") << ", " << generatedType.significantBits;
390 } else {
391 bool isSigned =
392 compatibleType.kind == SIGNED_INTEGER && generatedType.kind == SIGNED_INTEGER;
393 *mJava << (isSigned ? "true" : "false") << ", "
394 << min(compatibleType.significantBits, generatedType.significantBits);
395 }
396 *mJava << ")";
397 }
398
writeJavaOutputAllocationDefinition(const ParameterDefinition & param) const399 void PermutationWriter::writeJavaOutputAllocationDefinition(
400 const ParameterDefinition& param) const {
401 string dataType;
402 char vectorSize;
403 convertToRsType(param.rsType, &dataType, &vectorSize);
404 mJava->indent() << "Allocation " << param.javaAllocName << " = Allocation.createSized(mRS, "
405 << "getElement(mRS, Element.DataType." << dataType << ", " << vectorSize
406 << "), INPUTSIZE);\n";
407 }
408
writeJavaVerifyScalarMethod(bool verifierValidates) const409 void PermutationWriter::writeJavaVerifyScalarMethod(bool verifierValidates) const {
410 writeJavaVerifyMethodHeader();
411 mJava->startBlock();
412
413 string vectorSize = "1";
414 for (auto p : mAllInputsAndOutputs) {
415 writeJavaArrayInitialization(*p);
416 if (p->mVectorSize != "1" && p->mVectorSize != vectorSize) {
417 if (vectorSize == "1") {
418 vectorSize = p->mVectorSize;
419 } else {
420 cerr << "Error. Had vector " << vectorSize << " and " << p->mVectorSize << "\n";
421 }
422 }
423 }
424
425 mJava->indent() << "StringBuilder message = new StringBuilder();\n";
426 mJava->indent() << "boolean errorFound = false;\n";
427 mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)";
428 mJava->startBlock();
429
430 mJava->indent() << "for (int j = 0; j < " << vectorSize << " ; j++)";
431 mJava->startBlock();
432
433 mJava->indent() << "// Extract the inputs.\n";
434 mJava->indent() << mJavaArgumentsClassName << " args = new " << mJavaArgumentsClassName
435 << "();\n";
436 for (auto p : mAllInputsAndOutputs) {
437 if (!p->isOutParameter) {
438 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i";
439 if (p->vectorWidth != "1") {
440 *mJava << " * " << p->vectorWidth << " + j";
441 }
442 *mJava << "];\n";
443
444 // Convert the Float16 parameter to double and store it in the appropriate field in the
445 // Arguments class.
446 if (p->isFloat16Parameter()) {
447 mJava->indent() << "args." << p->doubleVariableName
448 << " = Float16Utils.convertFloat16ToDouble(args."
449 << p->variableName << ");\n";
450 }
451 }
452 }
453 const bool hasFloat = mPermutation.hasFloatAnswers();
454 if (verifierValidates) {
455 mJava->indent() << "// Extract the outputs.\n";
456 for (auto p : mAllInputsAndOutputs) {
457 if (p->isOutParameter) {
458 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName
459 << "[i * " << p->vectorWidth << " + j];\n";
460 if (p->isFloat16Parameter()) {
461 mJava->indent() << "args." << p->doubleVariableName
462 << " = Float16Utils.convertFloat16ToDouble(args."
463 << p->variableName << ");\n";
464 }
465 }
466 }
467 mJava->indent() << "// Ask the CoreMathVerifier to validate.\n";
468 if (hasFloat) {
469 writeJavaCreateTarget();
470 }
471 mJava->indent() << "String errorMessage = CoreMathVerifier."
472 << mJavaVerifierVerifyMethodName << "(args";
473 if (hasFloat) {
474 *mJava << ", target";
475 }
476 *mJava << ");\n";
477 mJava->indent() << "boolean valid = errorMessage == null;\n";
478 } else {
479 mJava->indent() << "// Figure out what the outputs should have been.\n";
480 if (hasFloat) {
481 writeJavaCreateTarget();
482 }
483 mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args";
484 if (hasFloat) {
485 *mJava << ", target";
486 }
487 *mJava << ");\n";
488 mJava->indent() << "// Validate the outputs.\n";
489 mJava->indent() << "boolean valid = true;\n";
490 for (auto p : mAllInputsAndOutputs) {
491 if (p->isOutParameter) {
492 writeJavaTestAndSetValid(*p, "", "[i * " + p->vectorWidth + " + j]");
493 }
494 }
495 }
496
497 mJava->indent() << "if (!valid)";
498 mJava->startBlock();
499 mJava->indent() << "if (!errorFound)";
500 mJava->startBlock();
501 mJava->indent() << "errorFound = true;\n";
502
503 for (auto p : mAllInputsAndOutputs) {
504 if (p->isOutParameter) {
505 writeJavaAppendOutputToMessage(*p, "", "[i * " + p->vectorWidth + " + j]",
506 verifierValidates);
507 } else {
508 writeJavaAppendInputToMessage(*p, "args." + p->variableName);
509 }
510 }
511 if (verifierValidates) {
512 mJava->indent() << "message.append(errorMessage);\n";
513 }
514 mJava->indent() << "message.append(\"Errors at\");\n";
515 mJava->endBlock();
516
517 mJava->indent() << "message.append(\" [\");\n";
518 mJava->indent() << "message.append(Integer.toString(i));\n";
519 mJava->indent() << "message.append(\", \");\n";
520 mJava->indent() << "message.append(Integer.toString(j));\n";
521 mJava->indent() << "message.append(\"]\");\n";
522
523 mJava->endBlock();
524 mJava->endBlock();
525 mJava->endBlock();
526
527 mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n";
528 mJava->indentPlus()
529 << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n";
530
531 mJava->endBlock();
532 *mJava << "\n";
533 }
534
writeJavaVerifyVectorMethod() const535 void PermutationWriter::writeJavaVerifyVectorMethod() const {
536 writeJavaVerifyMethodHeader();
537 mJava->startBlock();
538
539 for (auto p : mAllInputsAndOutputs) {
540 writeJavaArrayInitialization(*p);
541 }
542 mJava->indent() << "StringBuilder message = new StringBuilder();\n";
543 mJava->indent() << "boolean errorFound = false;\n";
544 mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)";
545 mJava->startBlock();
546
547 mJava->indent() << mJavaArgumentsNClassName << " args = new " << mJavaArgumentsNClassName
548 << "();\n";
549
550 mJava->indent() << "// Create the appropriate sized arrays in args\n";
551 for (auto p : mAllInputsAndOutputs) {
552 if (p->mVectorSize != "1") {
553 string type = p->javaBaseType;
554 if (p->isOutParameter && p->isFloatType) {
555 type = "Target.Floaty";
556 }
557 mJava->indent() << "args." << p->variableName << " = new " << type << "["
558 << p->mVectorSize << "];\n";
559 if (p->isFloat16Parameter() && !p->isOutParameter) {
560 mJava->indent() << "args." << p->variableName << "Double = new double["
561 << p->mVectorSize << "];\n";
562 }
563 }
564 }
565
566 mJava->indent() << "// Fill args with the input values\n";
567 for (auto p : mAllInputsAndOutputs) {
568 if (!p->isOutParameter) {
569 if (p->mVectorSize == "1") {
570 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i]"
571 << ";\n";
572 // Convert the Float16 parameter to double and store it in the appropriate field in
573 // the Arguments class.
574 if (p->isFloat16Parameter()) {
575 mJava->indent() << "args." << p->doubleVariableName << " = "
576 << "Float16Utils.convertFloat16ToDouble(args."
577 << p->variableName << ");\n";
578 }
579 } else {
580 mJava->indent() << "for (int j = 0; j < " << p->mVectorSize << " ; j++)";
581 mJava->startBlock();
582 mJava->indent() << "args." << p->variableName << "[j] = "
583 << p->javaArrayName << "[i * " << p->vectorWidth << " + j]"
584 << ";\n";
585
586 // Convert the Float16 parameter to double and store it in the appropriate field in
587 // the Arguments class.
588 if (p->isFloat16Parameter()) {
589 mJava->indent() << "args." << p->doubleVariableName << "[j] = "
590 << "Float16Utils.convertFloat16ToDouble(args."
591 << p->variableName << "[j]);\n";
592 }
593 mJava->endBlock();
594 }
595 }
596 }
597 writeJavaCreateTarget();
598 mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName
599 << "(args, target);\n\n";
600
601 mJava->indent() << "// Compare the expected outputs to the actual values returned by RS.\n";
602 mJava->indent() << "boolean valid = true;\n";
603 for (auto p : mAllInputsAndOutputs) {
604 if (p->isOutParameter) {
605 writeJavaVectorComparison(*p);
606 }
607 }
608
609 mJava->indent() << "if (!valid)";
610 mJava->startBlock();
611 mJava->indent() << "if (!errorFound)";
612 mJava->startBlock();
613 mJava->indent() << "errorFound = true;\n";
614
615 for (auto p : mAllInputsAndOutputs) {
616 if (p->isOutParameter) {
617 writeJavaAppendVectorOutputToMessage(*p);
618 } else {
619 writeJavaAppendVectorInputToMessage(*p);
620 }
621 }
622 mJava->indent() << "message.append(\"Errors at\");\n";
623 mJava->endBlock();
624
625 mJava->indent() << "message.append(\" [\");\n";
626 mJava->indent() << "message.append(Integer.toString(i));\n";
627 mJava->indent() << "message.append(\"]\");\n";
628
629 mJava->endBlock();
630 mJava->endBlock();
631
632 mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n";
633 mJava->indentPlus()
634 << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n";
635
636 mJava->endBlock();
637 *mJava << "\n";
638 }
639
640
writeJavaCreateTarget() const641 void PermutationWriter::writeJavaCreateTarget() const {
642 string name = mPermutation.getName();
643
644 const char* functionType = "NORMAL";
645 size_t end = name.find('_');
646 if (end != string::npos) {
647 if (name.compare(0, end, "native") == 0) {
648 functionType = "NATIVE";
649 } else if (name.compare(0, end, "half") == 0) {
650 functionType = "HALF";
651 } else if (name.compare(0, end, "fast") == 0) {
652 functionType = "FAST";
653 }
654 }
655
656 string floatType = mReturnParam->specType;
657 const char* precisionStr = "";
658 if (floatType.compare("f16") == 0) {
659 precisionStr = "HALF";
660 } else if (floatType.compare("f32") == 0) {
661 precisionStr = "FLOAT";
662 } else if (floatType.compare("f64") == 0) {
663 precisionStr = "DOUBLE";
664 } else {
665 cerr << "Error. Unreachable. Return type is not floating point\n";
666 }
667
668 mJava->indent() << "Target target = new Target(Target.FunctionType." <<
669 functionType << ", Target.ReturnType." << precisionStr <<
670 ", relaxed);\n";
671 }
672
writeJavaVerifyMethodHeader() const673 void PermutationWriter::writeJavaVerifyMethodHeader() const {
674 mJava->indent() << "private void " << mJavaVerifyMethodName << "(";
675 for (auto p : mAllInputsAndOutputs) {
676 *mJava << "Allocation " << p->javaAllocName << ", ";
677 }
678 *mJava << "boolean relaxed)";
679 }
680
writeJavaArrayInitialization(const ParameterDefinition & p) const681 void PermutationWriter::writeJavaArrayInitialization(const ParameterDefinition& p) const {
682 mJava->indent() << p.javaBaseType << "[] " << p.javaArrayName << " = new " << p.javaBaseType
683 << "[INPUTSIZE * " << p.vectorWidth << "];\n";
684
685 /* For basic types, populate the array with values, to help understand failures. We have had
686 * bugs where the output buffer was all 0. We were not sure if there was a failed copy or
687 * the GPU driver was copying zeroes.
688 */
689 if (p.typeIndex >= 0) {
690 mJava->indent() << "Arrays.fill(" << p.javaArrayName << ", (" << TYPES[p.typeIndex].javaType
691 << ") 42);\n";
692 }
693
694 mJava->indent() << p.javaAllocName << ".copyTo(" << p.javaArrayName << ");\n";
695 }
696
writeJavaTestAndSetValid(const ParameterDefinition & p,const string & argsIndex,const string & actualIndex) const697 void PermutationWriter::writeJavaTestAndSetValid(const ParameterDefinition& p,
698 const string& argsIndex,
699 const string& actualIndex) const {
700 writeJavaTestOneValue(p, argsIndex, actualIndex);
701 mJava->startBlock();
702 mJava->indent() << "valid = false;\n";
703 mJava->endBlock();
704 }
705
writeJavaTestOneValue(const ParameterDefinition & p,const string & argsIndex,const string & actualIndex) const706 void PermutationWriter::writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex,
707 const string& actualIndex) const {
708 string actualOut;
709 if (p.isFloat16Parameter()) {
710 // For Float16 values, the output needs to be converted to Double.
711 actualOut = "Float16Utils.convertFloat16ToDouble(" + p.javaArrayName + actualIndex + ")";
712 } else {
713 actualOut = p.javaArrayName + actualIndex;
714 }
715
716 mJava->indent() << "if (";
717 if (p.isFloatType) {
718 *mJava << "!args." << p.variableName << argsIndex << ".couldBe(" << actualOut;
719 const string s = mPermutation.getPrecisionLimit();
720 if (!s.empty()) {
721 *mJava << ", " << s;
722 }
723 *mJava << ")";
724 } else {
725 *mJava << "args." << p.variableName << argsIndex << " != " << p.javaArrayName
726 << actualIndex;
727 }
728
729 if (p.undefinedIfOutIsNan && mReturnParam) {
730 *mJava << " && !args." << mReturnParam->variableName << argsIndex << ".isNaN()";
731 }
732 *mJava << ")";
733 }
734
writeJavaVectorComparison(const ParameterDefinition & p) const735 void PermutationWriter::writeJavaVectorComparison(const ParameterDefinition& p) const {
736 if (p.mVectorSize == "1") {
737 writeJavaTestAndSetValid(p, "", "[i]");
738 } else {
739 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
740 mJava->startBlock();
741 writeJavaTestAndSetValid(p, "[j]", "[i * " + p.vectorWidth + " + j]");
742 mJava->endBlock();
743 }
744 }
745
writeJavaAppendOutputToMessage(const ParameterDefinition & p,const string & argsIndex,const string & actualIndex,bool verifierValidates) const746 void PermutationWriter::writeJavaAppendOutputToMessage(const ParameterDefinition& p,
747 const string& argsIndex,
748 const string& actualIndex,
749 bool verifierValidates) const {
750 if (verifierValidates) {
751 mJava->indent() << "message.append(\"Output " << p.variableName << ": \");\n";
752 mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex
753 << ");\n";
754 writeJavaAppendNewLineToMessage();
755 if (p.isFloat16Parameter()) {
756 writeJavaAppendNewLineToMessage();
757 mJava->indent() << "message.append(\"Output " << p.variableName
758 << " (in double): \");\n";
759 mJava->indent() << "appendVariableToMessage(message, args." << p.doubleVariableName
760 << ");\n";
761 writeJavaAppendNewLineToMessage();
762 }
763 } else {
764 mJava->indent() << "message.append(\"Expected output " << p.variableName << ": \");\n";
765 mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex
766 << ");\n";
767 writeJavaAppendNewLineToMessage();
768
769 mJava->indent() << "message.append(\"Actual output " << p.variableName << ": \");\n";
770 mJava->indent() << "appendVariableToMessage(message, " << p.javaArrayName << actualIndex
771 << ");\n";
772
773 if (p.isFloat16Parameter()) {
774 writeJavaAppendNewLineToMessage();
775 mJava->indent() << "message.append(\"Actual output " << p.variableName
776 << " (in double): \");\n";
777 mJava->indent() << "appendVariableToMessage(message, Float16Utils.convertFloat16ToDouble("
778 << p.javaArrayName << actualIndex << "));\n";
779 }
780
781 writeJavaTestOneValue(p, argsIndex, actualIndex);
782 mJava->startBlock();
783 mJava->indent() << "message.append(\" FAIL\");\n";
784 mJava->endBlock();
785 writeJavaAppendNewLineToMessage();
786 }
787 }
788
writeJavaAppendInputToMessage(const ParameterDefinition & p,const string & actual) const789 void PermutationWriter::writeJavaAppendInputToMessage(const ParameterDefinition& p,
790 const string& actual) const {
791 mJava->indent() << "message.append(\"Input " << p.variableName << ": \");\n";
792 mJava->indent() << "appendVariableToMessage(message, " << actual << ");\n";
793 writeJavaAppendNewLineToMessage();
794 }
795
writeJavaAppendNewLineToMessage() const796 void PermutationWriter::writeJavaAppendNewLineToMessage() const {
797 mJava->indent() << "message.append(\"\\n\");\n";
798 }
799
writeJavaAppendVectorInputToMessage(const ParameterDefinition & p) const800 void PermutationWriter::writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const {
801 if (p.mVectorSize == "1") {
802 writeJavaAppendInputToMessage(p, p.javaArrayName + "[i]");
803 } else {
804 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
805 mJava->startBlock();
806 writeJavaAppendInputToMessage(p, p.javaArrayName + "[i * " + p.vectorWidth + " + j]");
807 mJava->endBlock();
808 }
809 }
810
writeJavaAppendVectorOutputToMessage(const ParameterDefinition & p) const811 void PermutationWriter::writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const {
812 if (p.mVectorSize == "1") {
813 writeJavaAppendOutputToMessage(p, "", "[i]", false);
814 } else {
815 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
816 mJava->startBlock();
817 writeJavaAppendOutputToMessage(p, "[j]", "[i * " + p.vectorWidth + " + j]", false);
818 mJava->endBlock();
819 }
820 }
821
writeJavaCallToRs(bool relaxed,bool generateCallToVerifier) const822 void PermutationWriter::writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const {
823 string script = "script";
824 if (relaxed) {
825 script += "Relaxed";
826 }
827
828 mJava->indent() << "try";
829 mJava->startBlock();
830
831 for (auto p : mAllInputsAndOutputs) {
832 if (p->isOutParameter) {
833 writeJavaOutputAllocationDefinition(*p);
834 }
835 }
836
837 for (auto p : mPermutation.getParams()) {
838 if (p != mFirstInputParam) {
839 mJava->indent() << script << ".set_" << p->rsAllocName << "(" << p->javaAllocName
840 << ");\n";
841 }
842 }
843
844 mJava->indent() << script << ".forEach_" << mRsKernelName << "(";
845 bool needComma = false;
846 if (mFirstInputParam) {
847 *mJava << mFirstInputParam->javaAllocName;
848 needComma = true;
849 }
850 if (mReturnParam) {
851 if (needComma) {
852 *mJava << ", ";
853 }
854 *mJava << mReturnParam->variableName << ");\n";
855 }
856
857 if (generateCallToVerifier) {
858 mJava->indent() << mJavaVerifyMethodName << "(";
859 for (auto p : mAllInputsAndOutputs) {
860 *mJava << p->variableName << ", ";
861 }
862
863 if (relaxed) {
864 *mJava << "true";
865 } else {
866 *mJava << "false";
867 }
868 *mJava << ");\n";
869 }
870
871 // Generate code to destroy output Allocations.
872 for (auto p : mAllInputsAndOutputs) {
873 if (p->isOutParameter) {
874 mJava->indent() << p->javaAllocName << ".destroy();\n";
875 }
876 }
877
878 mJava->decreaseIndent();
879 mJava->indent() << "} catch (Exception e) {\n";
880 mJava->increaseIndent();
881 mJava->indent() << "throw new RSRuntimeException(\"RenderScript. Can't invoke forEach_"
882 << mRsKernelName << ": \" + e.toString());\n";
883 mJava->endBlock();
884 }
885
886 /* Write the section of the .rs file for this permutation.
887 *
888 * We communicate the extra input and output parameters via global allocations.
889 * For example, if we have a function that takes three arguments, two for input
890 * and one for output:
891 *
892 * start:
893 * name: gamn
894 * ret: float3
895 * arg: float3 a
896 * arg: int b
897 * arg: float3 *c
898 * end:
899 *
900 * We'll produce:
901 *
902 * rs_allocation gAllocInB;
903 * rs_allocation gAllocOutC;
904 *
905 * float3 __attribute__((kernel)) test_gamn_float3_int_float3(float3 inA, unsigned int x) {
906 * int inB;
907 * float3 outC;
908 * float2 out;
909 * inB = rsGetElementAt_int(gAllocInB, x);
910 * out = gamn(a, in_b, &outC);
911 * rsSetElementAt_float4(gAllocOutC, &outC, x);
912 * return out;
913 * }
914 *
915 * We avoid re-using x and y from the definition because these have reserved
916 * meanings in a .rs file.
917 */
writeRsSection(set<string> * rsAllocationsGenerated) const918 void PermutationWriter::writeRsSection(set<string>* rsAllocationsGenerated) const {
919 // Write the allocation declarations we'll need.
920 for (auto p : mPermutation.getParams()) {
921 // Don't need allocation for one input and one return value.
922 if (p != mFirstInputParam) {
923 writeRsAllocationDefinition(*p, rsAllocationsGenerated);
924 }
925 }
926 *mRs << "\n";
927
928 // Write the function header.
929 if (mReturnParam) {
930 *mRs << mReturnParam->rsType;
931 } else {
932 *mRs << "void";
933 }
934 *mRs << " __attribute__((kernel)) " << mRsKernelName;
935 *mRs << "(";
936 bool needComma = false;
937 if (mFirstInputParam) {
938 *mRs << mFirstInputParam->rsType << " " << mFirstInputParam->variableName;
939 needComma = true;
940 }
941 if (mPermutation.getOutputCount() > 1 || mPermutation.getInputCount() > 1) {
942 if (needComma) {
943 *mRs << ", ";
944 }
945 *mRs << "unsigned int x";
946 }
947 *mRs << ")";
948 mRs->startBlock();
949
950 // Write the local variable declarations and initializations.
951 for (auto p : mPermutation.getParams()) {
952 if (p == mFirstInputParam) {
953 continue;
954 }
955 mRs->indent() << p->rsType << " " << p->variableName;
956 if (p->isOutParameter) {
957 *mRs << " = 0;\n";
958 } else {
959 *mRs << " = rsGetElementAt_" << p->rsType << "(" << p->rsAllocName << ", x);\n";
960 }
961 }
962
963 // Write the function call.
964 if (mReturnParam) {
965 if (mPermutation.getOutputCount() > 1) {
966 mRs->indent() << mReturnParam->rsType << " " << mReturnParam->variableName << " = ";
967 } else {
968 mRs->indent() << "return ";
969 }
970 }
971 *mRs << mPermutation.getName() << "(";
972 needComma = false;
973 for (auto p : mPermutation.getParams()) {
974 if (needComma) {
975 *mRs << ", ";
976 }
977 if (p->isOutParameter) {
978 *mRs << "&";
979 }
980 *mRs << p->variableName;
981 needComma = true;
982 }
983 *mRs << ");\n";
984
985 if (mPermutation.getOutputCount() > 1) {
986 // Write setting the extra out parameters into the allocations.
987 for (auto p : mPermutation.getParams()) {
988 if (p->isOutParameter) {
989 mRs->indent() << "rsSetElementAt_" << p->rsType << "(" << p->rsAllocName << ", ";
990 // Check if we need to use '&' for this type of argument.
991 char lastChar = p->variableName.back();
992 if (lastChar >= '0' && lastChar <= '9') {
993 *mRs << "&";
994 }
995 *mRs << p->variableName << ", x);\n";
996 }
997 }
998 if (mReturnParam) {
999 mRs->indent() << "return " << mReturnParam->variableName << ";\n";
1000 }
1001 }
1002 mRs->endBlock();
1003 }
1004
writeRsAllocationDefinition(const ParameterDefinition & param,set<string> * rsAllocationsGenerated) const1005 void PermutationWriter::writeRsAllocationDefinition(const ParameterDefinition& param,
1006 set<string>* rsAllocationsGenerated) const {
1007 if (!testAndSet(param.rsAllocName, rsAllocationsGenerated)) {
1008 *mRs << "rs_allocation " << param.rsAllocName << ";\n";
1009 }
1010 }
1011
1012 // Open the mJavaFile and writes the header.
startJavaFile(GeneratedFile * file,const Function & function,const string & directory,const string & testName,const string & relaxedTestName)1013 static bool startJavaFile(GeneratedFile* file, const Function& function, const string& directory,
1014 const string& testName, const string& relaxedTestName) {
1015 const string fileName = testName + ".java";
1016 if (!file->start(directory, fileName)) {
1017 return false;
1018 }
1019 file->writeNotices();
1020
1021 *file << "package android.renderscript.cts;\n\n";
1022
1023 *file << "import android.renderscript.Allocation;\n";
1024 *file << "import android.renderscript.RSRuntimeException;\n";
1025 *file << "import android.renderscript.Element;\n";
1026 *file << "import android.renderscript.cts.Target;\n\n";
1027 *file << "import java.util.Arrays;\n\n";
1028
1029 *file << "public class " << testName << " extends RSBaseCompute";
1030 file->startBlock(); // The corresponding endBlock() is in finishJavaFile()
1031 *file << "\n";
1032
1033 file->indent() << "private ScriptC_" << testName << " script;\n";
1034 file->indent() << "private ScriptC_" << relaxedTestName << " scriptRelaxed;\n\n";
1035
1036 file->indent() << "@Override\n";
1037 file->indent() << "protected void setUp() throws Exception";
1038 file->startBlock();
1039
1040 file->indent() << "super.setUp();\n";
1041 file->indent() << "script = new ScriptC_" << testName << "(mRS);\n";
1042 file->indent() << "scriptRelaxed = new ScriptC_" << relaxedTestName << "(mRS);\n";
1043
1044 file->endBlock();
1045 *file << "\n";
1046
1047 file->indent() << "@Override\n";
1048 file->indent() << "protected void tearDown() throws Exception";
1049 file->startBlock();
1050
1051 file->indent() << "script.destroy();\n";
1052 file->indent() << "scriptRelaxed.destroy();\n";
1053 file->indent() << "super.tearDown();\n";
1054
1055 file->endBlock();
1056 *file << "\n";
1057
1058 return true;
1059 }
1060
1061 // Write the test method that calls all the generated Check methods.
finishJavaFile(GeneratedFile * file,const Function & function,const vector<string> & javaCheckMethods)1062 static void finishJavaFile(GeneratedFile* file, const Function& function,
1063 const vector<string>& javaCheckMethods) {
1064 file->indent() << "public void test" << function.getCapitalizedName() << "()";
1065 file->startBlock();
1066 for (auto m : javaCheckMethods) {
1067 file->indent() << m << "();\n";
1068 }
1069 file->endBlock();
1070
1071 file->endBlock();
1072 }
1073
1074 // Open the script file and write its header.
startRsFile(GeneratedFile * file,const Function & function,const string & directory,const string & testName)1075 static bool startRsFile(GeneratedFile* file, const Function& function, const string& directory,
1076 const string& testName) {
1077 string fileName = testName + ".rs";
1078 if (!file->start(directory, fileName)) {
1079 return false;
1080 }
1081 file->writeNotices();
1082
1083 *file << "#pragma version(1)\n";
1084 *file << "#pragma rs java_package_name(android.renderscript.cts)\n\n";
1085 return true;
1086 }
1087
1088 // Write the entire *Relaxed.rs test file, as it only depends on the name.
writeRelaxedRsFile(const Function & function,const string & directory,const string & testName,const string & relaxedTestName)1089 static bool writeRelaxedRsFile(const Function& function, const string& directory,
1090 const string& testName, const string& relaxedTestName) {
1091 string name = relaxedTestName + ".rs";
1092
1093 GeneratedFile file;
1094 if (!file.start(directory, name)) {
1095 return false;
1096 }
1097 file.writeNotices();
1098
1099 file << "#include \"" << testName << ".rs\"\n";
1100 file << "#pragma rs_fp_relaxed\n";
1101 file.close();
1102 return true;
1103 }
1104
1105 /* Write the .java and the two .rs test files. versionOfTestFiles is used to restrict which API
1106 * to test.
1107 */
writeTestFilesForFunction(const Function & function,const string & directory,unsigned int versionOfTestFiles)1108 static bool writeTestFilesForFunction(const Function& function, const string& directory,
1109 unsigned int versionOfTestFiles) {
1110 // Avoid creating empty files if we're not testing this function.
1111 if (!needTestFiles(function, versionOfTestFiles)) {
1112 return true;
1113 }
1114
1115 const string testName = "Test" + function.getCapitalizedName();
1116 const string relaxedTestName = testName + "Relaxed";
1117
1118 if (!writeRelaxedRsFile(function, directory, testName, relaxedTestName)) {
1119 return false;
1120 }
1121
1122 GeneratedFile rsFile; // The Renderscript test file we're generating.
1123 GeneratedFile javaFile; // The Jave test file we're generating.
1124 if (!startRsFile(&rsFile, function, directory, testName)) {
1125 return false;
1126 }
1127
1128 if (!startJavaFile(&javaFile, function, directory, testName, relaxedTestName)) {
1129 return false;
1130 }
1131
1132 /* We keep track of the allocations generated in the .rs file and the argument classes defined
1133 * in the Java file, as we share these between the functions created for each specification.
1134 */
1135 set<string> rsAllocationsGenerated;
1136 set<string> javaGeneratedArgumentClasses;
1137 // Lines of Java code to invoke the check methods.
1138 vector<string> javaCheckMethods;
1139
1140 for (auto spec : function.getSpecifications()) {
1141 if (spec->hasTests(versionOfTestFiles)) {
1142 for (auto permutation : spec->getPermutations()) {
1143 PermutationWriter w(*permutation, &rsFile, &javaFile);
1144 w.writeRsSection(&rsAllocationsGenerated);
1145 w.writeJavaSection(&javaGeneratedArgumentClasses);
1146
1147 // Store the check method to be called.
1148 javaCheckMethods.push_back(w.getJavaCheckMethodName());
1149 }
1150 }
1151 }
1152
1153 finishJavaFile(&javaFile, function, javaCheckMethods);
1154 // There's no work to wrap-up in the .rs file.
1155
1156 rsFile.close();
1157 javaFile.close();
1158 return true;
1159 }
1160
generateTestFiles(const string & directory,unsigned int versionOfTestFiles)1161 bool generateTestFiles(const string& directory, unsigned int versionOfTestFiles) {
1162 bool success = true;
1163 for (auto f : systemSpecification.getFunctions()) {
1164 if (!writeTestFilesForFunction(*f.second, directory, versionOfTestFiles)) {
1165 success = false;
1166 }
1167 }
1168 return success;
1169 }
1170