1 #include "ruy/context.h"
2 #include "ruy/gtest_wrapper.h"
3 #include "ruy/kernel.h"
4 #include "ruy/matrix.h"
5 #include "ruy/path.h"
6 #include "ruy/performance_advisory.h"
7 #include "ruy/ruy.h"
8 
9 namespace ruy {
10 namespace {
11 
12 constexpr Path kPath = Path::kInternalStandardCppVariant3;
13 constexpr int kBufferSize = 64;
14 
15 template <typename AccumScalar, typename DstScalar,
16           bool HaveQuantizedMultipliers =
17               std::is_same<AccumScalar, std::int32_t>::value &&
18               !std::is_same<DstScalar, std::int32_t>::value>
19 struct PopulatePerChannelBuffersImpl {
Runruy::__anoncf059a0a0111::PopulatePerChannelBuffersImpl20   static void Run(MulParams<AccumScalar, DstScalar>* mul_params) {
21     static const AccumScalar bias_buf[kBufferSize] = {0};
22     static const AccumScalar multiplier_fixedpoint_buf[kBufferSize] = {0};
23     static const int multiplier_exponent_buf[kBufferSize] = {0};
24     mul_params->set_bias(bias_buf);
25     mul_params->set_multiplier_fixedpoint_perchannel(multiplier_fixedpoint_buf);
26     mul_params->set_multiplier_exponent_perchannel(multiplier_exponent_buf);
27   }
28 };
29 
30 template <typename AccumScalar, typename DstScalar>
31 struct PopulatePerChannelBuffersImpl<AccumScalar, DstScalar, false> {
Runruy::__anoncf059a0a0111::PopulatePerChannelBuffersImpl32   static void Run(MulParams<AccumScalar, DstScalar>* mul_params) {
33     static const AccumScalar bias_buf[kBufferSize] = {0};
34     mul_params->set_bias(bias_buf);
35   }
36 };
37 
38 template <typename AccumScalar, typename DstScalar>
PopulatePerChannelBuffers(MulParams<AccumScalar,DstScalar> * mul_params)39 void PopulatePerChannelBuffers(MulParams<AccumScalar, DstScalar>* mul_params) {
40   PopulatePerChannelBuffersImpl<AccumScalar, DstScalar>::Run(mul_params);
41 }
42 
43 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
44           typename DstScalar>
TestPerChannelBuffersReallocation()45 void TestPerChannelBuffersReallocation() {
46   using KernelType = Kernel<kPath, float, float, float, float>;
47 
48   MulParams<AccumScalar, DstScalar> mul_params;
49   PopulatePerChannelBuffers(&mul_params);
50 
51   const int kMatrixSize = 3;
52   ruy::Matrix<LhsScalar> lhs;
53   ruy::MakeSimpleLayout(kMatrixSize, kMatrixSize, ruy::Order::kRowMajor,
54                         lhs.mutable_layout());
55   const LhsScalar lhs_data[kMatrixSize * kMatrixSize] = {0};
56   lhs.set_data(lhs_data);
57   ruy::Matrix<RhsScalar> rhs;
58   ruy::MakeSimpleLayout(kMatrixSize, kMatrixSize, ruy::Order::kColMajor,
59                         rhs.mutable_layout());
60   const RhsScalar rhs_data[kMatrixSize * kMatrixSize] = {0};
61   rhs.set_data(rhs_data);
62   DstScalar dst_data[kMatrixSize * kMatrixSize] = {0};
63   ruy::Matrix<DstScalar> dst;
64   ruy::MakeSimpleLayout(kMatrixSize, kMatrixSize, ruy::Order::kColMajor,
65                         dst.mutable_layout());
66   dst.set_data(dst_data);
67 
68   ruy::Context context;
69 
70   auto test_advisory = [&](bool expect_advisory,
71                            ChannelDimension channel_dimension,
72                            int capacity_rounding) {
73     mul_params.set_channel_dimension(channel_dimension);
74     mul_params.set_perchannel_buffers_capacity_rounding(capacity_rounding);
75     ruy::Mul<kPath>(lhs, rhs, mul_params, &context, &dst);
76     EXPECT_EQ(context.performance_advisory(
77                   PerformanceAdvisory::kReallocatedPerChannelBuffer),
78               expect_advisory);
79   };
80 
81   static_assert(KernelType::LhsLayout::kCols == 16, "");
82   test_advisory(true, ChannelDimension::kRow, 1);
83   test_advisory(true, ChannelDimension::kRow, 2);
84   test_advisory(true, ChannelDimension::kRow, 4);
85   test_advisory(true, ChannelDimension::kRow, 8);
86   test_advisory(false, ChannelDimension::kRow, 16);
87   test_advisory(false, ChannelDimension::kRow, 32);
88   test_advisory(false, ChannelDimension::kRow, 64);
89 
90   static_assert(KernelType::RhsLayout::kCols == 8, "");
91   test_advisory(true, ChannelDimension::kCol, 1);
92   test_advisory(true, ChannelDimension::kCol, 2);
93   test_advisory(true, ChannelDimension::kCol, 4);
94   test_advisory(false, ChannelDimension::kCol, 8);
95   test_advisory(false, ChannelDimension::kCol, 16);
96   test_advisory(false, ChannelDimension::kCol, 32);
97   test_advisory(false, ChannelDimension::kCol, 64);
98 }
99 
TEST(PerChannelBuffersReallocationTest,Float)100 TEST(PerChannelBuffersReallocationTest, Float) {
101   TestPerChannelBuffersReallocation<float, float, float, float>();
102 }
103 
TEST(PerChannelBuffersReallocationTest,Quantized)104 TEST(PerChannelBuffersReallocationTest, Quantized) {
105   TestPerChannelBuffersReallocation<std::int8_t, std::int8_t, std::int32_t,
106                                     std::int8_t>();
107 }
108 
TEST(PerChannelBuffersReallocationTest,RawInt32)109 TEST(PerChannelBuffersReallocationTest, RawInt32) {
110   TestPerChannelBuffersReallocation<std::int8_t, std::int8_t, std::int32_t,
111                                     std::int32_t>();
112 }
113 
114 }  // namespace
115 }  // namespace ruy
116 
main(int argc,char ** argv)117 int main(int argc, char** argv) {
118   ::testing::InitGoogleTest(&argc, argv);
119   return RUN_ALL_TESTS();
120 }
121