1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/TranslatorMetalDirect/AddExplicitTypeCasts.h"
8 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
9 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
10
11 using namespace sh;
12
13 namespace
14 {
15
16 class Rewriter : public TIntermRebuild
17 {
18 SymbolEnv &mSymbolEnv;
19
20 public:
Rewriter(TCompiler & compiler,SymbolEnv & symbolEnv)21 Rewriter(TCompiler &compiler, SymbolEnv &symbolEnv)
22 : TIntermRebuild(compiler, false, true), mSymbolEnv(symbolEnv)
23 {}
24
visitAggregatePost(TIntermAggregate & callNode)25 PostResult visitAggregatePost(TIntermAggregate &callNode) override
26 {
27 const size_t argCount = callNode.getChildCount();
28 const TType &retType = callNode.getType();
29
30 if (callNode.isConstructor())
31 {
32 if (IsScalarBasicType(retType))
33 {
34 if (argCount == 1)
35 {
36 TIntermTyped &arg = GetArg(callNode, 0);
37 const TType argType = arg.getType();
38 if (argType.isVector())
39 {
40 return CoerceSimple(retType, SubVector(arg, 0, 1));
41 }
42 }
43 }
44 else if (retType.isVector())
45 {
46 if (argCount == 1)
47 {
48 TIntermTyped &arg = GetArg(callNode, 0);
49 const TType argType = arg.getType();
50 if (argType.isVector())
51 {
52 return CoerceSimple(retType, SubVector(arg, 0, retType.getNominalSize()));
53 }
54 }
55 for (size_t i = 0; i < argCount; ++i)
56 {
57 TIntermTyped &arg = GetArg(callNode, i);
58 SetArg(callNode, i, CoerceSimple(retType.getBasicType(), arg));
59 }
60 }
61 else if (retType.isMatrix())
62 {
63 if (argCount == 1)
64 {
65 TIntermTyped &arg = GetArg(callNode, 0);
66 const TType argType = arg.getType();
67 if (argType.isMatrix())
68 {
69 if (retType.getCols() != argType.getCols() ||
70 retType.getRows() != argType.getRows())
71 {
72 TemplateArg templateArgs[] = {retType.getCols(), retType.getRows()};
73 return mSymbolEnv.callFunctionOverload(
74 Name("cast"), retType, *new TIntermSequence{&arg}, 2, templateArgs);
75 }
76 }
77 }
78 }
79 }
80
81 return callNode;
82 }
83 };
84
85 } // anonymous namespace
86
AddExplicitTypeCasts(TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv)87 bool sh::AddExplicitTypeCasts(TCompiler &compiler, TIntermBlock &root, SymbolEnv &symbolEnv)
88 {
89 Rewriter rewriter(compiler, symbolEnv);
90 if (!rewriter.rebuildRoot(root))
91 {
92 return false;
93 }
94 return true;
95 }
96