1 /* Copyright 2020 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required_capacity by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implementation of CreateTrMulParams, see function comment.
17 
18 #ifndef RUY_RUY_CREATE_TRMUL_PARAMS_H_
19 #define RUY_RUY_CREATE_TRMUL_PARAMS_H_
20 
21 #include <cstdint>
22 #include <cstring>
23 #include <type_traits>
24 
25 #include "ruy/allocator.h"
26 #include "ruy/ctx.h"
27 #include "ruy/kernel.h"
28 #include "ruy/mat.h"
29 #include "ruy/mul_params.h"
30 #include "ruy/pack.h"
31 #include "ruy/path.h"
32 #include "ruy/performance_advisory.h"
33 #include "ruy/trace.h"
34 #include "ruy/trmul_params.h"
35 
36 namespace ruy {
37 // While the only entry point to this file is CreateTrMulParams, its templatized
38 // nature requires putting more code in this header than we would like. This
39 // internal implementation code is enclosed in namespace 'detail'.
40 namespace detail {
41 
CreatePackedLayout(const MatLayout & src,const KernelLayout & kernel_layout,PMatLayout * packed_layout)42 inline void CreatePackedLayout(const MatLayout& src,
43                                const KernelLayout& kernel_layout,
44                                PMatLayout* packed_layout) {
45   // Packed matrices are always column-major, because in TrMul that is always
46   // the dimension of traversal of the kernel's inner loop.
47   packed_layout->order = Order::kColMajor;
48   packed_layout->rows = round_up_pot(src.rows, kernel_layout.rows);
49   packed_layout->cols = round_up_pot(src.cols, kernel_layout.cols);
50   packed_layout->stride = packed_layout->rows;
51   packed_layout->kernel = kernel_layout;
52 }
53 
54 template <typename Scalar, typename PackedScalar>
CreatePackedMatrix(Side side,const KernelLayout & kernel_layout,TrMulParams * params)55 void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout,
56                         TrMulParams* params) {
57   // Ruy always uses 32-bit signed accumulators for quantized
58   // matrix multiplication, so we would like to always use std::int32_t
59   // unconditionally for SumsType.
60   // However, for floating point types, we still need a reasonable type here to
61   // avoid tripping assertions elsewhere in the code.
62   using SumsType =
63       typename std::conditional<std::is_floating_point<Scalar>::value, Scalar,
64                                 std::int32_t>::type;
65 
66   const EMat& src = params->src[side];
67   PEMat* packed_matrix = &params->packed_matrix[side];
68   packed_matrix->data_type = Type::Create<PackedScalar>();
69   packed_matrix->sums_type = Type::Create<SumsType>();
70   CreatePackedLayout(src.layout, kernel_layout, &packed_matrix->layout);
71   packed_matrix->zero_point = Pack<PackedScalar, Scalar>(src.zero_point);
72 }
73 
74 template <typename KernelType>
75 struct CheckKernelPathImpl {
RunCheckKernelPathImpl76   static void Run(Path) {
77     // Do nothing.
78     // Path fallbacks are normal in general (see RUY_INHERIT_KERNEL).
79     // That is to say that one may instantiate ruy::Mul with a weird combination
80     // of types, such as LhsScalar==float and RhsScalar==double, and have it
81     // work by silently falling back to Path::kStandardCpp. Only in specific
82     // cases do we have dedicated kernels overriding that fallback, and that is
83     // what partial specializations of this template will check.
84   }
85 };
86 
87 #if RUY_DCHECK_IS_ENABLED
88 template <Path ThePath, typename SrcScalar, typename AccumScalar,
89           typename DstScalar>
90 struct CheckKernelPathImpl<Kernel<ThePath, SrcScalar, SrcScalar, DstScalar,
91                                   MulParams<AccumScalar, DstScalar>>>
92     final {
93   using KernelType = Kernel<ThePath, SrcScalar, SrcScalar, DstScalar,
94                             MulParams<AccumScalar, DstScalar>>;
95   static void Run(Path expected_path) {
96     // We want to assert that we are using a dedicated Kernel specialization and
97     // not a fallback when we know we are in a case where such a kernel
98     // specialization exists. At the moment in the current state of ruy's
99     // architecture support for ARM and x86, that is when LhsScalar==RhsScalar
100     // (already implied in this partial specialization) and when that type is
101     // either float, int8, or uint8. Indeed, we have kernels supporting float
102     // and int8, and we have the packing code converting uint8 to int8 (see
103     // PackedTypeImpl).
104     static constexpr bool kSrcScalarTypeSupportsFastKernels =
105         std::is_same<SrcScalar, float>::value ||
106         std::is_same<SrcScalar, std::int8_t>::value ||
107         std::is_same<SrcScalar, std::uint8_t>::value;
108     if (kSrcScalarTypeSupportsFastKernels) {
109       RUY_DCHECK_EQ(expected_path, KernelType::kPath);
110     }
111   }
112 };
113 #endif
114 
115 template <typename KernelType>
116 void CheckKernelPath(Path expected_path) {
117   CheckKernelPathImpl<KernelType>::Run(expected_path);
118 }
119 
120 template <Path ThePath, typename LhsScalar, typename RhsScalar,
121           typename AccumScalar, typename DstScalar>
122 void PopulateTrMulParams(TrMulParams* params) {
123   RUY_TRACE_SCOPE;
124   using PackedLhsScalar = PackedType<ThePath, LhsScalar>;
125   using PackedRhsScalar = PackedType<ThePath, RhsScalar>;
126   using Kernel =
127       Kernel<ThePath, PackedLhsScalar, PackedRhsScalar, AccumScalar, DstScalar>;
128   using LhsKernelLayout = typename Kernel::LhsLayout;
129   using RhsKernelLayout = typename Kernel::RhsLayout;
130 
131   params->path = ThePath;
132 
133   CreatePackedMatrix<LhsScalar, PackedLhsScalar>(
134       Side::kLhs, ToKernelLayout<LhsKernelLayout>(), params);
135   CreatePackedMatrix<RhsScalar, PackedRhsScalar>(
136       Side::kRhs, ToKernelLayout<RhsKernelLayout>(), params);
137   params->run_pack[Side::kLhs] =
138       &RunPack<ThePath, LhsKernelLayout, LhsScalar, PackedLhsScalar>;
139   params->run_pack[Side::kRhs] =
140       &RunPack<ThePath, RhsKernelLayout, RhsScalar, PackedRhsScalar>;
141   params->run_kernel = &RunKernel<Kernel>::Run;
142   CheckKernelPath<Kernel>(ThePath);
143   RUY_TRACE_INFO(POPULATE_TRMUL_PARAMS);
144 }
145 
146 // PopulateTrMulParamsAllCompiledPaths calls into one of multiple
147 // instantiations of PopulateTrMulParams. For each bit that is set in
148 // CompiledPaths, it statically instantiates PopulateTrMulParams with a Path
149 // corresponding to that single bit. The call to PopulateTrMulParams is
150 // guarded by a runtime check that it is in fact the dynamically selected path.
151 //
152 // PopulateTrMulParamsAllCompiledPaths is implemented with template
153 // metaprogramming by mutual recursion between PathSearchCountdown and
154 // PathSearchCompiledPaths.
155 //
156 // PopulateTrMulParamsAllCompiledPaths is logically implementing the following
157 // computation:
158 //
159 // template <Path CompiledPaths>
160 // void PopulateTrMulParamsAllCompiledPaths(Path the_path,
161 //                                            TrMulParams* params) {
162 //   for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1]
163 //     Path current_path = static_cast<Path>(1 << bit);
164 //     if ((CompiledPaths & current_path) != Path::kNone) { // [2]
165 //       if (current_path == the_path) { // [3]
166 //         PopulateTrMulParams<current_path, ...>(the_path, params);
167 //         return;
168 //       }
169 //     }
170 //   }
171 // }
172 //
173 //
174 //
175 // [1] - Done by the main definition of PathSearchCountdown. The `bit--` is
176 // done in the recursion of PathSearchOnlyCompiledPaths.
177 // [2] - Done by PathSearchOnlyCompiledPaths's partial template
178 // specialization on InCompiledPaths. This is the check which necessitates
179 // doing the whole computation at C++ compile time.
180 // [3] - Done by the `if` in the main definition of
181 // PathSearchOnlyCompiledPaths.
182 //
183 // The template metaprogramming is necessary because:
184 // - In `PopulateTrMulParams<current_path, ...>`, current_path must be a C++
185 // compile-time constant.
186 // - PopulateTrMulParamsAllCompiledPaths must not instantiate
187 // inner loops for paths that are not in CompiledPaths, since that can result in
188 // bogus instantiations which cause a compile time failure.
189 template <Path CompiledPaths, int BitNumber, typename LhsScalar,
190           typename RhsScalar, typename AccumScalar, typename DstScalar>
191 struct PathSearchCountdown;
192 
193 template <Path CompiledPaths, bool InCompiledPaths, int BitNumber,
194           typename LhsScalar, typename RhsScalar, typename AccumScalar,
195           typename DstScalar>
196 struct PathSearchOnlyCompiledPaths {
197   static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber);
198   static void Search(Path the_path, TrMulParams* params) {
199     if (kCurrentPath == the_path) {
200       PopulateTrMulParams<kCurrentPath, LhsScalar, RhsScalar, AccumScalar,
201                           DstScalar>(params);
202       return;
203     }
204     PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar,
205                         AccumScalar, DstScalar>::Search(the_path, params);
206   }
207 };
208 
209 // Skip this iteration if CompiledPaths doesn't contain the specified path.
210 template <Path CompiledPaths, int BitNumber, typename LhsScalar,
211           typename RhsScalar, typename AccumScalar, typename DstScalar>
212 struct PathSearchOnlyCompiledPaths<CompiledPaths, false, BitNumber, LhsScalar,
213                                    RhsScalar, AccumScalar, DstScalar> {
214   static void Search(Path the_path, TrMulParams* params) {
215     PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar,
216                         AccumScalar, DstScalar>::Search(the_path, params);
217   }
218 };
219 
220 template <Path CompiledPaths, int BitNumber, typename LhsScalar,
221           typename RhsScalar, typename AccumScalar, typename DstScalar>
222 struct PathSearchCountdown {
223   static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber);
224   static void Search(Path the_path, TrMulParams* params) {
225     PathSearchOnlyCompiledPaths<
226         CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber,
227         LhsScalar, RhsScalar, AccumScalar, DstScalar>::Search(the_path, params);
228   }
229 };
230 
231 // Termination of the countdown. If the counter reaches -1, then we haven't
232 // found the specified path.
233 template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
234           typename AccumScalar, typename DstScalar>
235 struct PathSearchCountdown<CompiledPaths, -1, LhsScalar, RhsScalar, AccumScalar,
236                            DstScalar> {
237   static void Search(Path, TrMulParams*) { RUY_DCHECK(false); }
238 };
239 
240 template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
241           typename AccumScalar, typename DstScalar>
242 void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) {
243   RUY_TRACE_SCOPE;
244   return PathSearchCountdown<CompiledPaths, 8 * sizeof(Path) - 1, LhsScalar,
245                              RhsScalar, AccumScalar,
246                              DstScalar>::Search(the_path, params);
247 }
248 
249 template <typename AccumScalar, typename DstScalar>
250 void AssertThatExtraCapacityInPerChannelBuffersIsZeroInitialized(
251     const MulParams<AccumScalar, DstScalar>& mul_params, int user_size,
252     int user_capacity) {
253 #if RUY_DCHECK_IS_ENABLED
254   if (mul_params.bias()) {
255     for (int i = user_size; i < user_capacity; i++) {
256       RUY_DCHECK_EQ(mul_params.bias()[i], 0);
257     }
258   }
259   if (mul_params.multiplier_fixedpoint_perchannel()) {
260     for (int i = user_size; i < user_capacity; i++) {
261       RUY_DCHECK_EQ(mul_params.multiplier_fixedpoint_perchannel()[i], 0);
262     }
263   }
264   if (mul_params.multiplier_exponent_perchannel()) {
265     for (int i = user_size; i < user_capacity; i++) {
266       RUY_DCHECK_EQ(mul_params.multiplier_exponent_perchannel()[i], 0);
267     }
268   }
269 #else
270   (void)mul_params;
271   (void)user_size;
272   (void)user_capacity;
273 #endif
274 }
275 
276 template <typename AccumScalar, typename DstScalar,
277           bool HaveQuantizedMultipliers =
278               std::is_same<AccumScalar, std::int32_t>::value &&
279               !std::is_same<DstScalar, std::int32_t>::value>
280 struct EnsurePerChannelBuffersLargeEnoughImpl {
281   static void Run(const TrMulParams& params, Allocator* allocator,
282                   MulParams<AccumScalar, DstScalar>* mul_params) {
283     const Side channel_side =
284         mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs
285                                                                   : Side::kRhs;
286     const int required_capacity =
287         params.packed_matrix[channel_side].layout.cols;
288     const int user_size = params.src[channel_side].layout.cols;
289     const int user_capacity = round_up_pot(
290         user_size, mul_params->perchannel_buffers_capacity_rounding());
291     // We should have already checked earlier for the case where
292     // user_capacity >= required_capacity.
293     RUY_DCHECK_GT(required_capacity, user_capacity);
294     if (mul_params->bias()) {
295       AccumScalar* new_data =
296           allocator->Allocate<AccumScalar>(required_capacity);
297       std::memcpy(new_data, mul_params->bias(),
298                   user_size * sizeof(AccumScalar));
299       std::memset(new_data + user_size, 0,
300                   (required_capacity - user_size) * sizeof(AccumScalar));
301       mul_params->set_bias(new_data);
302     }
303     if (mul_params->multiplier_fixedpoint_perchannel()) {
304       AccumScalar* new_data =
305           allocator->Allocate<AccumScalar>(required_capacity);
306       std::memcpy(new_data, mul_params->multiplier_fixedpoint_perchannel(),
307                   user_size * sizeof(AccumScalar));
308       std::memset(new_data + user_size, 0,
309                   (required_capacity - user_size) * sizeof(AccumScalar));
310       mul_params->set_multiplier_fixedpoint_perchannel(new_data);
311     }
312     if (mul_params->multiplier_exponent_perchannel()) {
313       int* new_data = allocator->Allocate<int>(required_capacity);
314       std::memcpy(new_data, mul_params->multiplier_exponent_perchannel(),
315                   user_size * sizeof(int));
316       std::memset(new_data + user_size, 0,
317                   (required_capacity - user_size) * sizeof(int));
318       mul_params->set_multiplier_exponent_perchannel(new_data);
319     }
320   }
321 };
322 
323 template <typename AccumScalar, typename DstScalar>
324 struct EnsurePerChannelBuffersLargeEnoughImpl<AccumScalar, DstScalar, false> {
325   static void Run(const TrMulParams& params, Allocator* allocator,
326                   MulParams<AccumScalar, DstScalar>* mul_params) {
327     const Side channel_side =
328         mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs
329                                                                   : Side::kRhs;
330     const int required_capacity =
331         params.packed_matrix[channel_side].layout.cols;
332     const int user_size = params.src[channel_side].layout.cols;
333     const int user_capacity = round_up_pot(
334         user_size, mul_params->perchannel_buffers_capacity_rounding());
335     // We should have already checked earlier for the case where
336     // user_capacity >= required_capacity.
337     RUY_DCHECK_GT(required_capacity, user_capacity);
338     if (mul_params->bias()) {
339       AccumScalar* new_data =
340           allocator->Allocate<AccumScalar>(required_capacity);
341       std::memcpy(new_data, mul_params->bias(),
342                   user_size * sizeof(AccumScalar));
343       std::memset(new_data + user_size, 0,
344                   (required_capacity - user_size) * sizeof(AccumScalar));
345       mul_params->set_bias(new_data);
346     }
347   }
348 };
349 
350 template <typename AccumScalar, typename DstScalar>
351 void EnsurePerChannelBuffersLargeEnough(
352     const TrMulParams& params, Ctx* ctx,
353     MulParams<AccumScalar, DstScalar>* mul_params) {
354   // Early exit in the common case where the packed matrix size matches the
355   // number of channels (as opposed to having been rounded up to a slightly
356   // larger value).
357   const Side channel_side =
358       mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs
359                                                                 : Side::kRhs;
360   const int required_capacity = params.packed_matrix[channel_side].layout.cols;
361   const int user_size = params.src[channel_side].layout.cols;
362   const int user_capacity = round_up_pot(
363       user_size, mul_params->perchannel_buffers_capacity_rounding());
364   AssertThatExtraCapacityInPerChannelBuffersIsZeroInitialized(
365       *mul_params, user_size, user_capacity);
366   if (required_capacity <= user_capacity) {
367     return;
368   }
369   ctx->set_performance_advisory(
370       PerformanceAdvisory::kReallocatedPerChannelBuffer);
371   EnsurePerChannelBuffersLargeEnoughImpl<AccumScalar, DstScalar>::Run(
372       params, ctx->GetMainAllocator(), mul_params);
373 }
374 
375 // Ensures that `params->mul_params_bytes` contains MulParams data that's ready
376 // to be consumed by the kernel. As a first-order approximation, that is simply
377 // copying the user-provided `mul_params`, however there are a few changes.
378 //
379 //   1. The specified `channel_dimension` value overrides the channel_dimension
380 //      member in `mul_params`. The reason why `channel_dimension` is being
381 //      special-cased among MulParams members is that we will need to transpose
382 //      MulParams, and that consists just in toggling channel_dimension.
383 //   2. Per-channel buffers may be reallocated, see
384 //      EnsurePerChannelBuffersLargeEnough.
385 template <typename AccumScalar, typename DstScalar>
386 void FinalizeMulParams(const MulParams<AccumScalar, DstScalar>& mul_params,
387                        ChannelDimension channel_dimension, Ctx* ctx,
388                        TrMulParams* params) {
389   using MulParamsType = MulParams<AccumScalar, DstScalar>;
390   static_assert(alignof(MulParamsType) <= kMaxMulParamsAlignment, "");
391   static_assert(sizeof(MulParamsType) <= kMaxMulParamsSize, "");
392   static_assert(std::is_trivially_copyable<MulParamsType>::value, "");
393   auto* dst_mul_params =
394       reinterpret_cast<MulParamsType*>(params->mul_params_bytes);
395   std::memcpy(dst_mul_params, &mul_params, sizeof(MulParamsType));
396   dst_mul_params->set_channel_dimension(channel_dimension);
397   EnsurePerChannelBuffersLargeEnough(*params, ctx, dst_mul_params);
398 }
399 
400 // In this function, the `channel_dimension` parameter overrides the value
401 // of the channel_dimension member in the `mul_params` parameter. See the
402 // FinalizeMulParams comment.
403 template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
404           typename AccumScalar, typename DstScalar>
405 void CreateTrMulParamsAssumingColMajorDst(
406     const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
407     const Mat<DstScalar>& dst,
408     const MulParams<AccumScalar, DstScalar>& mul_params,
409     ChannelDimension channel_dimension, Ctx* ctx, TrMulParams* params) {
410   RUY_TRACE_SCOPE;
411   RUY_DCHECK(IsColMajor(dst.layout));
412 
413   // Fill in the fields we already know.
414   params->src[Side::kLhs] = EraseType(lhs);
415   params->src[Side::kRhs] = EraseType(rhs);
416   params->dst = EraseType(dst);
417 
418   // Determine which exact Path we're going to take in this Mul call.
419   // This is cheap because it's cached in `ctx`. In user scenarios this always
420   // evaluates to the same value on a given machine with given `CompiledPaths`,
421   // but could be invalidated by a call to Ctx::SetRuntimeEnabledPaths(), which
422   // might be exposed publicly in Context in the future.
423   const Path the_path = ctx->SelectPath(CompiledPaths);
424 
425   RUY_TRACE_INFO(CREATE_TRMUL_PARAMS_ASSUMING_COLMAJOR_DST);
426 
427   // If we ever need again to fall back to Path::kStandardCpp, this is a good
428   // place to do it -- just pass Path::kStandardCpp as both the template and
429   // runtime parameters in this function call.
430   // In the past we did that here (as version control history remembers).
431   // A typical reason why we might need to resurrect that is if we implement
432   // a new Path (i.e. port to a new ISA) and need to subdivide that work into
433   // a series of incremental changes.
434   PopulateTrMulParamsAllCompiledPaths<CompiledPaths, LhsScalar, RhsScalar,
435                                       AccumScalar, DstScalar>(the_path, params);
436 
437   // This must be done last, as it depends on the specific choice of kernel.
438   // Specifically, the EnsurePerChannelBuffersLargeEnough part of this will read
439   // the packed matrix layouts that are written to `params` by the above
440   // PopulateTrMulParams* call.
441   FinalizeMulParams(mul_params, channel_dimension, ctx, params);
442 }
443 
444 }  // namespace detail
445 
446 inline ChannelDimension Transpose(ChannelDimension channel_dimension) {
447   return channel_dimension == ChannelDimension::kCol ? ChannelDimension::kRow
448                                                      : ChannelDimension::kCol;
449 }
450 
451 // CreateTrMulParams's output is a TrMulParams object that encodes
452 // all of the input information required_capacity by the middle-end, that is,
453 // the TrMul function.
454 //
455 // CreateTrMulParams performs the following tasks:
456 //   1. Reduce to the case of column-major destination, by transposing the
457 //      whole problem as needed.
458 //   2. Select the single code path to be taken, out of the set of paths
459 //      described by the `CompiledPaths` template parameter, based on the
460 //      runtime input parameter `the_path`.
461 //   3. Perform type-erasure, converting templatized typed input parameters
462 //      to the un-typed data stored in TrMulParams.
463 template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
464           typename AccumScalar, typename DstScalar>
465 void CreateTrMulParams(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
466                        const Mat<DstScalar>& dst,
467                        const MulParams<AccumScalar, DstScalar>& mul_params,
468                        Ctx* ctx, TrMulParams* params) {
469   RUY_TRACE_SCOPE;
470   ChannelDimension channel_dimension = mul_params.channel_dimension();
471   if (IsColMajor(dst.layout)) {
472     detail::CreateTrMulParamsAssumingColMajorDst<CompiledPaths>(
473         lhs, rhs, dst, mul_params, channel_dimension, ctx, params);
474   } else {
475     RUY_TRACE_INFO(CREATE_TRMUL_PARAMS_TRANSPOSING);
476     detail::CreateTrMulParamsAssumingColMajorDst<CompiledPaths>(
477         rhs, lhs, Transpose(dst), mul_params, Transpose(channel_dimension), ctx,
478         params);
479   }
480 }
481 
482 }  // namespace ruy
483 
484 #endif  // RUY_RUY_CREATE_TRMUL_PARAMS_H_
485