1 /* Copyright 2020 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Front-end validation code, see the Validate function.
17
18 #ifndef RUY_RUY_VALIDATE_H_
19 #define RUY_RUY_VALIDATE_H_
20
21 #include <cstdint>
22 #include <limits>
23 #include <type_traits>
24
25 #include "ruy/check_macros.h"
26 #include "ruy/mat.h"
27 #include "ruy/mul_params.h"
28 #include "ruy/side_pair.h"
29
30 namespace ruy {
31 namespace detail {
32
33 template <typename Scalar>
CheckZeroPoint(Scalar zero_point)34 void CheckZeroPoint(Scalar zero_point) {
35 if (std::is_floating_point<Scalar>::value) {
36 RUY_DCHECK(!zero_point);
37 }
38 }
39
40 template <typename LhsScalar, typename RhsScalar, typename DstScalar>
ValidateZeroPoints(LhsScalar lhs_zero_point,RhsScalar rhs_zero_point,DstScalar dst_zero_point)41 void ValidateZeroPoints(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point,
42 DstScalar dst_zero_point) {
43 CheckZeroPoint(lhs_zero_point);
44 CheckZeroPoint(rhs_zero_point);
45 CheckZeroPoint(dst_zero_point);
46
47 // Guard against the case when both LHS and RHS zero_point's are equal to
48 // the minimum representable value. In that case, padding with zero_point
49 // values will generate the bad case for fast int8 kernels on NEON
50 // (pre-dotprod) which attempt to multiply-accumulate two pairs of int8
51 // into a int16: this is safe except in the bad case -128*-128 + -128*-128.
52 // See b/131609283. This only affects the kNeon path but we ban this for all
53 // paths in order for ruy to have the same supported parameter space
54 // on all paths.
55 // We disable this check for now for the case of LhsScalar==RhsScalar==uint8
56 // for backwards compatability with gemmlowp. The issue is still relevant
57 // because we convert from uint8 to int8 for the backend kernels.
58 if (!std::is_same<LhsScalar, uint8_t>::value ||
59 !std::is_same<RhsScalar, uint8_t>::value) {
60 RUY_DCHECK(lhs_zero_point != std::numeric_limits<LhsScalar>::lowest() ||
61 rhs_zero_point != std::numeric_limits<RhsScalar>::lowest());
62 }
63 }
64
65 } // namespace detail
66
67 template <typename LhsScalar, typename RhsScalar, typename DstScalar>
Validate(const Mat<LhsScalar> & lhs,const Mat<RhsScalar> & rhs,const Mat<DstScalar> & dst)68 void Validate(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
69 const Mat<DstScalar>& dst) {
70 detail::ValidateZeroPoints(lhs.zero_point, rhs.zero_point, dst.zero_point);
71 }
72
73 } // namespace ruy
74
75 #endif // RUY_RUY_VALIDATE_H_
76