1 /* Copyright 2019 Google LLC. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef RUY_RUY_KERNEL_ARM_H_ 17 #define RUY_RUY_KERNEL_ARM_H_ 18 19 #include <cstddef> 20 #include <cstdint> 21 22 #include "ruy/asm_helpers.h" 23 #include "ruy/kernel_common.h" 24 #include "ruy/mat.h" 25 #include "ruy/mul_params.h" 26 #include "ruy/opt_set.h" 27 #include "ruy/path.h" 28 #include "ruy/platform.h" 29 #include "ruy/profiler/instrumentation.h" 30 #include "ruy/side_pair.h" 31 #include "ruy/size_util.h" 32 #include "ruy/tune.h" 33 34 namespace ruy { 35 36 #if RUY_PLATFORM_NEON && RUY_OPT(ASM) 37 38 RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon) 39 RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod) 40 41 #if RUY_PLATFORM_NEON_64 42 void Kernel8bitNeon(const KernelParams8bit<4, 4>& params); 43 void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params); 44 #elif RUY_PLATFORM_NEON_32 45 void Kernel8bitNeon(const KernelParams8bit<4, 2>& params); 46 void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params); 47 #endif 48 void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params); 49 void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params); 50 void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params); 51 void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params); 52 53 #if RUY_PLATFORM_NEON_64 54 template <typename DstScalar> 55 struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> { 56 static constexpr Path kPath = Path::kNeon; 57 using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>; 58 using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>; 59 Tuning tuning = Tuning::kAuto; 60 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 61 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, 62 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, 63 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { 64 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; 65 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, 66 end_col, dst, ¶ms); 67 if (dst->layout.cols == 1 && 68 mul_params.channel_dimension() == ChannelDimension::kRow) { 69 Kernel8bitNeon1Col(params); 70 return; 71 } 72 if (__builtin_expect(tuning == Tuning::kA55ish, true)) { 73 Kernel8bitNeonA55ish(params); 74 } else { 75 Kernel8bitNeon(params); 76 } 77 } 78 }; 79 #endif 80 81 #if RUY_PLATFORM_NEON_32 82 template <typename DstScalar> 83 struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> { 84 static constexpr Path kPath = Path::kNeon; 85 using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>; 86 using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 2>; 87 Tuning tuning = Tuning::kAuto; 88 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 89 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, 90 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, 91 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { 92 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; 93 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, 94 end_col, dst, ¶ms); 95 if (dst->layout.cols == 1 && 96 mul_params.channel_dimension() == ChannelDimension::kRow) { 97 Kernel8bitNeon1Col(params); 98 return; 99 } 100 Kernel8bitNeon(params); 101 } 102 }; 103 #endif 104 105 #if RUY_PLATFORM_NEON_64 106 template <typename DstScalar> 107 struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t, DstScalar> { 108 static constexpr Path kPath = Path::kNeonDotprod; 109 Tuning tuning = Tuning::kAuto; 110 using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; 111 using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; 112 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 113 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, 114 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, 115 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { 116 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; 117 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, 118 end_col, dst, ¶ms); 119 if (dst->layout.cols == 1 && 120 mul_params.channel_dimension() == ChannelDimension::kRow) { 121 Kernel8bitNeonDotprod1Col(params); 122 } else if (__builtin_expect(tuning == Tuning::kA55ish, true)) { 123 Kernel8bitNeonDotprodA55ish(params); 124 } else { 125 Kernel8bitNeonDotprod(params); 126 } 127 } 128 }; 129 #endif 130 131 void KernelFloatNeon(const KernelParamsFloat<8, 8>& params); 132 void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params); 133 void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params); 134 void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params); 135 136 #if RUY_PLATFORM_NEON_64 137 // A Float kernel for ARM64 Neon. 138 template <> 139 struct Kernel<Path::kNeon, float, float, float, float> { 140 static constexpr Path kPath = Path::kNeon; 141 Tuning tuning = Tuning::kAuto; 142 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 143 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 144 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 145 void Run(const PMat<float>& lhs, const PMat<float>& rhs, 146 const MulParams<float, float>& mul_params, int start_row, 147 int start_col, int end_row, int end_col, Mat<float>* dst) const { 148 KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; 149 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, 150 end_col, dst, ¶ms); 151 if (__builtin_expect(tuning == Tuning::kA55ish, true)) { 152 KernelFloatNeonA55ish(params); 153 } else { 154 KernelFloatNeon(params); 155 } 156 } 157 }; 158 #endif 159 160 #if RUY_PLATFORM_NEON_32 161 // A Float kernel for ARM32 Neon. 162 template <> 163 struct Kernel<Path::kNeon, float, float, float, float> { 164 static constexpr Path kPath = Path::kNeon; 165 Tuning tuning = Tuning::kAuto; 166 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 167 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 4>; 168 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 169 void Run(const PMat<float>& lhs, const PMat<float>& rhs, 170 const MulParams<float, float>& mul_params, int start_row, 171 int start_col, int end_row, int end_col, Mat<float>* dst) const { 172 KernelParamsFloat<8, 4> params; 173 174 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, 175 end_col, dst, ¶ms); 176 177 KernelFloat32Neon(params); 178 } 179 }; 180 #endif 181 182 // While the dotprod NEON extension does not concern floating-point arithmetic, 183 // its presence allows us to distinguish, in the in-order tuning case, between 184 // A53 and A55r1. TODO: should this be folded into tuning? 185 template <> 186 struct Kernel<Path::kNeonDotprod, float, float, float, float> { 187 static constexpr Path kPath = Path::kNeonDotprod; 188 Tuning tuning = Tuning::kAuto; 189 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 190 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 191 using Base = 192 Kernel<Path::kNeon, float, float, float, float>; 193 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 194 void Run(const PMat<float>& lhs, const PMat<float>& rhs, 195 const MulParams<float, float>& mul_params, int start_row, 196 int start_col, int end_row, int end_col, Mat<float>* dst) const { 197 KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; 198 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, 199 end_col, dst, ¶ms); 200 if (__builtin_expect(tuning == Tuning::kA55ish, true)) { 201 KernelFloatNeonDotprodA55ish(params); 202 } else { 203 KernelFloatNeon(params); 204 } 205 } 206 }; 207 208 #endif // RUY_PLATFORM_NEON && RUY_OPT(ASM) 209 210 } // namespace ruy 211 212 #endif // RUY_RUY_KERNEL_ARM_H_ 213