1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/SkSLAnalysis.h"
9 #include "src/sksl/SkSLConstantFolder.h"
10 #include "src/sksl/SkSLContext.h"
11 #include "src/sksl/SkSLProgramSettings.h"
12 #include "src/sksl/ir/SkSLBoolLiteral.h"
13 #include "src/sksl/ir/SkSLConstructorCompound.h"
14 #include "src/sksl/ir/SkSLFloatLiteral.h"
15 #include "src/sksl/ir/SkSLFunctionCall.h"
16 #include "src/sksl/ir/SkSLIntLiteral.h"
17
18 namespace SkSL {
19
has_compile_time_constant_arguments(const ExpressionArray & arguments)20 static bool has_compile_time_constant_arguments(const ExpressionArray& arguments) {
21 for (const std::unique_ptr<Expression>& arg : arguments) {
22 const Expression* expr = ConstantFolder::GetConstantValueForVariable(*arg);
23 if (!expr->isCompileTimeConstant()) {
24 return false;
25 }
26 }
27 return true;
28 }
29
coalesce_bool_vector(const ExpressionArray & arguments,bool startingState,const std::function<bool (bool,bool)> & coalesce)30 static std::unique_ptr<Expression> coalesce_bool_vector(
31 const ExpressionArray& arguments,
32 bool startingState,
33 const std::function<bool(bool, bool)>& coalesce) {
34 SkASSERT(arguments.size() == 1);
35 const Expression* arg = ConstantFolder::GetConstantValueForVariable(*arguments.front());
36 const Type& type = arg->type();
37 SkASSERT(type.isVector());
38 SkASSERT(type.componentType().isBoolean());
39
40 bool value = startingState;
41 for (int index = 0; index < type.columns(); ++index) {
42 const Expression* subexpression = arg->getConstantSubexpression(index);
43 SkASSERT(subexpression);
44 value = coalesce(value, subexpression->as<BoolLiteral>().value());
45 }
46
47 return BoolLiteral::Make(arg->fOffset, value, &type.componentType());
48 }
49
50 template <typename LITERAL, typename FN>
optimize_comparison_of_type(const Context & context,const Expression & left,const Expression & right,const FN & compare)51 static std::unique_ptr<Expression> optimize_comparison_of_type(const Context& context,
52 const Expression& left,
53 const Expression& right,
54 const FN& compare) {
55 const Type& type = left.type();
56 SkASSERT(type.isVector());
57 SkASSERT(type.componentType().isNumber());
58 SkASSERT(type == right.type());
59
60 ExpressionArray result;
61 result.reserve_back(type.columns());
62
63 for (int index = 0; index < type.columns(); ++index) {
64 const Expression* leftSubexpr = left.getConstantSubexpression(index);
65 const Expression* rightSubexpr = right.getConstantSubexpression(index);
66 SkASSERT(leftSubexpr);
67 SkASSERT(rightSubexpr);
68 bool value = compare(leftSubexpr->as<LITERAL>().value(),
69 rightSubexpr->as<LITERAL>().value());
70 result.push_back(BoolLiteral::Make(context, leftSubexpr->fOffset, value));
71 }
72
73 const Type& bvecType = context.fTypes.fBool->toCompound(context, type.columns(), /*rows=*/1);
74 return ConstructorCompound::Make(context, left.fOffset, bvecType, std::move(result));
75 }
76
77 template <typename FN>
optimize_comparison(const Context & context,const ExpressionArray & arguments,const FN & compare)78 static std::unique_ptr<Expression> optimize_comparison(const Context& context,
79 const ExpressionArray& arguments,
80 const FN& compare) {
81 SkASSERT(arguments.size() == 2);
82 const Expression* left = ConstantFolder::GetConstantValueForVariable(*arguments[0]);
83 const Expression* right = ConstantFolder::GetConstantValueForVariable(*arguments[1]);
84 const Type& type = left->type().componentType();
85
86 if (type.isFloat()) {
87 return optimize_comparison_of_type<FloatLiteral>(context, *left, *right, compare);
88 }
89 if (type.isInteger()) {
90 return optimize_comparison_of_type<IntLiteral>(context, *left, *right, compare);
91 }
92 SkDEBUGFAILF("unsupported type %s", type.description().c_str());
93 return nullptr;
94 }
95
96 template <typename LITERAL, typename FN>
evaluate_intrinsic_1_of_type(const Context & context,const Expression * arg,const FN & evaluate)97 static std::unique_ptr<Expression> evaluate_intrinsic_1_of_type(const Context& context,
98 const Expression* arg,
99 const FN& evaluate) {
100 const Type& vecType = arg->type();
101 const Type& type = vecType.componentType();
102 SkASSERT(type.isScalar());
103
104 ExpressionArray result;
105 result.reserve_back(vecType.columns());
106
107 for (int index = 0; index < vecType.columns(); ++index) {
108 const Expression* subexpr = arg->getConstantSubexpression(index);
109 SkASSERT(subexpr);
110 auto value = evaluate(subexpr->as<LITERAL>().value());
111 if constexpr (std::is_floating_point<decltype(value)>::value) {
112 // If evaluation of the intrinsic yields a non-finite value, bail on optimization.
113 if (!isfinite(value)) {
114 return nullptr;
115 }
116 }
117 result.push_back(LITERAL::Make(subexpr->fOffset, value, &type));
118 }
119
120 return ConstructorCompound::Make(context, arg->fOffset, vecType, std::move(result));
121 }
122
123 template <typename FN,
124 bool kSupportsFloat = true,
125 bool kSupportsInt = true,
126 bool kSupportsBool = false>
evaluate_intrinsic_generic1(const Context & context,const ExpressionArray & arguments,const FN & evaluate)127 static std::unique_ptr<Expression> evaluate_intrinsic_generic1(const Context& context,
128 const ExpressionArray& arguments,
129 const FN& evaluate) {
130 SkASSERT(arguments.size() == 1);
131 const Expression* arg = ConstantFolder::GetConstantValueForVariable(*arguments.front());
132 const Type& type = arg->type().componentType();
133
134 if constexpr (kSupportsFloat) {
135 if (type.isFloat()) {
136 return evaluate_intrinsic_1_of_type<FloatLiteral>(context, arg, evaluate);
137 }
138 }
139 if constexpr (kSupportsInt) {
140 if (type.isInteger()) {
141 return evaluate_intrinsic_1_of_type<IntLiteral>(context, arg, evaluate);
142 }
143 }
144 if constexpr (kSupportsBool) {
145 if (type.isBoolean()) {
146 return evaluate_intrinsic_1_of_type<BoolLiteral>(context, arg, evaluate);
147 }
148 }
149 SkDEBUGFAILF("unsupported type %s", type.description().c_str());
150 return nullptr;
151 }
152
153 template <typename FN>
evaluate_intrinsic_float1(const Context & context,const ExpressionArray & arguments,const FN & evaluate)154 static std::unique_ptr<Expression> evaluate_intrinsic_float1(const Context& context,
155 const ExpressionArray& arguments,
156 const FN& evaluate) {
157 return evaluate_intrinsic_generic1<FN,
158 /*kSupportsFloat=*/true,
159 /*kSupportsInt=*/false,
160 /*kSupportsBool=*/false>(context, arguments, evaluate);
161 }
162
163 template <typename FN>
evaluate_intrinsic_bool1(const Context & context,const ExpressionArray & arguments,const FN & evaluate)164 static std::unique_ptr<Expression> evaluate_intrinsic_bool1(const Context& context,
165 const ExpressionArray& arguments,
166 const FN& evaluate) {
167 return evaluate_intrinsic_generic1<FN,
168 /*kSupportsFloat=*/false,
169 /*kSupportsInt=*/false,
170 /*kSupportsBool=*/true>(context, arguments, evaluate);
171 }
172
optimize_intrinsic_call(const Context & context,IntrinsicKind intrinsic,const ExpressionArray & arguments)173 static std::unique_ptr<Expression> optimize_intrinsic_call(const Context& context,
174 IntrinsicKind intrinsic,
175 const ExpressionArray& arguments) {
176 switch (intrinsic) {
177 case k_all_IntrinsicKind:
178 return coalesce_bool_vector(arguments, /*startingState=*/true,
179 [](bool a, bool b) { return a && b; });
180 case k_any_IntrinsicKind:
181 return coalesce_bool_vector(arguments, /*startingState=*/false,
182 [](bool a, bool b) { return a || b; });
183 case k_not_IntrinsicKind:
184 return evaluate_intrinsic_bool1(context, arguments, [](bool a) { return !a; });
185
186 case k_greaterThan_IntrinsicKind:
187 return optimize_comparison(context, arguments, [](auto a, auto b) { return a > b; });
188
189 case k_greaterThanEqual_IntrinsicKind:
190 return optimize_comparison(context, arguments, [](auto a, auto b) { return a >= b; });
191
192 case k_lessThan_IntrinsicKind:
193 return optimize_comparison(context, arguments, [](auto a, auto b) { return a < b; });
194
195 case k_lessThanEqual_IntrinsicKind:
196 return optimize_comparison(context, arguments, [](auto a, auto b) { return a <= b; });
197
198 case k_equal_IntrinsicKind:
199 return optimize_comparison(context, arguments, [](auto a, auto b) { return a == b; });
200
201 case k_notEqual_IntrinsicKind:
202 return optimize_comparison(context, arguments, [](auto a, auto b) { return a != b; });
203
204 case k_abs_IntrinsicKind:
205 return evaluate_intrinsic_generic1(context, arguments, [](auto a) { return abs(a); });
206
207 case k_sign_IntrinsicKind:
208 return evaluate_intrinsic_generic1(context, arguments,
209 [](auto a) { return (a > 0) - (a < 0); });
210 case k_sin_IntrinsicKind:
211 return evaluate_intrinsic_float1(context, arguments, [](float a) { return sin(a); });
212
213 case k_cos_IntrinsicKind:
214 return evaluate_intrinsic_float1(context, arguments, [](float a) { return cos(a); });
215
216 case k_tan_IntrinsicKind:
217 return evaluate_intrinsic_float1(context, arguments, [](float a) { return tan(a); });
218
219 case k_asin_IntrinsicKind:
220 return evaluate_intrinsic_float1(context, arguments, [](float a) { return asin(a); });
221
222 case k_acos_IntrinsicKind:
223 return evaluate_intrinsic_float1(context, arguments, [](float a) { return acos(a); });
224
225 case k_sinh_IntrinsicKind:
226 return evaluate_intrinsic_float1(context, arguments, [](float a) { return sinh(a); });
227
228 case k_cosh_IntrinsicKind:
229 return evaluate_intrinsic_float1(context, arguments, [](float a) { return cosh(a); });
230
231 case k_tanh_IntrinsicKind:
232 return evaluate_intrinsic_float1(context, arguments, [](float a) { return tanh(a); });
233
234 case k_ceil_IntrinsicKind:
235 return evaluate_intrinsic_float1(context, arguments, [](float a) { return ceil(a); });
236
237 case k_floor_IntrinsicKind:
238 return evaluate_intrinsic_float1(context, arguments, [](float a) { return floor(a); });
239
240 case k_fract_IntrinsicKind:
241 return evaluate_intrinsic_float1(context, arguments,
242 [](float a) { return a - floor(a); });
243 case k_trunc_IntrinsicKind:
244 return evaluate_intrinsic_float1(context, arguments, [](float a) { return trunc(a); });
245
246 case k_exp_IntrinsicKind:
247 return evaluate_intrinsic_float1(context, arguments, [](float a) { return exp(a); });
248
249 case k_log_IntrinsicKind:
250 return evaluate_intrinsic_float1(context, arguments, [](float a) { return log(a); });
251
252 case k_exp2_IntrinsicKind:
253 return evaluate_intrinsic_float1(context, arguments, [](float a) { return exp2(a); });
254
255 case k_log2_IntrinsicKind:
256 return evaluate_intrinsic_float1(context, arguments, [](float a) { return log2(a); });
257
258 case k_saturate_IntrinsicKind:
259 return evaluate_intrinsic_float1(context, arguments,
260 [](float a) { return (a < 0) ? 0 : (a > 1) ? 1 : a; });
261 case k_round_IntrinsicKind: // GLSL `round` documents its rounding mode as unspecified
262 case k_roundEven_IntrinsicKind: // and is allowed to behave identically to `roundEven`.
263 return evaluate_intrinsic_float1(context, arguments,
264 [](float a) { return round(a / 2) * 2; });
265 case k_inversesqrt_IntrinsicKind:
266 return evaluate_intrinsic_float1(context, arguments,
267 [](float a) { return 1 / sqrt(a); });
268 case k_radians_IntrinsicKind:
269 return evaluate_intrinsic_float1(context, arguments,
270 [](float a) { return a * 0.0174532925; });
271 case k_degrees_IntrinsicKind:
272 return evaluate_intrinsic_float1(context, arguments,
273 [](float a) { return a * 57.2957795; });
274 default:
275 return nullptr;
276 }
277 }
278
hasProperty(Property property) const279 bool FunctionCall::hasProperty(Property property) const {
280 if (property == Property::kSideEffects &&
281 (this->function().modifiers().fFlags & Modifiers::kHasSideEffects_Flag)) {
282 return true;
283 }
284 for (const auto& arg : this->arguments()) {
285 if (arg->hasProperty(property)) {
286 return true;
287 }
288 }
289 return false;
290 }
291
clone() const292 std::unique_ptr<Expression> FunctionCall::clone() const {
293 ExpressionArray cloned;
294 cloned.reserve_back(this->arguments().size());
295 for (const std::unique_ptr<Expression>& arg : this->arguments()) {
296 cloned.push_back(arg->clone());
297 }
298 return std::make_unique<FunctionCall>(
299 fOffset, &this->type(), &this->function(), std::move(cloned));
300 }
301
description() const302 String FunctionCall::description() const {
303 String result = String(this->function().name()) + "(";
304 String separator;
305 for (const std::unique_ptr<Expression>& arg : this->arguments()) {
306 result += separator;
307 result += arg->description();
308 separator = ", ";
309 }
310 result += ")";
311 return result;
312 }
313
Convert(const Context & context,int offset,const FunctionDeclaration & function,ExpressionArray arguments)314 std::unique_ptr<Expression> FunctionCall::Convert(const Context& context,
315 int offset,
316 const FunctionDeclaration& function,
317 ExpressionArray arguments) {
318 // Reject function calls with the wrong number of arguments.
319 if (function.parameters().size() != arguments.size()) {
320 String msg = "call to '" + function.name() + "' expected " +
321 to_string((int)function.parameters().size()) + " argument";
322 if (function.parameters().size() != 1) {
323 msg += "s";
324 }
325 msg += ", but found " + to_string(arguments.count());
326 context.fErrors.error(offset, msg);
327 return nullptr;
328 }
329
330 // GLSL ES 1.0 requires static recursion be rejected by the compiler. Also, our CPU back-end
331 // cannot handle recursion (and is tied to strictES2Mode front-ends). The safest way to reject
332 // all (potentially) recursive code is to disallow calls to functions before they're defined.
333 if (context.fConfig->strictES2Mode() && !function.definition() && !function.isBuiltin()) {
334 context.fErrors.error(offset, "call to undefined function '" + function.name() + "'");
335 return nullptr;
336 }
337
338 // Resolve generic types.
339 FunctionDeclaration::ParamTypes types;
340 const Type* returnType;
341 if (!function.determineFinalTypes(arguments, &types, &returnType)) {
342 String msg = "no match for " + function.name() + "(";
343 String separator;
344 for (const std::unique_ptr<Expression>& arg : arguments) {
345 msg += separator;
346 msg += arg->type().displayName();
347 separator = ", ";
348 }
349 msg += ")";
350 context.fErrors.error(offset, msg);
351 return nullptr;
352 }
353
354 for (size_t i = 0; i < arguments.size(); i++) {
355 // Coerce each argument to the proper type.
356 arguments[i] = types[i]->coerceExpression(std::move(arguments[i]), context);
357 if (!arguments[i]) {
358 return nullptr;
359 }
360 // Update the refKind on out-parameters, and ensure that they are actually assignable.
361 const Modifiers& paramModifiers = function.parameters()[i]->modifiers();
362 if (paramModifiers.fFlags & Modifiers::kOut_Flag) {
363 const VariableRefKind refKind = paramModifiers.fFlags & Modifiers::kIn_Flag
364 ? VariableReference::RefKind::kReadWrite
365 : VariableReference::RefKind::kPointer;
366 if (!Analysis::MakeAssignmentExpr(arguments[i].get(), refKind, &context.fErrors)) {
367 return nullptr;
368 }
369 }
370 }
371
372 return Make(context, offset, returnType, function, std::move(arguments));
373 }
374
Make(const Context & context,int offset,const Type * returnType,const FunctionDeclaration & function,ExpressionArray arguments)375 std::unique_ptr<Expression> FunctionCall::Make(const Context& context,
376 int offset,
377 const Type* returnType,
378 const FunctionDeclaration& function,
379 ExpressionArray arguments) {
380 SkASSERT(function.parameters().size() == arguments.size());
381 SkASSERT(function.definition() || function.isBuiltin() || !context.fConfig->strictES2Mode());
382
383 if (context.fConfig->fSettings.fOptimize) {
384 // We might be able to optimize built-in intrinsics.
385 if (function.isIntrinsic() && has_compile_time_constant_arguments(arguments)) {
386 // The function is an intrinsic and all inputs are compile-time constants. Optimize it.
387 if (std::unique_ptr<Expression> expr =
388 optimize_intrinsic_call(context, function.intrinsicKind(), arguments)) {
389 return expr;
390 }
391 }
392 }
393
394 return std::make_unique<FunctionCall>(offset, returnType, &function, std::move(arguments));
395 }
396
397 } // namespace SkSL
398