1# WALS Factorization 2 3$$ 4% commands 5\newcommand\bracket[2]{\left\langle #1, #2 \right\rangle} 6\newcommand\trace{\text{trace}} 7\newcommand\Rbb{\mathbb{R}} 8$$ 9 10### Problem formulation 11 12WALS (Weighed Alternating Least Squares) is an algorithm for factorizing a 13sparse matrix $$A \in \Rbb^{n \times m}$$ into low rank factors, $$U \in \Rbb^{n 14\times k}$$ and $$V \in \Rbb^{m \times k}$$, such that the product $$UV^T$$ is a 15"good" approximation of the full matrix. 16 17![wals](wals.png) 18 19Typically, it involves minimizing the following loss function: 20 21$$ min_{U,V} 22(\|\sqrt{W} \odot (A- UV^T)\|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2)), 23$$ 24 25where 26 27- $$\lambda$$ is a regularization parameter, 28- $$\odot$$ denotes the component-wise product, 29- $$W$$ is a weight matrix of the form $$W_{i, j} = w_0 + 1_{A_{ij} \neq 0}R_i 30 C_j$$, where $$w_0$$ is the weight of unobserved entries, and $$R \in 31 \Rbb^n$$ and $$C \in \Rbb^m$$ are the row and column weights respectively. 32 This form of the weight matrix lends this problem to an efficient 33 implementation. 34 35### Solution method 36 37The WALS algorithm proceeds in phases, or "sweeps", where each sweep involves 38fixing $$U$$ and solving for $$V$$, and then fixing $$V$$ and solving for $$U$$. 39Note that each subproblem is an unconstrained quadratic minimization problem and 40can be solved exactly. Convergence is typically pretty fast (10-20 sweeps). 41 42### Loss computation 43 44The product $$UV^T$$ is dense, and can be too large to compute. So we use the 45following reformulation of the loss to be able to compute it efficiently. First, 46we decompose the norm into two terms, corresponding to the sparsity pattern of 47$$A$$. Let $$S = \{(i, j) : A_{i, j} \neq 0\}$$. 48 49$$ 50\begin{align} 51\|\sqrt W \odot (A - UV^T)\|_F^2 52&= \sum_{(i, j) \in S} (w_0 + R_i C_j) (A_{ij} - \bracket{u_i}{v_j})^2 + 53\sum_{(i, j) \not\in S} w_0 (\bracket{u_i}{v_j})^2 \\ 54&= \sum_{(i, j) \in S} \left( (w_0 + R_i C_j) (A_{ij} - \bracket{u_i}{v_j})^2 - 55w_0 \bracket{u_i}{v_j}^2 \right) + w_0\|UV^T\|_F^2 56\end{align} 57$$ 58 59The last term can be computed efficiently by observing that 60 61$$ 62\|UV^T\|_F^2 = \trace(UV^TVU^T) = \trace(U^TUV^TV) 63$$ 64 65Each of the Gramian matrices $$U^TU$$ and $$V^TV$$ are $$k\times k$$ and are 66cheap to store. Additionally, $$\|U\|_F^2 = \trace(U^TU)$$ and similarly for 67$$V$$, so we can use the trace of the individual Gramians to compute the norms. 68