1 #include "ruy/context.h"
2 #include "ruy/gtest_wrapper.h"
3 #include "ruy/kernel.h"
4 #include "ruy/matrix.h"
5 #include "ruy/path.h"
6 #include "ruy/performance_advisory.h"
7 #include "ruy/ruy.h"
8
9 namespace ruy {
10 namespace {
11
12 constexpr Path kPath = Path::kInternalStandardCppVariant3;
13 constexpr int kBufferSize = 64;
14
15 template <typename AccumScalar, typename DstScalar,
16 bool HaveQuantizedMultipliers =
17 std::is_same<AccumScalar, std::int32_t>::value &&
18 !std::is_same<DstScalar, std::int32_t>::value>
19 struct PopulatePerChannelBuffersImpl {
Runruy::__anoncf059a0a0111::PopulatePerChannelBuffersImpl20 static void Run(MulParams<AccumScalar, DstScalar>* mul_params) {
21 static const AccumScalar bias_buf[kBufferSize] = {0};
22 static const AccumScalar multiplier_fixedpoint_buf[kBufferSize] = {0};
23 static const int multiplier_exponent_buf[kBufferSize] = {0};
24 mul_params->set_bias(bias_buf);
25 mul_params->set_multiplier_fixedpoint_perchannel(multiplier_fixedpoint_buf);
26 mul_params->set_multiplier_exponent_perchannel(multiplier_exponent_buf);
27 }
28 };
29
30 template <typename AccumScalar, typename DstScalar>
31 struct PopulatePerChannelBuffersImpl<AccumScalar, DstScalar, false> {
Runruy::__anoncf059a0a0111::PopulatePerChannelBuffersImpl32 static void Run(MulParams<AccumScalar, DstScalar>* mul_params) {
33 static const AccumScalar bias_buf[kBufferSize] = {0};
34 mul_params->set_bias(bias_buf);
35 }
36 };
37
38 template <typename AccumScalar, typename DstScalar>
PopulatePerChannelBuffers(MulParams<AccumScalar,DstScalar> * mul_params)39 void PopulatePerChannelBuffers(MulParams<AccumScalar, DstScalar>* mul_params) {
40 PopulatePerChannelBuffersImpl<AccumScalar, DstScalar>::Run(mul_params);
41 }
42
43 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
44 typename DstScalar>
TestPerChannelBuffersReallocation()45 void TestPerChannelBuffersReallocation() {
46 using KernelType = Kernel<kPath, float, float, float, float>;
47
48 MulParams<AccumScalar, DstScalar> mul_params;
49 PopulatePerChannelBuffers(&mul_params);
50
51 const int kMatrixSize = 3;
52 ruy::Matrix<LhsScalar> lhs;
53 ruy::MakeSimpleLayout(kMatrixSize, kMatrixSize, ruy::Order::kRowMajor,
54 lhs.mutable_layout());
55 const LhsScalar lhs_data[kMatrixSize * kMatrixSize] = {0};
56 lhs.set_data(lhs_data);
57 ruy::Matrix<RhsScalar> rhs;
58 ruy::MakeSimpleLayout(kMatrixSize, kMatrixSize, ruy::Order::kColMajor,
59 rhs.mutable_layout());
60 const RhsScalar rhs_data[kMatrixSize * kMatrixSize] = {0};
61 rhs.set_data(rhs_data);
62 DstScalar dst_data[kMatrixSize * kMatrixSize] = {0};
63 ruy::Matrix<DstScalar> dst;
64 ruy::MakeSimpleLayout(kMatrixSize, kMatrixSize, ruy::Order::kColMajor,
65 dst.mutable_layout());
66 dst.set_data(dst_data);
67
68 ruy::Context context;
69
70 auto test_advisory = [&](bool expect_advisory,
71 ChannelDimension channel_dimension,
72 int capacity_rounding) {
73 mul_params.set_channel_dimension(channel_dimension);
74 mul_params.set_perchannel_buffers_capacity_rounding(capacity_rounding);
75 ruy::Mul<kPath>(lhs, rhs, mul_params, &context, &dst);
76 EXPECT_EQ(context.performance_advisory(
77 PerformanceAdvisory::kReallocatedPerChannelBuffer),
78 expect_advisory);
79 };
80
81 static_assert(KernelType::LhsLayout::kCols == 16, "");
82 test_advisory(true, ChannelDimension::kRow, 1);
83 test_advisory(true, ChannelDimension::kRow, 2);
84 test_advisory(true, ChannelDimension::kRow, 4);
85 test_advisory(true, ChannelDimension::kRow, 8);
86 test_advisory(false, ChannelDimension::kRow, 16);
87 test_advisory(false, ChannelDimension::kRow, 32);
88 test_advisory(false, ChannelDimension::kRow, 64);
89
90 static_assert(KernelType::RhsLayout::kCols == 8, "");
91 test_advisory(true, ChannelDimension::kCol, 1);
92 test_advisory(true, ChannelDimension::kCol, 2);
93 test_advisory(true, ChannelDimension::kCol, 4);
94 test_advisory(false, ChannelDimension::kCol, 8);
95 test_advisory(false, ChannelDimension::kCol, 16);
96 test_advisory(false, ChannelDimension::kCol, 32);
97 test_advisory(false, ChannelDimension::kCol, 64);
98 }
99
TEST(PerChannelBuffersReallocationTest,Float)100 TEST(PerChannelBuffersReallocationTest, Float) {
101 TestPerChannelBuffersReallocation<float, float, float, float>();
102 }
103
TEST(PerChannelBuffersReallocationTest,Quantized)104 TEST(PerChannelBuffersReallocationTest, Quantized) {
105 TestPerChannelBuffersReallocation<std::int8_t, std::int8_t, std::int32_t,
106 std::int8_t>();
107 }
108
TEST(PerChannelBuffersReallocationTest,RawInt32)109 TEST(PerChannelBuffersReallocationTest, RawInt32) {
110 TestPerChannelBuffersReallocation<std::int8_t, std::int8_t, std::int32_t,
111 std::int32_t>();
112 }
113
114 } // namespace
115 } // namespace ruy
116
main(int argc,char ** argv)117 int main(int argc, char** argv) {
118 ::testing::InitGoogleTest(&argc, argv);
119 return RUN_ALL_TESTS();
120 }
121