1 /* Copyright 2019 Google LLC. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef RUY_RUY_KERNEL_H_ 17 #define RUY_RUY_KERNEL_H_ 18 19 #include "ruy/kernel_common.h" 20 #include "ruy/mul_params.h" 21 #include "ruy/platform.h" 22 #include "ruy/trace.h" 23 24 // IWYU pragma: begin_exports 25 #if RUY_PLATFORM_NEON 26 #include "ruy/kernel_arm.h" 27 #elif RUY_PLATFORM_X86 28 #include "ruy/kernel_x86.h" 29 #endif 30 // IWYU pragma: end_exports 31 32 namespace ruy { 33 34 // KernelArgs is a helper to access the template parameter values from a Kernel 35 // template instantiation. 36 template <typename KernelType> 37 struct KernelArgs {}; 38 39 template <Path tPath, typename tLhsScalar, typename tRhsScalar, 40 typename tAccumScalar, typename tDstScalar> 41 struct KernelArgs< 42 Kernel<tPath, tLhsScalar, tRhsScalar, tAccumScalar, tDstScalar>> { 43 static constexpr Path kPath = tPath; 44 using LhsScalar = tLhsScalar; 45 using RhsScalar = tRhsScalar; 46 using AccumScalar = tAccumScalar; 47 using DstScalar = tDstScalar; 48 }; 49 50 // RunKernel::Run() is the only place that directly invokes Kernel::Run(). 51 // It performs the types un-erasure, and factoring all Kernel::Run() calls 52 // through this function also gives a single place where to conditionally 53 // implement RUY_OPT(FAT_KERNEL). This should be a function but is a class to 54 // hide and share some boilerplate (see the member types, and the RunTyped 55 // method also using them). 56 template <typename KernelType> 57 class RunKernel final { 58 public: 59 static void Run(Tuning tuning, const SidePair<PEMat>& src, 60 const void* mul_params, const SidePair<int>& start, 61 const SidePair<int>& end, EMat* dst) { 62 RUY_TRACE_SCOPE_NAME("RunKernel"); 63 const auto& unerased_lhs = UneraseType<LhsScalar>(src[Side::kLhs]); 64 const auto& unerased_rhs = UneraseType<RhsScalar>(src[Side::kRhs]); 65 auto unerased_dst = UneraseType<DstScalar>(*dst); 66 RUY_TRACE_INFO(RUN_KERNEL); 67 RunTyped(tuning, unerased_lhs, unerased_rhs, 68 *static_cast<const MulParamsType*>(mul_params), start, end, 69 &unerased_dst); 70 } 71 72 private: 73 using Args = KernelArgs<KernelType>; 74 using LhsScalar = typename Args::LhsScalar; 75 using RhsScalar = typename Args::RhsScalar; 76 using AccumScalar = typename Args::AccumScalar; 77 using DstScalar = typename Args::DstScalar; 78 using MulParamsType = MulParams<AccumScalar, DstScalar>; 79 static void RunTyped(Tuning tuning, const PMat<LhsScalar>& lhs, 80 const PMat<RhsScalar>& rhs, 81 const MulParamsType& mul_params, 82 const SidePair<int>& start, const SidePair<int>& end, 83 Mat<DstScalar>* dst) { 84 const int start_row = start[Side::kLhs]; 85 const int start_col = start[Side::kRhs]; 86 const int end_row = end[Side::kLhs]; 87 const int end_col = end[Side::kRhs]; 88 KernelType kernel(tuning); 89 using LhsLayout = typename KernelType::LhsLayout; 90 using RhsLayout = typename KernelType::RhsLayout; 91 // This is a good place to validate kernel layouts. The Kernel class 92 // template itself isn't a good place to do that because it has 93 // specializations. 94 // The kRows of both sides have to match: in TrMul, kRows is the depth 95 // dimension, on which LHS and RHS have to agree for the matrix 96 // multiplication to be defined at all, so requiring the corresponding 97 // dimension of the kernel layouts to also match is reasonable. If it didn't 98 // match, then the packed matrices could have mismatching depth dimensions 99 // even with the source matrices agreeing. 100 static_assert(LhsLayout::kRows == RhsLayout::kRows, ""); 101 // The kernel layouts have to be power-of-two. This simplifies BlockMap 102 // logic considerably. This also avoids leaking fine performance 103 // optimization details up the stack. For instance, if one of the dimensions 104 // were 6, then users might notice that optimal performance is achieved with 105 // matrix dimensions that are multiples of 6, and might start contorting 106 // their own application code to match that requirement, in a way that would 107 // not be future-proof. 108 static_assert(is_pot(LhsLayout::kRows), ""); 109 static_assert(is_pot(LhsLayout::kCols), ""); 110 static_assert(is_pot(RhsLayout::kRows), ""); 111 static_assert(is_pot(RhsLayout::kCols), ""); 112 // end_row and end_col may be larger than dst dimensions. 113 // that is because kernels write directly to the destination matrix, whose 114 // dimensions may not be a multiple of the kernel dimensions, and we try to 115 // keep this annoyance localized as an implementation detail in kernels, 116 // by allowing to pass rounded-up values down as far as possible. 117 // These assertions encode the contract. 118 RUY_DCHECK_LE(0, start_row); 119 RUY_DCHECK_LE(start_row, end_row); 120 RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols); 121 RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0); 122 RUY_DCHECK_LE(0, start_col); 123 RUY_DCHECK_LE(start_col, end_col); 124 RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols); 125 RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0); 126 #if RUY_OPT(FAT_KERNEL) 127 kernel.Run(lhs, rhs, mul_params, start_row, start_col, end_row, end_col, dst); 128 #else 129 for (int col = start_col; col < end_col; col += RhsLayout::kCols) { 130 int block_end_col = std::min(col + RhsLayout::kCols, end_col); 131 for (int row = start_row; row < end_row; row += LhsLayout::kCols) { 132 int block_end_row = std::min(row + LhsLayout::kCols, end_row); 133 kernel.Run(lhs, rhs, mul_params, row, col, block_end_row, block_end_col, 134 dst); 135 } 136 } 137 #endif 138 } 139 }; 140 141 template <Path ThePath> 142 struct StandardCppKernelLayout {}; 143 144 template <> 145 struct StandardCppKernelLayout<Path::kStandardCpp> { 146 using Lhs = FixedKernelLayout<Order::kColMajor, 1, 1>; 147 using Rhs = FixedKernelLayout<Order::kColMajor, 1, 1>; 148 }; 149 150 // A variant exercising RowMajor square blocks 151 template <> 152 struct StandardCppKernelLayout<Path::kInternalStandardCppVariant1> { 153 using Lhs = FixedKernelLayout<Order::kRowMajor, 4, 4>; 154 using Rhs = FixedKernelLayout<Order::kRowMajor, 4, 4>; 155 }; 156 157 // A variant with a rectangular layout: 4x8 158 template <> 159 struct StandardCppKernelLayout<Path::kInternalStandardCppVariant2> { 160 using Lhs = FixedKernelLayout<Order::kColMajor, 1, 4>; 161 using Rhs = FixedKernelLayout<Order::kColMajor, 1, 8>; 162 }; 163 164 // A variant with different block orders in LHS vs RHS. 165 template <> 166 struct StandardCppKernelLayout<Path::kInternalStandardCppVariant3> { 167 using Lhs = FixedKernelLayout<Order::kColMajor, 2, 16>; 168 using Rhs = FixedKernelLayout<Order::kRowMajor, 2, 8>; 169 }; 170 171 // General implementation of the Kernel template, overridden by template 172 // specializations for specific SIMD code paths. This general implementation 173 // covers Path::kStandardCpp and its internal test-only variants. 174 template <Path ThePath, typename LhsScalar, typename RhsScalar, 175 typename AccumScalar, typename DstScalar> 176 struct Kernel { 177 // Each Kernel specialization defines kPath as the ground-truth path that it 178 // implements. This is used in assertions. As we support fallbacks between 179 // paths (see RUY_INHERIT_KERNEL), Unless a specialization for a specific set 180 // of template parameters was defined, it is normal for template 181 // instantiations of the form Kernel<SomePath, ...> to have kPath!=SomePath. 182 // Assertions that kPath==SomePath are used in places where we know that we 183 // should be using a template specialization for a specific path rather than a 184 // fallback. 185 static constexpr Path kPath = ThePath; 186 using MulParamsType = MulParams<AccumScalar, DstScalar>; 187 using LhsLayout = typename StandardCppKernelLayout<ThePath>::Lhs; 188 using RhsLayout = typename StandardCppKernelLayout<ThePath>::Rhs; 189 explicit Kernel(Tuning) {} 190 void Run(const PMat<LhsScalar>& lhs, const PMat<RhsScalar>& rhs, 191 const MulParamsType& mul_params, int start_row, int start_col, 192 int end_row, int end_col, Mat<DstScalar>* dst) const { 193 // See the comment in RunKernelTyped. end_row may be larger than 194 // dst->layout.rows. It's the responsibility of the kernel to avoid 195 // overrunning dst boundaries, which we do here by computing 196 // clamped_end_row. 197 int clamped_end_row = std::min(end_row, dst->layout.rows); 198 int clamped_end_col = std::min(end_col, dst->layout.cols); 199 RUY_DCHECK_LE(0, start_row); 200 RUY_DCHECK_LE(start_row, clamped_end_row); 201 RUY_DCHECK_LE(clamped_end_row, dst->layout.rows); 202 RUY_DCHECK_LE(clamped_end_row, end_row); 203 RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols); 204 RUY_DCHECK_LE(0, start_col); 205 RUY_DCHECK_LE(start_col, clamped_end_col); 206 RUY_DCHECK_LE(clamped_end_col, dst->layout.cols); 207 RUY_DCHECK_LE(clamped_end_col, end_col); 208 RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols); 209 profiler::ScopeLabel label("Kernel (Standard Cpp)"); 210 const int depth = lhs.layout.rows; 211 for (int i = start_row; i < clamped_end_row; i++) { 212 for (int j = start_col; j < clamped_end_col; j++) { 213 AccumScalar accum = 0; 214 for (int k = 0; k < depth; k++) { 215 AccumScalar lhs_val = Element(lhs, k, i); 216 AccumScalar rhs_val = Element(rhs, k, j); 217 accum += lhs_val * rhs_val; 218 } 219 int channel = 220 mul_params.channel_dimension() == ChannelDimension::kRow ? i : j; 221 if (mul_params.bias()) { 222 accum += mul_params.bias()[channel]; 223 } 224 if (lhs.zero_point) { 225 accum -= lhs.zero_point * rhs.sums[j]; 226 } 227 if (rhs.zero_point) { 228 accum -= rhs.zero_point * lhs.sums[i]; 229 } 230 if (lhs.zero_point && rhs.zero_point) { 231 accum += lhs.zero_point * rhs.zero_point * depth; 232 } 233 ApplyMultiplier(mul_params, channel, &accum); 234 accum += dst->zero_point; 235 accum = std::min<AccumScalar>(accum, mul_params.clamp_max()); 236 accum = std::max<AccumScalar>(accum, mul_params.clamp_min()); 237 *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum); 238 } 239 } 240 } 241 }; 242 243 } // namespace ruy 244 245 #endif // RUY_RUY_KERNEL_H_ 246