1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef BERBERIS_GUEST_ABI_GUEST_ARGUMENTS_ARCH_H_
18 #define BERBERIS_GUEST_ABI_GUEST_ARGUMENTS_ARCH_H_
19 
20 #include <array>
21 #include <tuple>
22 
23 #include "berberis/base/dependent_false.h"
24 #include "berberis/calling_conventions/calling_conventions_riscv64.h"
25 #include "berberis/guest_abi/guest_abi_arch.h"
26 
27 namespace berberis {
28 
29 struct GuestArgumentBuffer {
30   int argc;        // in integer registers.
31   int resc;        // in integer registers.
32   int fp_argc;     // in float registers.
33   int fp_resc;     // in float registers.
34   int stack_argc;  // in bytes.
35 
36   uint64_t argv[8];
37   uint64_t fp_argv[8];
38   uint64_t stack_argv[1];  // VLA.
39 };
40 
41 template <typename, GuestAbi::CallingConventionsVariant = GuestAbi::kDefaultAbi>
42 class GuestArgumentsAndResult;
43 
44 // GuestArguments is a typesafe wrapper around GuestArgumentBuffer.
45 // Usage looks like this:
46 //   GuestArguments<double(int, double, int, double)> args(*buf);
47 //   int x = args.Arguments<0>();
48 //   float y = args.Arguments<1>();
49 //   args.Result() = x * y;
50 
51 template <typename ResultType,
52           typename... ArgumentType,
53           bool kNoexcept,
54           GuestAbi::CallingConventionsVariant kCallingConventionsVariant>
55 class GuestArgumentsAndResult<ResultType(ArgumentType...) noexcept(kNoexcept),
56                               kCallingConventionsVariant> : GuestAbi {
57  public:
GuestArgumentsAndResult(GuestArgumentBuffer * buffer)58   GuestArgumentsAndResult(GuestArgumentBuffer* buffer) : buffer_(buffer) {}
59 
60   template <size_t index>
GuestArgument()61   auto& GuestArgument() const {
62     static_assert(index < sizeof...(ArgumentType));
63     using Type = std::tuple_element_t<index, std::tuple<ArgumentType...>>;
64     using ArgumentInfo = GuestArgumentInfo<Type, kCallingConventionsVariant>;
65     using CastType = typename ArgumentInfo::GuestType;
66     return Reference<ArgumentInfo, CastType>(kArgumentsLocations[index]);
67   }
68 
69   template <size_t index>
HostArgument()70   auto& HostArgument() const {
71     static_assert(index < sizeof...(ArgumentType));
72     using Type = std::tuple_element_t<index, std::tuple<ArgumentType...>>;
73     using ArgumentInfo = GuestArgumentInfo<Type, kCallingConventionsVariant>;
74     using CastType = typename ArgumentInfo::HostType;
75     return Reference<ArgumentInfo, CastType>(kArgumentsLocations[index]);
76   }
77 
GuestResult()78   auto& GuestResult() const {
79     static_assert(!std::is_same_v<ResultType, void>);
80     using ArgumentInfo = GuestArgumentInfo<ResultType, kCallingConventionsVariant>;
81     using CastType = typename ArgumentInfo::GuestType;
82     return Reference<ArgumentInfo, CastType>(kResultLocation);
83   }
84 
HostResult()85   auto& HostResult() const {
86     static_assert(!std::is_same_v<ResultType, void>);
87     using ArgumentInfo = GuestArgumentInfo<ResultType, kCallingConventionsVariant>;
88     using CastType = typename ArgumentInfo::HostType;
89     return Reference<ArgumentInfo, CastType>(kResultLocation);
90   }
91 
92  private:
93   template <typename ArgumentInfo, typename CastType>
Reference(riscv64::ArgLocation loc)94   auto& Reference(riscv64::ArgLocation loc) const {
95     if constexpr (ArgumentInfo::kArgumentClass == ArgumentClass::kLargeStruct) {
96       return **reinterpret_cast<CastType**>(ArgLocationAddress(loc));
97     } else {
98       return *reinterpret_cast<CastType*>(ArgLocationAddress(loc));
99     }
100   }
101 
102   constexpr static const std::tuple<riscv64::ArgLocation,
103                                     std::array<riscv64::ArgLocation, sizeof...(ArgumentType)>>
ArgumentsInfoHelper()104   ArgumentsInfoHelper() {
105     struct {
106       const ArgumentClass kArgumentClass;
107       const unsigned kSize;
108       const unsigned kAlignment;
109     } const kArgumentsInfo[] = {
110         {.kArgumentClass =
111              GuestArgumentInfo<ArgumentType, kCallingConventionsVariant>::kArgumentClass,
112          .kSize = GuestArgumentInfo<ArgumentType, kCallingConventionsVariant>::kSize,
113          .kAlignment = GuestArgumentInfo<ArgumentType, kCallingConventionsVariant>::kAlignment}...};
114 
115     riscv64::CallingConventions conv;
116     // The result location must be allocated before any arguments to ensure that the implicit a0
117     // argument for functions with large structure return types is reserved.
118     riscv64::ArgLocation result_loc = ResultInfoHelper(conv);
119     std::array<riscv64::ArgLocation, sizeof...(ArgumentType)> arg_locs{};
120     for (const auto& kArgInfo : kArgumentsInfo) {
121       if (kArgInfo.kArgumentClass == ArgumentClass::kInteger ||
122           kArgInfo.kArgumentClass == ArgumentClass::kLargeStruct) {
123         arg_locs[&kArgInfo - kArgumentsInfo] =
124             conv.GetNextIntArgLoc(kArgInfo.kSize, kArgInfo.kAlignment);
125       } else if (kArgInfo.kArgumentClass == ArgumentClass::kFp) {
126         arg_locs[&kArgInfo - kArgumentsInfo] =
127             conv.GetNextFpArgLoc(kArgInfo.kSize, kArgInfo.kAlignment);
128       } else {
129         LOG_ALWAYS_FATAL("Unsupported ArgumentClass");
130       }
131     }
132 
133     return {result_loc, arg_locs};
134   }
135 
ResultInfoHelper(riscv64::CallingConventions & conv)136   constexpr static riscv64::ArgLocation ResultInfoHelper(riscv64::CallingConventions& conv) {
137     using ResultInfo = GuestArgumentInfo<ResultType, kCallingConventionsVariant>;
138     if constexpr (std::is_same_v<ResultType, void>) {
139       return {riscv64::kArgLocationNone, 0};
140     } else if constexpr (ResultInfo::kArgumentClass == ArgumentClass::kInteger) {
141       return conv.GetIntResLoc(ResultInfo::kSize);
142     } else if constexpr (ResultInfo::kArgumentClass == ArgumentClass::kFp) {
143       return conv.GetFpResLoc(ResultInfo::kSize);
144     } else if constexpr (ResultInfo::kArgumentClass == ArgumentClass::kLargeStruct) {
145       // The caller allocates memory for large structure return values and passes the address in a0
146       // as an implicit parameter.  If the return type is a large structure, we must reserve a0 for
147       // this implicit parameter.
148       return conv.GetNextIntArgLoc(ResultInfo::kSize, ResultInfo::kAlignment);
149     } else {
150       static_assert(kDependentTypeFalse<ResultType>, "Unsupported ArgumentClass");
151     }
152   }
153 
ArgLocationAddress(riscv64::ArgLocation loc)154   constexpr void* ArgLocationAddress(riscv64::ArgLocation loc) const {
155     if (loc.kind == riscv64::kArgLocationStack) {
156       return reinterpret_cast<char*>(buffer_->stack_argv) + loc.offset;
157     } else if (loc.kind == riscv64::kArgLocationInt) {
158       return buffer_->argv + loc.offset;
159     } else if (loc.kind == riscv64::kArgLocationFp) {
160       return buffer_->fp_argv + loc.offset;
161     } else {
162       CHECK(false);
163     }
164   }
165 
166   constexpr static riscv64::ArgLocation kResultLocation = std::get<0>(ArgumentsInfoHelper());
167 
168   constexpr static std::array<riscv64::ArgLocation, sizeof...(ArgumentType)> kArgumentsLocations =
169       std::get<1>(ArgumentsInfoHelper());
170 
171   GuestArgumentBuffer* const buffer_;
172 };
173 
174 // Partial specialization for GuestArgumentsAndResult<FunctionToPointer> - it acts the same as the
175 // corresponding GuestArgumentsAndResult<Function>.
176 template <typename ResultType,
177           typename... ArgumentType,
178           bool kNoexcept,
179           GuestAbi::CallingConventionsVariant kCallingConventionsVariant>
180 class GuestArgumentsAndResult<ResultType (*)(ArgumentType...) noexcept(kNoexcept),
181                               kCallingConventionsVariant>
182     : public GuestArgumentsAndResult<ResultType(ArgumentType...) noexcept(kNoexcept),
183                                      kCallingConventionsVariant> {
184  public:
GuestArgumentsAndResult(GuestArgumentBuffer * buffer)185   GuestArgumentsAndResult(GuestArgumentBuffer* buffer)
186       : GuestArgumentsAndResult<ResultType(ArgumentType...) noexcept(kNoexcept),
187                                 kCallingConventionsVariant>(buffer) {}
188 };
189 
190 }  // namespace berberis
191 
192 #endif  // BERBERIS_GUEST_ABI_GUEST_ARGUMENTS_ARCH_H_
193