1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implementation of MulFrontEnd, the front-end part of ruy.
17 // This is what the ruy::Mul entry point calls, and this ends in a call to
18 // TrMul, at which point we enter the middle-end.
19 // The front-end work includes parameter validation (Validate), detemplatization
20 // and resolution of the specific code path to take (CreateTrMulParams), and
21 // any additional logic best done upfront before entering the middle-end
22 // (e.g. HandlePrepackedCaching).
23 // The call to CreateTrMulParams is an important watershed in this code's
24 // structure: code before it needs to be templatized like the ruy::Mul entry
25 // point, code after it is un-templatized.
26
27 #ifndef RUY_RUY_FRONTEND_H_
28 #define RUY_RUY_FRONTEND_H_
29
30 #include "ruy/create_trmul_params.h"
31 #include "ruy/ctx.h"
32 #include "ruy/profiler/instrumentation.h"
33 #include "ruy/trace.h"
34 #include "ruy/trmul_params.h"
35 #include "ruy/validate.h"
36
37 namespace ruy {
38
39 // The first half of front-end work, up to the point where we have TrMulParams.
40 // In other words, this is the part of the front-end work that needs to be
41 // templatized like the entry point, and that performs the initial work that
42 // requires this templatization, and the de-templatization. The output of this
43 // function is the TrMulParams, which contain enough information to allow the
44 // un-templatized code to take over from there.
45 template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
46 typename AccumScalar, typename DstScalar>
MulFrontEndUpToCreateTrMulParams(const Mat<LhsScalar> & lhs,const Mat<RhsScalar> & rhs,const Mat<DstScalar> & dst,const MulParams<AccumScalar,DstScalar> & mul_params,Ctx * ctx,TrMulParams * params)47 void MulFrontEndUpToCreateTrMulParams(
48 const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
49 const Mat<DstScalar>& dst,
50 const MulParams<AccumScalar, DstScalar>& mul_params, Ctx* ctx,
51 TrMulParams* params) {
52 RUY_TRACE_SCOPE;
53 static_assert(CompiledPaths != Path::kNone, "Must compile at least one Path");
54 static_assert(
55 (CompiledPaths & ~kAllPathsIncludingInternalVariants) == Path::kNone,
56 "CompiledPaths must be a subset of "
57 "ruy::kAllPathsIncludingInternalVariants");
58
59 // Perform validation of parameters early so that failures are easier to map
60 // to user errors. In particular, perform this validation before the
61 // transposition.
62 Validate(lhs, rhs, dst);
63
64 // De-templatize this Mul call by creating a TrMulParams structure.
65 // This is also where the specific kernel and pack code paths corresponding to
66 // `the_path` are selected, among all the code paths in `CompiledPaths`, and
67 // recorded as function pointers in the TrMulParams.
68 // The Transpose(lhs) here is where we switch from 'Mul' to 'TrMul'.
69 CreateTrMulParams<CompiledPaths>(Transpose(lhs), rhs, dst, mul_params, ctx,
70 params);
71 }
72
73 // The second part of the front-end work, starting from where we have freshly
74 // created TrMulParams, performing any remaining front-end work and entering the
75 // middle-end.
76 void MulFrontEndFromTrMulParams(Ctx* ctx, TrMulParams* params);
77
78 // Top-level function orchestrating the two halves of front-end work:
79 // before and after we have detemplatized the call by creating TrMulParams.
80 template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
81 typename AccumScalar, typename DstScalar>
MulFrontEnd(const Mat<LhsScalar> & lhs,const Mat<RhsScalar> & rhs,const MulParams<AccumScalar,DstScalar> & mul_params,Ctx * ctx,Mat<DstScalar> * dst)82 void MulFrontEnd(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
83 const MulParams<AccumScalar, DstScalar>& mul_params, Ctx* ctx,
84 Mat<DstScalar>* dst) {
85 RUY_TRACE_SCOPE;
86 profiler::ScopeLabel mul_label("Mul");
87 profiler::ScopeLabel shape_specific_label("matmul shape: %dx%dx%d",
88 lhs.layout.rows, lhs.layout.cols,
89 rhs.layout.cols);
90 ctx->clear_performance_advisories();
91 TrMulParams params;
92 MulFrontEndUpToCreateTrMulParams<CompiledPaths>(lhs, rhs, *dst, mul_params,
93 ctx, ¶ms);
94 MulFrontEndFromTrMulParams(ctx, ¶ms);
95 }
96
97 } // namespace ruy
98
99 #endif // RUY_RUY_FRONTEND_H_
100