1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_KERNEL_H_
17 #define RUY_RUY_KERNEL_H_
18 
19 #include "ruy/kernel_common.h"
20 #include "ruy/mul_params.h"
21 #include "ruy/platform.h"
22 #include "ruy/trace.h"
23 
24 // IWYU pragma: begin_exports
25 #if RUY_PLATFORM_NEON
26 #include "ruy/kernel_arm.h"
27 #elif RUY_PLATFORM_X86
28 #include "ruy/kernel_x86.h"
29 #endif
30 // IWYU pragma: end_exports
31 
32 namespace ruy {
33 
34 // KernelArgs is a helper to access the template parameter values from a Kernel
35 // template instantiation.
36 template <typename KernelType>
37 struct KernelArgs {};
38 
39 template <Path tPath, typename tLhsScalar, typename tRhsScalar,
40           typename tAccumScalar, typename tDstScalar>
41 struct KernelArgs<
42     Kernel<tPath, tLhsScalar, tRhsScalar, tAccumScalar, tDstScalar>> {
43   static constexpr Path kPath = tPath;
44   using LhsScalar = tLhsScalar;
45   using RhsScalar = tRhsScalar;
46   using AccumScalar = tAccumScalar;
47   using DstScalar = tDstScalar;
48 };
49 
50 // RunKernel::Run() is the only place that directly invokes Kernel::Run().
51 // It performs the types un-erasure, and factoring all Kernel::Run() calls
52 // through this function also gives a single place where to conditionally
53 // implement RUY_OPT(FAT_KERNEL). This should be a function but is a class to
54 // hide and share some boilerplate (see the member types, and the RunTyped
55 // method also using them).
56 template <typename KernelType>
57 class RunKernel final {
58  public:
59   static void Run(Tuning tuning, const SidePair<PEMat>& src,
60                   const void* mul_params, const SidePair<int>& start,
61                   const SidePair<int>& end, EMat* dst) {
62     RUY_TRACE_SCOPE_NAME("RunKernel");
63     const auto& unerased_lhs = UneraseType<LhsScalar>(src[Side::kLhs]);
64     const auto& unerased_rhs = UneraseType<RhsScalar>(src[Side::kRhs]);
65     auto unerased_dst = UneraseType<DstScalar>(*dst);
66     RUY_TRACE_INFO(RUN_KERNEL);
67     RunTyped(tuning, unerased_lhs, unerased_rhs,
68              *static_cast<const MulParamsType*>(mul_params), start, end,
69              &unerased_dst);
70   }
71 
72  private:
73   using Args = KernelArgs<KernelType>;
74   using LhsScalar = typename Args::LhsScalar;
75   using RhsScalar = typename Args::RhsScalar;
76   using AccumScalar = typename Args::AccumScalar;
77   using DstScalar = typename Args::DstScalar;
78   using MulParamsType = MulParams<AccumScalar, DstScalar>;
79   static void RunTyped(Tuning tuning, const PMat<LhsScalar>& lhs,
80                        const PMat<RhsScalar>& rhs,
81                        const MulParamsType& mul_params,
82                        const SidePair<int>& start, const SidePair<int>& end,
83                        Mat<DstScalar>* dst) {
84     const int start_row = start[Side::kLhs];
85     const int start_col = start[Side::kRhs];
86     const int end_row = end[Side::kLhs];
87     const int end_col = end[Side::kRhs];
88     KernelType kernel(tuning);
89     using LhsLayout = typename KernelType::LhsLayout;
90     using RhsLayout = typename KernelType::RhsLayout;
91     // This is a good place to validate kernel layouts. The Kernel class
92     // template itself isn't a good place to do that because it has
93     // specializations.
94     // The kRows of both sides have to match: in TrMul, kRows is the depth
95     // dimension, on which LHS and RHS have to agree for the matrix
96     // multiplication to be defined at all, so requiring the corresponding
97     // dimension of the kernel layouts to also match is reasonable. If it didn't
98     // match, then the packed matrices could have mismatching depth dimensions
99     // even with the source matrices agreeing.
100     static_assert(LhsLayout::kRows == RhsLayout::kRows, "");
101     // The kernel layouts have to be power-of-two. This simplifies BlockMap
102     // logic considerably. This also avoids leaking fine performance
103     // optimization details up the stack. For instance, if one of the dimensions
104     // were 6, then users might notice that optimal performance is achieved with
105     // matrix dimensions that are multiples of 6, and might start contorting
106     // their own application code to match that requirement, in a way that would
107     // not be future-proof.
108     static_assert(is_pot(LhsLayout::kRows), "");
109     static_assert(is_pot(LhsLayout::kCols), "");
110     static_assert(is_pot(RhsLayout::kRows), "");
111     static_assert(is_pot(RhsLayout::kCols), "");
112     // end_row and end_col may be larger than dst dimensions.
113     // that is because kernels write directly to the destination matrix, whose
114     // dimensions may not be a multiple of the kernel dimensions, and we try to
115     // keep this annoyance localized as an implementation detail in kernels,
116     // by allowing to pass rounded-up values down as far as possible.
117     // These assertions encode the contract.
118     RUY_DCHECK_LE(0, start_row);
119     RUY_DCHECK_LE(start_row, end_row);
120     RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols);
121     RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0);
122     RUY_DCHECK_LE(0, start_col);
123     RUY_DCHECK_LE(start_col, end_col);
124     RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols);
125     RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0);
126 #if RUY_OPT(FAT_KERNEL)
127   kernel.Run(lhs, rhs, mul_params, start_row, start_col, end_row, end_col, dst);
128 #else
129   for (int col = start_col; col < end_col; col += RhsLayout::kCols) {
130     int block_end_col = std::min(col + RhsLayout::kCols, end_col);
131     for (int row = start_row; row < end_row; row += LhsLayout::kCols) {
132       int block_end_row = std::min(row + LhsLayout::kCols, end_row);
133       kernel.Run(lhs, rhs, mul_params, row, col, block_end_row, block_end_col,
134                  dst);
135     }
136   }
137 #endif
138   }
139 };
140 
141 template <Path ThePath>
142 struct StandardCppKernelLayout {};
143 
144 template <>
145 struct StandardCppKernelLayout<Path::kStandardCpp> {
146   using Lhs = FixedKernelLayout<Order::kColMajor, 1, 1>;
147   using Rhs = FixedKernelLayout<Order::kColMajor, 1, 1>;
148 };
149 
150 // A variant exercising RowMajor square blocks
151 template <>
152 struct StandardCppKernelLayout<Path::kInternalStandardCppVariant1> {
153   using Lhs = FixedKernelLayout<Order::kRowMajor, 4, 4>;
154   using Rhs = FixedKernelLayout<Order::kRowMajor, 4, 4>;
155 };
156 
157 // A variant with a rectangular layout: 4x8
158 template <>
159 struct StandardCppKernelLayout<Path::kInternalStandardCppVariant2> {
160   using Lhs = FixedKernelLayout<Order::kColMajor, 1, 4>;
161   using Rhs = FixedKernelLayout<Order::kColMajor, 1, 8>;
162 };
163 
164 // A variant with different block orders in LHS vs RHS.
165 template <>
166 struct StandardCppKernelLayout<Path::kInternalStandardCppVariant3> {
167   using Lhs = FixedKernelLayout<Order::kColMajor, 2, 16>;
168   using Rhs = FixedKernelLayout<Order::kRowMajor, 2, 8>;
169 };
170 
171 // General implementation of the Kernel template, overridden by template
172 // specializations for specific SIMD code paths. This general implementation
173 // covers Path::kStandardCpp and its internal test-only variants.
174 template <Path ThePath, typename LhsScalar, typename RhsScalar,
175           typename AccumScalar, typename DstScalar>
176 struct Kernel {
177   // Each Kernel specialization defines kPath as the ground-truth path that it
178   // implements. This is used in assertions. As we support fallbacks between
179   // paths (see RUY_INHERIT_KERNEL), Unless a specialization for a specific set
180   // of template parameters was defined, it is normal for template
181   // instantiations of the form Kernel<SomePath, ...> to have kPath!=SomePath.
182   // Assertions that kPath==SomePath are used in places where we know that we
183   // should be using a template specialization for a specific path rather than a
184   // fallback.
185   static constexpr Path kPath = ThePath;
186   using MulParamsType = MulParams<AccumScalar, DstScalar>;
187   using LhsLayout = typename StandardCppKernelLayout<ThePath>::Lhs;
188   using RhsLayout = typename StandardCppKernelLayout<ThePath>::Rhs;
189   explicit Kernel(Tuning) {}
190   void Run(const PMat<LhsScalar>& lhs, const PMat<RhsScalar>& rhs,
191            const MulParamsType& mul_params, int start_row, int start_col,
192            int end_row, int end_col, Mat<DstScalar>* dst) const {
193     // See the comment in RunKernelTyped. end_row may be larger than
194     // dst->layout.rows. It's the responsibility of the kernel to avoid
195     // overrunning dst boundaries, which we do here by computing
196     // clamped_end_row.
197     int clamped_end_row = std::min(end_row, dst->layout.rows);
198     int clamped_end_col = std::min(end_col, dst->layout.cols);
199     RUY_DCHECK_LE(0, start_row);
200     RUY_DCHECK_LE(start_row, clamped_end_row);
201     RUY_DCHECK_LE(clamped_end_row, dst->layout.rows);
202     RUY_DCHECK_LE(clamped_end_row, end_row);
203     RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols);
204     RUY_DCHECK_LE(0, start_col);
205     RUY_DCHECK_LE(start_col, clamped_end_col);
206     RUY_DCHECK_LE(clamped_end_col, dst->layout.cols);
207     RUY_DCHECK_LE(clamped_end_col, end_col);
208     RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols);
209     profiler::ScopeLabel label("Kernel (Standard Cpp)");
210     const int depth = lhs.layout.rows;
211     for (int i = start_row; i < clamped_end_row; i++) {
212       for (int j = start_col; j < clamped_end_col; j++) {
213         AccumScalar accum = 0;
214         for (int k = 0; k < depth; k++) {
215           AccumScalar lhs_val = Element(lhs, k, i);
216           AccumScalar rhs_val = Element(rhs, k, j);
217           accum += lhs_val * rhs_val;
218         }
219         int channel =
220             mul_params.channel_dimension() == ChannelDimension::kRow ? i : j;
221         if (mul_params.bias()) {
222           accum += mul_params.bias()[channel];
223         }
224         if (lhs.zero_point) {
225           accum -= lhs.zero_point * rhs.sums[j];
226         }
227         if (rhs.zero_point) {
228           accum -= rhs.zero_point * lhs.sums[i];
229         }
230         if (lhs.zero_point && rhs.zero_point) {
231           accum += lhs.zero_point * rhs.zero_point * depth;
232         }
233         ApplyMultiplier(mul_params, channel, &accum);
234         accum += dst->zero_point;
235         accum = std::min<AccumScalar>(accum, mul_params.clamp_max());
236         accum = std::max<AccumScalar>(accum, mul_params.clamp_min());
237         *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum);
238       }
239     }
240   }
241 };
242 
243 }  // namespace ruy
244 
245 #endif  // RUY_RUY_KERNEL_H_
246