1 /*
2  *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
12 #define WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
13 
14 #include <algorithm>
15 #include <cstring>
16 #include <string>
17 #include <vector>
18 
19 #include "webrtc/base/checks.h"
20 #include "webrtc/base/constructormagic.h"
21 #include "webrtc/base/scoped_ptr.h"
22 
23 namespace {
24 
25 // Wrappers to get around the compiler warning resulting from the fact that
26 // there's no std::sqrt overload for ints. We cast all non-complex types to
27 // a double for the sqrt method.
28 template <typename T>
sqrt_wrapper(T x)29 T sqrt_wrapper(T x) {
30   return sqrt(static_cast<double>(x));
31 }
32 
33 template <typename S>
sqrt_wrapper(std::complex<S> x)34 std::complex<S> sqrt_wrapper(std::complex<S> x) {
35   return sqrt(x);
36 }
37 } // namespace
38 
39 namespace webrtc {
40 
41 // Matrix is a class for doing standard matrix operations on 2 dimensional
42 // matrices of any size. Results of matrix operations are stored in the
43 // calling object. Function overloads exist for both in-place (the calling
44 // object is used as both an operand and the result) and out-of-place (all
45 // operands are passed in as parameters) operations. If operand dimensions
46 // mismatch, the program crashes. Out-of-place operations change the size of
47 // the calling object, if necessary, before operating.
48 //
49 // 'In-place' operations that inherently change the size of the matrix (eg.
50 // Transpose, Multiply on different-sized matrices) must make temporary copies
51 // (|scratch_elements_| and |scratch_data_|) of existing data to complete the
52 // operations.
53 //
54 // The data is stored contiguously. Data can be accessed internally as a flat
55 // array, |data_|, or as an array of row pointers, |elements_|, but is
56 // available to users only as an array of row pointers through |elements()|.
57 // Memory for storage is allocated when a matrix is resized only if the new
58 // size overflows capacity. Memory needed temporarily for any operations is
59 // similarly resized only if the new size overflows capacity.
60 //
61 // If you pass in storage through the ctor, that storage is copied into the
62 // matrix. TODO(claguna): albeit tricky, allow for data to be referenced
63 // instead of copied, and owned by the user.
64 template <typename T>
65 class Matrix {
66  public:
Matrix()67   Matrix() : num_rows_(0), num_columns_(0) {}
68 
69   // Allocates space for the elements and initializes all values to zero.
Matrix(size_t num_rows,size_t num_columns)70   Matrix(size_t num_rows, size_t num_columns)
71       : num_rows_(num_rows), num_columns_(num_columns) {
72     Resize();
73     scratch_data_.resize(num_rows_ * num_columns_);
74     scratch_elements_.resize(num_rows_);
75   }
76 
77   // Copies |data| into the new Matrix.
Matrix(const T * data,size_t num_rows,size_t num_columns)78   Matrix(const T* data, size_t num_rows, size_t num_columns)
79       : num_rows_(0), num_columns_(0) {
80     CopyFrom(data, num_rows, num_columns);
81     scratch_data_.resize(num_rows_ * num_columns_);
82     scratch_elements_.resize(num_rows_);
83   }
84 
~Matrix()85   virtual ~Matrix() {}
86 
87   // Deep copy an existing matrix.
CopyFrom(const Matrix & other)88   void CopyFrom(const Matrix& other) {
89     CopyFrom(&other.data_[0], other.num_rows_, other.num_columns_);
90   }
91 
92   // Copy |data| into the Matrix. The current data is lost.
CopyFrom(const T * const data,size_t num_rows,size_t num_columns)93   void CopyFrom(const T* const data, size_t num_rows, size_t num_columns) {
94     Resize(num_rows, num_columns);
95     memcpy(&data_[0], data, num_rows_ * num_columns_ * sizeof(data_[0]));
96   }
97 
CopyFromColumn(const T * const * src,size_t column_index,size_t num_rows)98   Matrix& CopyFromColumn(const T* const* src,
99                          size_t column_index,
100                          size_t num_rows) {
101     Resize(1, num_rows);
102     for (size_t i = 0; i < num_columns_; ++i) {
103       data_[i] = src[i][column_index];
104     }
105 
106     return *this;
107   }
108 
Resize(size_t num_rows,size_t num_columns)109   void Resize(size_t num_rows, size_t num_columns) {
110     if (num_rows != num_rows_ || num_columns != num_columns_) {
111       num_rows_ = num_rows;
112       num_columns_ = num_columns;
113       Resize();
114     }
115   }
116 
117   // Accessors and mutators.
num_rows()118   size_t num_rows() const { return num_rows_; }
num_columns()119   size_t num_columns() const { return num_columns_; }
elements()120   T* const* elements() { return &elements_[0]; }
elements()121   const T* const* elements() const { return &elements_[0]; }
122 
Trace()123   T Trace() {
124     RTC_CHECK_EQ(num_rows_, num_columns_);
125 
126     T trace = 0;
127     for (size_t i = 0; i < num_rows_; ++i) {
128       trace += elements_[i][i];
129     }
130     return trace;
131   }
132 
133   // Matrix Operations. Returns *this to support method chaining.
Transpose()134   Matrix& Transpose() {
135     CopyDataToScratch();
136     Resize(num_columns_, num_rows_);
137     return Transpose(scratch_elements());
138   }
139 
Transpose(const Matrix & operand)140   Matrix& Transpose(const Matrix& operand) {
141     RTC_CHECK_EQ(operand.num_rows_, num_columns_);
142     RTC_CHECK_EQ(operand.num_columns_, num_rows_);
143 
144     return Transpose(operand.elements());
145   }
146 
147   template <typename S>
Scale(const S & scalar)148   Matrix& Scale(const S& scalar) {
149     for (size_t i = 0; i < data_.size(); ++i) {
150       data_[i] *= scalar;
151     }
152 
153     return *this;
154   }
155 
156   template <typename S>
Scale(const Matrix & operand,const S & scalar)157   Matrix& Scale(const Matrix& operand, const S& scalar) {
158     CopyFrom(operand);
159     return Scale(scalar);
160   }
161 
Add(const Matrix & operand)162   Matrix& Add(const Matrix& operand) {
163     RTC_CHECK_EQ(num_rows_, operand.num_rows_);
164     RTC_CHECK_EQ(num_columns_, operand.num_columns_);
165 
166     for (size_t i = 0; i < data_.size(); ++i) {
167       data_[i] += operand.data_[i];
168     }
169 
170     return *this;
171   }
172 
Add(const Matrix & lhs,const Matrix & rhs)173   Matrix& Add(const Matrix& lhs, const Matrix& rhs) {
174     CopyFrom(lhs);
175     return Add(rhs);
176   }
177 
Subtract(const Matrix & operand)178   Matrix& Subtract(const Matrix& operand) {
179     RTC_CHECK_EQ(num_rows_, operand.num_rows_);
180     RTC_CHECK_EQ(num_columns_, operand.num_columns_);
181 
182     for (size_t i = 0; i < data_.size(); ++i) {
183       data_[i] -= operand.data_[i];
184     }
185 
186     return *this;
187   }
188 
Subtract(const Matrix & lhs,const Matrix & rhs)189   Matrix& Subtract(const Matrix& lhs, const Matrix& rhs) {
190     CopyFrom(lhs);
191     return Subtract(rhs);
192   }
193 
PointwiseMultiply(const Matrix & operand)194   Matrix& PointwiseMultiply(const Matrix& operand) {
195     RTC_CHECK_EQ(num_rows_, operand.num_rows_);
196     RTC_CHECK_EQ(num_columns_, operand.num_columns_);
197 
198     for (size_t i = 0; i < data_.size(); ++i) {
199       data_[i] *= operand.data_[i];
200     }
201 
202     return *this;
203   }
204 
PointwiseMultiply(const Matrix & lhs,const Matrix & rhs)205   Matrix& PointwiseMultiply(const Matrix& lhs, const Matrix& rhs) {
206     CopyFrom(lhs);
207     return PointwiseMultiply(rhs);
208   }
209 
PointwiseDivide(const Matrix & operand)210   Matrix& PointwiseDivide(const Matrix& operand) {
211     RTC_CHECK_EQ(num_rows_, operand.num_rows_);
212     RTC_CHECK_EQ(num_columns_, operand.num_columns_);
213 
214     for (size_t i = 0; i < data_.size(); ++i) {
215       data_[i] /= operand.data_[i];
216     }
217 
218     return *this;
219   }
220 
PointwiseDivide(const Matrix & lhs,const Matrix & rhs)221   Matrix& PointwiseDivide(const Matrix& lhs, const Matrix& rhs) {
222     CopyFrom(lhs);
223     return PointwiseDivide(rhs);
224   }
225 
PointwiseSquareRoot()226   Matrix& PointwiseSquareRoot() {
227     for (size_t i = 0; i < data_.size(); ++i) {
228       data_[i] = sqrt_wrapper(data_[i]);
229     }
230 
231     return *this;
232   }
233 
PointwiseSquareRoot(const Matrix & operand)234   Matrix& PointwiseSquareRoot(const Matrix& operand) {
235     CopyFrom(operand);
236     return PointwiseSquareRoot();
237   }
238 
PointwiseAbsoluteValue()239   Matrix& PointwiseAbsoluteValue() {
240     for (size_t i = 0; i < data_.size(); ++i) {
241       data_[i] = abs(data_[i]);
242     }
243 
244     return *this;
245   }
246 
PointwiseAbsoluteValue(const Matrix & operand)247   Matrix& PointwiseAbsoluteValue(const Matrix& operand) {
248     CopyFrom(operand);
249     return PointwiseAbsoluteValue();
250   }
251 
PointwiseSquare()252   Matrix& PointwiseSquare() {
253     for (size_t i = 0; i < data_.size(); ++i) {
254       data_[i] *= data_[i];
255     }
256 
257     return *this;
258   }
259 
PointwiseSquare(const Matrix & operand)260   Matrix& PointwiseSquare(const Matrix& operand) {
261     CopyFrom(operand);
262     return PointwiseSquare();
263   }
264 
Multiply(const Matrix & lhs,const Matrix & rhs)265   Matrix& Multiply(const Matrix& lhs, const Matrix& rhs) {
266     RTC_CHECK_EQ(lhs.num_columns_, rhs.num_rows_);
267     RTC_CHECK_EQ(num_rows_, lhs.num_rows_);
268     RTC_CHECK_EQ(num_columns_, rhs.num_columns_);
269 
270     return Multiply(lhs.elements(), rhs.num_rows_, rhs.elements());
271   }
272 
Multiply(const Matrix & rhs)273   Matrix& Multiply(const Matrix& rhs) {
274     RTC_CHECK_EQ(num_columns_, rhs.num_rows_);
275 
276     CopyDataToScratch();
277     Resize(num_rows_, rhs.num_columns_);
278     return Multiply(scratch_elements(), rhs.num_rows_, rhs.elements());
279   }
280 
ToString()281   std::string ToString() const {
282     std::ostringstream ss;
283     ss << std::endl << "Matrix" << std::endl;
284 
285     for (size_t i = 0; i < num_rows_; ++i) {
286       for (size_t j = 0; j < num_columns_; ++j) {
287         ss << elements_[i][j] << " ";
288       }
289       ss << std::endl;
290     }
291     ss << std::endl;
292 
293     return ss.str();
294   }
295 
296  protected:
SetNumRows(const size_t num_rows)297   void SetNumRows(const size_t num_rows) { num_rows_ = num_rows; }
SetNumColumns(const size_t num_columns)298   void SetNumColumns(const size_t num_columns) { num_columns_ = num_columns; }
data()299   T* data() { return &data_[0]; }
data()300   const T* data() const { return &data_[0]; }
scratch_elements()301   const T* const* scratch_elements() const { return &scratch_elements_[0]; }
302 
303   // Resize the matrix. If an increase in capacity is required, the current
304   // data is lost.
Resize()305   void Resize() {
306     size_t size = num_rows_ * num_columns_;
307     data_.resize(size);
308     elements_.resize(num_rows_);
309 
310     for (size_t i = 0; i < num_rows_; ++i) {
311       elements_[i] = &data_[i * num_columns_];
312     }
313   }
314 
315   // Copies data_ into scratch_data_ and updates scratch_elements_ accordingly.
CopyDataToScratch()316   void CopyDataToScratch() {
317     scratch_data_ = data_;
318     scratch_elements_.resize(num_rows_);
319 
320     for (size_t i = 0; i < num_rows_; ++i) {
321       scratch_elements_[i] = &scratch_data_[i * num_columns_];
322     }
323   }
324 
325  private:
326   size_t num_rows_;
327   size_t num_columns_;
328   std::vector<T> data_;
329   std::vector<T*> elements_;
330 
331   // Stores temporary copies of |data_| and |elements_| for in-place operations
332   // where referring to original data is necessary.
333   std::vector<T> scratch_data_;
334   std::vector<T*> scratch_elements_;
335 
336   // Helpers for Transpose and Multiply operations that unify in-place and
337   // out-of-place solutions.
Transpose(const T * const * src)338   Matrix& Transpose(const T* const* src) {
339     for (size_t i = 0; i < num_rows_; ++i) {
340       for (size_t j = 0; j < num_columns_; ++j) {
341         elements_[i][j] = src[j][i];
342       }
343     }
344 
345     return *this;
346   }
347 
Multiply(const T * const * lhs,size_t num_rows_rhs,const T * const * rhs)348   Matrix& Multiply(const T* const* lhs,
349                    size_t num_rows_rhs,
350                    const T* const* rhs) {
351     for (size_t row = 0; row < num_rows_; ++row) {
352       for (size_t col = 0; col < num_columns_; ++col) {
353         T cur_element = 0;
354         for (size_t i = 0; i < num_rows_rhs; ++i) {
355           cur_element += lhs[row][i] * rhs[i][col];
356         }
357 
358         elements_[row][col] = cur_element;
359       }
360     }
361 
362     return *this;
363   }
364 
365   RTC_DISALLOW_COPY_AND_ASSIGN(Matrix);
366 };
367 
368 }  // namespace webrtc
369 
370 #endif  // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
371