1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLConstantFolder.h"
9 #include "src/sksl/ir/SkSLConstructor.h"
10 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
11 #include "src/sksl/ir/SkSLConstructorSplat.h"
12 #include "src/sksl/ir/SkSLSwizzle.h"
13 
14 namespace SkSL {
15 
Convert(const Context & context,std::unique_ptr<Expression> base,ComponentArray inComponents)16 std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
17                                              std::unique_ptr<Expression> base,
18                                              ComponentArray inComponents) {
19     const int offset = base->fOffset;
20     const Type& baseType = base->type();
21 
22     // The IRGenerator is responsible for enforcing these invariants.
23     SkASSERTF(baseType.isVector() || baseType.isScalar(),
24               "cannot swizzle type '%s'", baseType.description().c_str());
25     SkASSERT(inComponents.count() >= 1 && inComponents.count() <= 4);
26 
27     ComponentArray maskComponents;
28     for (int8_t component : inComponents) {
29         switch (component) {
30             case SwizzleComponent::ZERO:
31             case SwizzleComponent::ONE:
32                 // Skip over constant fields for now.
33                 break;
34             case SwizzleComponent::X:
35                 maskComponents.push_back(SwizzleComponent::X);
36                 break;
37             case SwizzleComponent::Y:
38                 if (baseType.columns() >= 2) {
39                     maskComponents.push_back(SwizzleComponent::Y);
40                     break;
41                 }
42                 [[fallthrough]];
43             case SwizzleComponent::Z:
44                 if (baseType.columns() >= 3) {
45                     maskComponents.push_back(SwizzleComponent::Z);
46                     break;
47                 }
48                 [[fallthrough]];
49             case SwizzleComponent::W:
50                 if (baseType.columns() >= 4) {
51                     maskComponents.push_back(SwizzleComponent::W);
52                     break;
53                 }
54                 [[fallthrough]];
55             default:
56                 SkDEBUGFAILF("invalid swizzle component %d", component);
57                 return nullptr;
58         }
59     }
60 
61     // First, we need a vector expression that is the non-constant portion of the swizzle, packed:
62     //   scalar.xxx  -> type3(scalar)
63     //   scalar.x0x0 -> type2(scalar)
64     //   vector.zyx  -> vector.zyx
65     //   vector.x0y0 -> vector.xy
66     std::unique_ptr<Expression> expr = Swizzle::Make(context, std::move(base), maskComponents);
67 
68     // If we have processed the entire swizzle, we're done.
69     if (maskComponents.count() == inComponents.count()) {
70         return expr;
71     }
72 
73     // Now we create a constructor that has the correct number of elements for the final swizzle,
74     // with all fields at the start. It's not finished yet; constants we need will be added below.
75     //   scalar.x0x0 -> type4(type2(x), ...)
76     //   vector.y111 -> type4(vector.y, ...)
77     //   vector.z10x -> type4(vector.zx, ...)
78     //
79     // The constructor will have at most three arguments: { base expr, constant 0, constant 1 }
80     ExpressionArray constructorArgs;
81     constructorArgs.reserve_back(3);
82     constructorArgs.push_back(std::move(expr));
83 
84     // Apply another swizzle to shuffle the constants into the correct place. Any constant values we
85     // need are also tacked on to the end of the constructor.
86     //   scalar.x0x0 -> type4(type2(x), 0).xyxy
87     //   vector.y111 -> type4(vector.y, 1).xyyy
88     //   vector.z10x -> type4(vector.zx, 1, 0).xzwy
89     const Type* numberType = &baseType.componentType();
90     ComponentArray swizzleComponents;
91     int maskFieldIdx = 0;
92     int constantFieldIdx = maskComponents.size();
93     int constantZeroIdx = -1, constantOneIdx = -1;
94 
95     for (int i = 0; i < inComponents.count(); i++) {
96         switch (inComponents[i]) {
97             case SwizzleComponent::ZERO:
98                 if (constantZeroIdx == -1) {
99                     // Synthesize a 'type(0)' argument at the end of the constructor.
100                     constructorArgs.push_back(ConstructorScalarCast::Make(
101                             context, offset, *numberType,
102                             IntLiteral::Make(context, offset, /*value=*/0)));
103                     constantZeroIdx = constantFieldIdx++;
104                 }
105                 swizzleComponents.push_back(constantZeroIdx);
106                 break;
107             case SwizzleComponent::ONE:
108                 if (constantOneIdx == -1) {
109                     // Synthesize a 'type(1)' argument at the end of the constructor.
110                     constructorArgs.push_back(ConstructorScalarCast::Make(
111                             context, offset, *numberType,
112                             IntLiteral::Make(context, offset, /*value=*/1)));
113                     constantOneIdx = constantFieldIdx++;
114                 }
115                 swizzleComponents.push_back(constantOneIdx);
116                 break;
117             default:
118                 // The non-constant fields are already in the expected order.
119                 swizzleComponents.push_back(maskFieldIdx++);
120                 break;
121         }
122     }
123 
124     expr = Constructor::Convert(context, offset,
125                                 numberType->toCompound(context, constantFieldIdx, /*rows=*/1),
126                                 std::move(constructorArgs));
127     if (!expr) {
128         return nullptr;
129     }
130 
131     return Swizzle::Make(context, std::move(expr), swizzleComponents);
132 }
133 
Make(const Context & context,std::unique_ptr<Expression> expr,ComponentArray components)134 std::unique_ptr<Expression> Swizzle::Make(const Context& context,
135                                           std::unique_ptr<Expression> expr,
136                                           ComponentArray components) {
137     const Type& exprType = expr->type();
138     SkASSERTF(exprType.isVector() || exprType.isScalar(),
139               "cannot swizzle type '%s'", exprType.description().c_str());
140     SkASSERT(components.count() >= 1 && components.count() <= 4);
141 
142     // Confirm that the component array only contains X/Y/Z/W. (Call MakeWith01 if you want support
143     // for ZERO and ONE. Once initial IR generation is complete, no swizzles should have zeros or
144     // ones in them.)
145     SkASSERT(std::all_of(components.begin(), components.end(), [](int8_t component) {
146         return component >= SwizzleComponent::X &&
147                component <= SwizzleComponent::W;
148     }));
149 
150     // SkSL supports splatting a scalar via `scalar.xxxx`, but not all versions of GLSL allow this.
151     // Replace swizzles with equivalent splat constructors (`scalar.xxx` --> `half3(value)`).
152     if (exprType.isScalar()) {
153         int offset = expr->fOffset;
154         return ConstructorSplat::Make(context, offset,
155                                       exprType.toCompound(context, components.size(), /*rows=*/1),
156                                       std::move(expr));
157     }
158 
159     if (context.fConfig->fSettings.fOptimize) {
160         // Detect identity swizzles like `color.rgba` and return the base-expression as-is.
161         if (components.count() == exprType.columns()) {
162             bool identity = true;
163             for (int i = 0; i < components.count(); ++i) {
164                 if (components[i] != i) {
165                     identity = false;
166                     break;
167                 }
168             }
169             if (identity) {
170                 return expr;
171             }
172         }
173 
174         // Optimize swizzles of swizzles, e.g. replace `foo.argb.rggg` with `foo.arrr`.
175         if (expr->is<Swizzle>()) {
176             Swizzle& base = expr->as<Swizzle>();
177             ComponentArray combined;
178             for (int8_t c : components) {
179                 combined.push_back(base.components()[c]);
180             }
181 
182             // It may actually be possible to further simplify this swizzle. Go again.
183             // (e.g. `color.abgr.abgr` --> `color.rgba` --> `color`.)
184             return Swizzle::Make(context, std::move(base.base()), combined);
185         }
186 
187         // If we are swizzling a constant expression, we can use its value instead here (so that
188         // swizzles like `colorWhite.x` can be simplified to `1`).
189         const Expression* value = ConstantFolder::GetConstantValueForVariable(*expr);
190 
191         // `half4(scalar).zyy` can be optimized to `half3(scalar)`, and `half3(scalar).y` can be
192         // optimized to just `scalar`. The swizzle components don't actually matter, as every field
193         // in a splat constructor holds the same value.
194         if (value->is<ConstructorSplat>()) {
195             const ConstructorSplat& splat = value->as<ConstructorSplat>();
196             return ConstructorSplat::Make(
197                     context, splat.fOffset,
198                     splat.type().componentType().toCompound(context, components.size(), /*rows=*/1),
199                     splat.argument()->clone());
200         }
201 
202         // Optimize swizzles of constructors.
203         if (value->isAnyConstructor()) {
204             const AnyConstructor& base = value->asAnyConstructor();
205             auto baseArguments = base.argumentSpan();
206             std::unique_ptr<Expression> replacement;
207             const Type& componentType = exprType.componentType();
208             int swizzleSize = components.size();
209 
210             // Swizzles can duplicate some elements and discard others, e.g.
211             // `half4(1, 2, 3, 4).xxz` --> `half3(1, 1, 3)`. However, there are constraints:
212             // - Expressions with side effects need to occur exactly once, even if they
213             //   would otherwise be swizzle-eliminated
214             // - Non-trivial expressions should not be repeated, but elimination is OK.
215             //
216             // Look up the argument for the constructor at each index. This is typically simple
217             // but for weird cases like `half4(bar.yz, half2(foo))`, it can be harder than it
218             // seems. This example would result in:
219             //     argMap[0] = {.fArgIndex = 0, .fComponent = 0}   (bar.yz     .x)
220             //     argMap[1] = {.fArgIndex = 0, .fComponent = 1}   (bar.yz     .y)
221             //     argMap[2] = {.fArgIndex = 1, .fComponent = 0}   (half2(foo) .x)
222             //     argMap[3] = {.fArgIndex = 1, .fComponent = 1}   (half2(foo) .y)
223             struct ConstructorArgMap {
224                 int8_t fArgIndex;
225                 int8_t fComponent;
226             };
227 
228             int numConstructorArgs = base.type().columns();
229             ConstructorArgMap argMap[4] = {};
230             int writeIdx = 0;
231             for (int argIdx = 0; argIdx < (int)baseArguments.size(); ++argIdx) {
232                 const Expression& arg = *baseArguments[argIdx];
233                 int argWidth = arg.type().columns();
234                 for (int componentIdx = 0; componentIdx < argWidth; ++componentIdx) {
235                     argMap[writeIdx].fArgIndex = argIdx;
236                     argMap[writeIdx].fComponent = componentIdx;
237                     ++writeIdx;
238                 }
239             }
240             SkASSERT(writeIdx == numConstructorArgs);
241 
242             // Count up the number of times each constructor argument is used by the
243             // swizzle.
244             //    `half4(bar.yz, half2(foo)).xwxy` -> { 3, 1 }
245             // - bar.yz    is referenced 3 times, by `.x_xy`
246             // - half(foo) is referenced 1 time,  by `._w__`
247             int8_t exprUsed[4] = {};
248             for (int8_t c : components) {
249                 exprUsed[argMap[c].fArgIndex]++;
250             }
251 
252             bool safeToOptimize = true;
253             for (int index = 0; index < numConstructorArgs; ++index) {
254                 int8_t constructorArgIndex = argMap[index].fArgIndex;
255                 const Expression& baseArg = *baseArguments[constructorArgIndex];
256 
257                 // Check that non-trivial expressions are not swizzled in more than once.
258                 if (exprUsed[constructorArgIndex] > 1 && !Analysis::IsTrivialExpression(baseArg)) {
259                     safeToOptimize = false;
260                     break;
261                 }
262                 // Check that side-effect-bearing expressions are swizzled in exactly once.
263                 if (exprUsed[constructorArgIndex] != 1 && baseArg.hasSideEffects()) {
264                     safeToOptimize = false;
265                     break;
266                 }
267             }
268 
269             if (safeToOptimize) {
270                 struct ReorderedArgument {
271                     int8_t fArgIndex;
272                     ComponentArray fComponents;
273                 };
274                 SkSTArray<4, ReorderedArgument> reorderedArgs;
275                 for (int8_t c : components) {
276                     const ConstructorArgMap& argument = argMap[c];
277                     const Expression& baseArg = *baseArguments[argument.fArgIndex];
278 
279                     if (baseArg.type().isScalar()) {
280                         // This argument is a scalar; add it to the list as-is.
281                         SkASSERT(argument.fComponent == 0);
282                         reorderedArgs.push_back({argument.fArgIndex,
283                                                  ComponentArray{}});
284                     } else {
285                         // This argument is a component from a vector.
286                         SkASSERT(argument.fComponent < baseArg.type().columns());
287                         if (reorderedArgs.empty() ||
288                             reorderedArgs.back().fArgIndex != argument.fArgIndex) {
289                             // This can't be combined with the previous argument. Add a new one.
290                             reorderedArgs.push_back({argument.fArgIndex,
291                                                      ComponentArray{argument.fComponent}});
292                         } else {
293                             // Since we know this argument uses components, it should already
294                             // have at least one component set.
295                             SkASSERT(!reorderedArgs.back().fComponents.empty());
296                             // Build up the current argument with one more component.
297                             reorderedArgs.back().fComponents.push_back(argument.fComponent);
298                         }
299                     }
300                 }
301 
302                 // Convert our reordered argument list to an actual array of expressions, with
303                 // the new order and any new inner swizzles that need to be applied.
304                 ExpressionArray newArgs;
305                 newArgs.reserve_back(swizzleSize);
306                 for (const ReorderedArgument& reorderedArg : reorderedArgs) {
307                     std::unique_ptr<Expression> newArg =
308                             baseArguments[reorderedArg.fArgIndex]->clone();
309 
310                     if (reorderedArg.fComponents.empty()) {
311                         newArgs.push_back(std::move(newArg));
312                     } else {
313                         newArgs.push_back(Swizzle::Make(context, std::move(newArg),
314                                                         reorderedArg.fComponents));
315                     }
316                 }
317 
318                 // Wrap the new argument list in a constructor.
319                 auto ctor = Constructor::Convert(
320                         context, base.fOffset,
321                         componentType.toCompound(context, swizzleSize, /*rows=*/1),
322                         std::move(newArgs));
323                 SkASSERT(ctor);
324                 return ctor;
325             }
326         }
327     }
328 
329     // The swizzle could not be simplified, so apply the requested swizzle to the base expression.
330     return std::make_unique<Swizzle>(context, std::move(expr), components);
331 }
332 
333 }  // namespace SkSL
334