1 
2 //
3 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
4 // Use of this source code is governed by a BSD-style license that can be
5 // found in the LICENSE file.
6 //
7 #include "compiler/translator/TranslatorMetalDirect/FixTypeConstructors.h"
8 #include <unordered_map>
9 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
10 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
11 #include "compiler/translator/tree_ops/SimplifyLoopConditions.h"
12 using namespace sh;
13 ////////////////////////////////////////////////////////////////////////////////
14 namespace
15 {
16 class FixTypeTraverser : public TIntermTraverser
17 {
18   public:
FixTypeTraverser()19     FixTypeTraverser() : TIntermTraverser(false, false, true) {}
20 
visitAggregate(Visit visit,TIntermAggregate * aggregateNode)21     bool visitAggregate(Visit visit, TIntermAggregate *aggregateNode) override
22     {
23         if (visit != Visit::PostVisit)
24         {
25             return true;
26         }
27         if (aggregateNode->isConstructor())
28         {
29             const TType &retType = aggregateNode->getType();
30             if (retType.isScalar())
31             {
32                 // No-op.
33             }
34             else if (retType.isVector())
35             {
36                 size_t primarySize    = retType.getNominalSize() * retType.getArraySizeProduct();
37                 TIntermSequence *args = aggregateNode->getSequence();
38                 size_t argsSize       = 0;
39                 size_t beforeSize     = 0;
40                 TIntermNode *lastArg  = nullptr;
41                 for (TIntermNode *&arg : *args)
42                 {
43                     TIntermTyped *targ = arg->getAsTyped();
44                     lastArg            = arg;
45                     if (targ)
46                     {
47                         argsSize += targ->getNominalSize();
48                     }
49                     if (argsSize <= primarySize)
50                     {
51                         beforeSize += targ->getNominalSize();
52                     }
53                 }
54                 if (argsSize > primarySize)
55                 {
56                     size_t swizzleSize         = primarySize - beforeSize;
57                     TIntermTyped *targ         = lastArg->getAsTyped();
58                     TIntermSwizzle *newSwizzle = nullptr;
59                     switch (swizzleSize)
60                     {
61                         case 1:
62                             newSwizzle = new TIntermSwizzle(targ->deepCopy(), {0});
63                             break;
64                         case 2:
65                             newSwizzle = new TIntermSwizzle(targ->deepCopy(), {0, 1});
66                             break;
67                         case 3:
68                             newSwizzle = new TIntermSwizzle(targ->deepCopy(), {0, 1, 2});
69                             break;
70                         default:
71                             UNREACHABLE();  // Should not be reached in case of 0, or 4
72                     }
73                     if (newSwizzle)
74                     {
75                         this->queueReplacementWithParent(aggregateNode, lastArg, newSwizzle,
76                                                          OriginalNode::IS_DROPPED);
77                     }
78                 }
79             }
80             else if (retType.isMatrix())
81             {
82                 // TBD if issues
83             }
84         }
85         return true;
86     }
87 };
88 
89 }  // anonymous namespace
90 
91 ////////////////////////////////////////////////////////////////////////////////
92 
FixTypeConstructors(TCompiler & compiler,SymbolEnv & symbolEnv,TIntermBlock & root)93 bool sh::FixTypeConstructors(TCompiler &compiler, SymbolEnv &symbolEnv, TIntermBlock &root)
94 {
95     FixTypeTraverser traverser;
96     root.traverse(&traverser);
97     if (!traverser.updateTree(&compiler, &root))
98     {
99         return false;
100     }
101     return true;
102 }
103