1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/TranslatorMetalDirect/AddExplicitTypeCasts.h"
8 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
9 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
10 
11 using namespace sh;
12 
13 namespace
14 {
15 
16 class Rewriter : public TIntermRebuild
17 {
18     SymbolEnv &mSymbolEnv;
19 
20   public:
Rewriter(TCompiler & compiler,SymbolEnv & symbolEnv)21     Rewriter(TCompiler &compiler, SymbolEnv &symbolEnv)
22         : TIntermRebuild(compiler, false, true), mSymbolEnv(symbolEnv)
23     {}
24 
visitAggregatePost(TIntermAggregate & callNode)25     PostResult visitAggregatePost(TIntermAggregate &callNode) override
26     {
27         const size_t argCount = callNode.getChildCount();
28         const TType &retType  = callNode.getType();
29 
30         if (callNode.isConstructor())
31         {
32             if (IsScalarBasicType(retType))
33             {
34                 if (argCount == 1)
35                 {
36                     TIntermTyped &arg   = GetArg(callNode, 0);
37                     const TType argType = arg.getType();
38                     if (argType.isVector())
39                     {
40                         return CoerceSimple(retType, SubVector(arg, 0, 1));
41                     }
42                 }
43             }
44             else if (retType.isVector())
45             {
46                 if (argCount == 1)
47                 {
48                     TIntermTyped &arg   = GetArg(callNode, 0);
49                     const TType argType = arg.getType();
50                     if (argType.isVector())
51                     {
52                         return CoerceSimple(retType, SubVector(arg, 0, retType.getNominalSize()));
53                     }
54                 }
55                 for (size_t i = 0; i < argCount; ++i)
56                 {
57                     TIntermTyped &arg = GetArg(callNode, i);
58                     SetArg(callNode, i, CoerceSimple(retType.getBasicType(), arg));
59                 }
60             }
61             else if (retType.isMatrix())
62             {
63                 if (argCount == 1)
64                 {
65                     TIntermTyped &arg   = GetArg(callNode, 0);
66                     const TType argType = arg.getType();
67                     if (argType.isMatrix())
68                     {
69                         if (retType.getCols() != argType.getCols() ||
70                             retType.getRows() != argType.getRows())
71                         {
72                             TemplateArg templateArgs[] = {retType.getCols(), retType.getRows()};
73                             return mSymbolEnv.callFunctionOverload(
74                                 Name("cast"), retType, *new TIntermSequence{&arg}, 2, templateArgs);
75                         }
76                     }
77                 }
78             }
79         }
80 
81         return callNode;
82     }
83 };
84 
85 }  // anonymous namespace
86 
AddExplicitTypeCasts(TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv)87 bool sh::AddExplicitTypeCasts(TCompiler &compiler, TIntermBlock &root, SymbolEnv &symbolEnv)
88 {
89     Rewriter rewriter(compiler, symbolEnv);
90     if (!rewriter.rebuildRoot(root))
91     {
92         return false;
93     }
94     return true;
95 }
96