1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_MATRIX_H_
17 #define RUY_RUY_MATRIX_H_
18 
19 #include <cstddef>
20 #include <cstdint>  // IWYU pragma: keep
21 #include <type_traits>
22 
23 #include "ruy/check_macros.h"
24 
25 namespace ruy {
26 
27 // Layout storage order. Here and elsewhere, 'col' is short for 'column'.
28 // 'column-major' means that each column is contiguous in memory.
29 enum class Order : std::uint8_t { kColMajor, kRowMajor };
30 
31 // Describes the shape and storage layout of a matrix.
32 class Layout final {
33  public:
rows()34   int rows() const { return rows_; }
set_rows(int val)35   void set_rows(int val) { rows_ = val; }
cols()36   int cols() const { return cols_; }
set_cols(int val)37   void set_cols(int val) { cols_ = val; }
stride()38   int stride() const { return stride_; }
set_stride(int val)39   void set_stride(int val) { stride_ = val; }
order()40   Order order() const { return order_; }
set_order(Order val)41   void set_order(Order val) { order_ = val; }
42 
43  private:
44   int rows_ = 0;
45   int cols_ = 0;
46   // Stride is the offset between two adjacent matrix elements
47   // in the non-contiguous direction.
48   int stride_ = 0;
49   Order order_ = Order::kColMajor;
50 };
51 
52 namespace detail {
53 
54 // Thin wrapper around a pointer with a constness model that works for the
55 // purposes of the Matrix class.
56 //
57 // A typical conundrum of any C++ container class is what type constness should
58 // encode at compile time constancy of the contained data?
59 // `Matrix<const T>` or `const Matrix<T>`?
60 // With either approach it is very difficult to achieve perfect
61 // const-correctness that that can only be done with some combination of
62 // inconvenient interface and c++ complexity/abstraction.
63 //
64 // Here we opt for something that's entirely tailored to the needs of the Ruy
65 // interface. The only purpose of the Matrix class is to pass matrix data
66 // pointers to ruy. There is an asymmetry here: the caller of ruy::Mul only
67 // needs to `set` the data; ruy itself only needs to `get` the data. In the
68 // caller's code, it's convenient to be able to just deal with `Matrix<T>`
69 // without having to sprinkle `const` keywords in the right places, so we want
70 // to track whether the data is constant in a way that's decoupled from the
71 // constness of `this`, and we never want to have Matrix<const T>. Inside ruy
72 // code, as input matrices are passed by const-reference and output matrices are
73 // passed by pointer (to non-const), the constness of `this` is telling whether
74 // the data is constant. See the `get` and `set` methods below and the comment
75 // explaining the core logic that they encapsulate.
76 template <typename T>
77 class ConstCheckingPtr final {
78  public:
79   using element_type = T;
80 
81   // Core accessors. These encapsulate the main logic:
82   // - for `set`, the constness of the argument determines whether internal
83   // pointer should be tracked as const/mutable.
84   // - for `get`, the constness of `this` determines whether the call
85   // counts as a const or mutable use of the internal pointer.
set(T * ptr)86   void set(T* ptr) {
87     ptr_ = ptr;
88     set_mutable(true);
89   }
set(const T * ptr)90   void set(const T* ptr) {
91     ptr_ = ptr;
92     set_mutable(false);
93   }
set(std::nullptr_t)94   void set(std::nullptr_t) { ptr_ = nullptr; }
get()95   T* get() /* NOT const */ {
96     assert_mutable();
97     return const_cast<T*>(ptr_);
98   }
get()99   const T* get() const { return ptr_; }
100 
101  private:
102   // There's never a need for Matrix<const T>.
103   static_assert(!std::is_const<T>::value, "");
104   const T* ptr_ = nullptr;
105 #ifndef NDEBUG
106   bool is_mutable_ = true;
set_mutable(bool val)107   void set_mutable(bool val) { is_mutable_ = val; }
assert_mutable()108   void assert_mutable() { RUY_DCHECK(is_mutable_); }
109 #else
set_mutable(bool)110   void set_mutable(bool) {}
assert_mutable()111   void assert_mutable() {}
112 #endif
113 };
114 
115 }  // namespace detail
116 
117 enum class CachePolicy : std::uint8_t {
118   kNeverCache,
119   kCacheIfLargeSpeedup,
120   kCacheIfSignificantSpeedup,
121   kAlwaysCache,
122 };
123 
124 // A Matrix merely wraps existing data as a matrix. It doesn't own any buffer.
125 // The purpose of Matrix is only to be used in ruy's interface -- it's just
126 // a structured way for the user to pass to ruy::Mul the matrix data pointers
127 // together with other matrix parameters.
128 // Scalar may be any floating-point or integral type. When integral, it may be
129 // signed or unsigned. It's never const: use Matrix<T> for both input and output
130 // matrices, never use Matrix<const T>.
131 // See the comments on detail::ConstCheckingPointer.
132 template <typename Scalar>
133 class Matrix final {
134  public:
135   static_assert(!std::is_const<Scalar>::value,
136                 "Never use Matrix<const T>. Just use Matrix<T>. Constness of "
137                 "the data is guarded by debug-only runtime assertions. See "
138                 "detail::ConstCheckingPtr.");
139 
data()140   Scalar* data() { return data_.get(); }
data()141   const Scalar* data() const { return data_.get(); }
set_data(Scalar * ptr)142   void set_data(Scalar* ptr) { data_.set(ptr); }
set_data(const Scalar * ptr)143   void set_data(const Scalar* ptr) { data_.set(ptr); }
set_data(std::nullptr_t)144   void set_data(std::nullptr_t) { data_.set(nullptr); }
layout()145   const Layout& layout() const { return layout_; }
mutable_layout()146   Layout* mutable_layout() { return &layout_; }
zero_point()147   Scalar zero_point() const { return zero_point_; }
set_zero_point(Scalar value)148   void set_zero_point(Scalar value) { zero_point_ = value; }
cache_policy()149   CachePolicy cache_policy() const { return cache_policy_; }
set_cache_policy(CachePolicy value)150   void set_cache_policy(CachePolicy value) { cache_policy_ = value; }
151 
152  private:
153   // The underlying buffer wrapped by this matrix.
154   detail::ConstCheckingPtr<Scalar> data_;
155   // The shape and data layout of this matrix.
156   Layout layout_;
157   // The zero_point, i.e. which Scalar value is to be interpreted as zero.
158   // When Scalar is floating-point, this must be 0.
159   Scalar zero_point_ = 0;
160   // When the data pointed to by this matrix is constant data, so that it is
161   // valid to assume that equality of pointers implies equality of data,
162   // a CachePolicy may be used instead of the default kNeverCache,
163   // which will enable ruy to take advantage of this constancy of the data to
164   // cache the packing work, which can be a large speedup in matrix*vector
165   // and other narrow shapes.
166   CachePolicy cache_policy_ = CachePolicy::kNeverCache;
167 };
168 
MakeSimpleLayout(int rows,int cols,Order order,Layout * layout)169 inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) {
170   layout->set_rows(rows);
171   layout->set_cols(cols);
172   layout->set_order(order);
173   layout->set_stride(order == Order::kColMajor ? rows : cols);
174 }
175 
176 template <typename StreamType, typename Scalar>
177 StreamType& operator<<(StreamType& stream, const Matrix<Scalar>& mat) {
178   for (int row = 0; row < mat.layout().rows(); row++) {
179     for (int col = 0; col < mat.layout().cols(); col++) {
180       stream << static_cast<double>(Element(mat, row, col)) << " ";
181     }
182     stream << "\n";
183   }
184   return stream;
185 }
186 
187 // TODO(b/130417400) add a unit test
Offset(const Layout & layout,int row,int col)188 inline int Offset(const Layout& layout, int row, int col) {
189   // TODO(benoitjacob)  - should check this but this make the _slow tests take
190   // 5x longer.  Find a mitigation like in Eigen with an 'internal' variant
191   // bypassing the check?
192   // RUY_DCHECK_GE(row, 0);
193   // RUY_DCHECK_GE(col, 0);
194   // RUY_DCHECK_LT(row, layout.rows());
195   // RUY_DCHECK_LT(col, layout.cols());
196   int row_stride = layout.order() == Order::kColMajor ? 1 : layout.stride();
197   int col_stride = layout.order() == Order::kRowMajor ? 1 : layout.stride();
198   return row * row_stride + col * col_stride;
199 }
200 
201 template <typename Scalar>
ElementPtr(const Matrix<Scalar> & mat,int row,int col)202 const Scalar* ElementPtr(const Matrix<Scalar>& mat, int row, int col) {
203   return mat.data() + Offset(mat.layout(), row, col);
204 }
205 
206 template <typename Scalar>
ElementPtr(Matrix<Scalar> * mat,int row,int col)207 Scalar* ElementPtr(Matrix<Scalar>* mat, int row, int col) {
208   return mat->data() + Offset(mat->layout(), row, col);
209 }
210 
211 template <typename Scalar>
Element(const Matrix<Scalar> & mat,int row,int col)212 Scalar Element(const Matrix<Scalar>& mat, int row, int col) {
213   return *ElementPtr(mat, row, col);
214 }
215 
216 }  // namespace ruy
217 
218 #endif  // RUY_RUY_MATRIX_H_
219