1# WALS Factorization
2
3$$
4% commands
5\newcommand\bracket[2]{\left\langle #1, #2 \right\rangle}
6\newcommand\trace{\text{trace}}
7\newcommand\Rbb{\mathbb{R}}
8$$
9
10### Problem formulation
11
12WALS (Weighed Alternating Least Squares) is an algorithm for factorizing a
13sparse matrix $$A \in \Rbb^{n \times m}$$ into low rank factors, $$U \in \Rbb^{n
14\times k}$$ and $$V \in \Rbb^{m \times k}$$, such that the product $$UV^T$$ is a
15"good" approximation of the full matrix.
16
17![wals](wals.png)
18
19Typically, it involves minimizing the following loss function:
20
21$$ min_{U,V}
22(\|\sqrt{W} \odot (A- UV^T)\|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2)),
23$$
24
25where
26
27-   $$\lambda$$ is a regularization parameter,
28-   $$\odot$$ denotes the component-wise product,
29-   $$W$$ is a weight matrix of the form $$W_{i, j} = w_0 + 1_{A_{ij} \neq 0}R_i
30    C_j$$, where $$w_0$$ is the weight of unobserved entries, and $$R \in
31    \Rbb^n$$ and $$C \in \Rbb^m$$ are the row and column weights respectively.
32    This form of the weight matrix lends this problem to an efficient
33    implementation.
34
35### Solution method
36
37The WALS algorithm proceeds in phases, or "sweeps", where each sweep involves
38fixing $$U$$ and solving for $$V$$, and then fixing $$V$$ and solving for $$U$$.
39Note that each subproblem is an unconstrained quadratic minimization problem and
40can be solved exactly. Convergence is typically pretty fast (10-20 sweeps).
41
42### Loss computation
43
44The product $$UV^T$$ is dense, and can be too large to compute. So we use the
45following reformulation of the loss to be able to compute it efficiently. First,
46we decompose the norm into two terms, corresponding to the sparsity pattern of
47$$A$$. Let $$S = \{(i, j) : A_{i, j} \neq 0\}$$.
48
49$$
50\begin{align}
51\|\sqrt W \odot (A - UV^T)\|_F^2
52&= \sum_{(i, j) \in S} (w_0 + R_i C_j) (A_{ij} - \bracket{u_i}{v_j})^2 +
53\sum_{(i, j) \not\in S} w_0 (\bracket{u_i}{v_j})^2 \\
54&= \sum_{(i, j) \in S} \left( (w_0 + R_i C_j) (A_{ij} - \bracket{u_i}{v_j})^2 -
55w_0 \bracket{u_i}{v_j}^2 \right) + w_0\|UV^T\|_F^2
56\end{align}
57$$
58
59The last term can be computed efficiently by observing that
60
61$$
62\|UV^T\|_F^2 = \trace(UV^TVU^T) = \trace(U^TUV^TV)
63$$
64
65Each of the Gramian matrices $$U^TU$$ and $$V^TV$$ are $$k\times k$$ and are
66cheap to store. Additionally, $$\|U\|_F^2 = \trace(U^TU)$$ and similarly for
67$$V$$, so we can use the trace of the individual Gramians to compute the norms.
68