1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Weighted Alternating Least Squares (WALS) on the tf.learn API."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.factorization.python.ops import factorization_ops
22from tensorflow.contrib.learn.python.learn.estimators import estimator
23from tensorflow.contrib.learn.python.learn.estimators import model_fn
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import state_ops
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.summary import summary
33from tensorflow.python.training import session_run_hook
34from tensorflow.python.training import training_util
35
36
37class _SweepHook(session_run_hook.SessionRunHook):
38  """Keeps track of row/col sweeps, and runs prep ops before each sweep."""
39
40  def __init__(self, is_row_sweep_var, is_sweep_done_var, init_op,
41               row_prep_ops, col_prep_ops, row_train_op, col_train_op,
42               switch_op):
43    """Initializes SweepHook.
44
45    Args:
46      is_row_sweep_var: A Boolean tf.Variable, determines whether we are
47        currently doing a row or column sweep. It is updated by the hook.
48      is_sweep_done_var: A Boolean tf.Variable, determines whether we are
49        starting a new sweep (this is used to determine when to run the prep ops
50        below).
51      init_op: op to be run once before training. This is typically a local
52        initialization op (such as cache initialization).
53      row_prep_ops: A list of TensorFlow ops, to be run before the beginning of
54        each row sweep (and during initialization), in the given order.
55      col_prep_ops: A list of TensorFlow ops, to be run before the beginning of
56        each column sweep (and during initialization), in the given order.
57      row_train_op: A TensorFlow op to be run during row sweeps.
58      col_train_op: A TensorFlow op to be run during column sweeps.
59      switch_op: A TensorFlow op to be run before each sweep.
60    """
61    self._is_row_sweep_var = is_row_sweep_var
62    self._is_sweep_done_var = is_sweep_done_var
63    self._init_op = init_op
64    self._row_prep_ops = row_prep_ops
65    self._col_prep_ops = col_prep_ops
66    self._row_train_op = row_train_op
67    self._col_train_op = col_train_op
68    self._switch_op = switch_op
69    # Boolean variable that determines whether the init_op has been run.
70    self._is_initialized = False
71
72  def before_run(self, run_context):
73    """Runs the appropriate prep ops, and requests running update ops."""
74    sess = run_context.session
75    is_sweep_done = sess.run(self._is_sweep_done_var)
76    if not self._is_initialized:
77      logging.info("SweepHook running init op.")
78      sess.run(self._init_op)
79    if is_sweep_done:
80      logging.info("SweepHook starting the next sweep.")
81      sess.run(self._switch_op)
82    is_row_sweep = sess.run(self._is_row_sweep_var)
83    if is_sweep_done or not self._is_initialized:
84      logging.info("SweepHook running prep ops for the {} sweep.".format(
85          "row" if is_row_sweep else "col"))
86      prep_ops = self._row_prep_ops if is_row_sweep else self._col_prep_ops
87      for prep_op in prep_ops:
88        sess.run(prep_op)
89    self._is_initialized = True
90    logging.info("Next fit step starting.")
91    return session_run_hook.SessionRunArgs(
92        fetches=[self._row_train_op if is_row_sweep else self._col_train_op])
93
94
95class _IncrementGlobalStepHook(session_run_hook.SessionRunHook):
96  """Hook that increments the global step."""
97
98  def __init__(self):
99    global_step = training_util.get_global_step()
100    if global_step:
101      self._global_step_incr_op = state_ops.assign_add(
102          global_step, 1, name="global_step_incr").op
103    else:
104      self._global_step_incr_op = None
105
106  def before_run(self, run_context):
107    if self._global_step_incr_op:
108      run_context.session.run(self._global_step_incr_op)
109
110
111class _StopAtSweepHook(session_run_hook.SessionRunHook):
112  """Hook that requests stop at a given sweep."""
113
114  def __init__(self, last_sweep):
115    """Initializes a `StopAtSweepHook`.
116
117    This hook requests stop at a given sweep. Relies on the tensor named
118    COMPLETED_SWEEPS in the default graph.
119
120    Args:
121      last_sweep: Integer, number of the last sweep to run.
122    """
123    self._last_sweep = last_sweep
124
125  def begin(self):
126    try:
127      self._completed_sweeps_var = ops.get_default_graph().get_tensor_by_name(
128          WALSMatrixFactorization.COMPLETED_SWEEPS + ":0")
129    except KeyError:
130      raise RuntimeError(WALSMatrixFactorization.COMPLETED_SWEEPS +
131                         " counter should be created to use StopAtSweepHook.")
132
133  def before_run(self, run_context):
134    return session_run_hook.SessionRunArgs(self._completed_sweeps_var)
135
136  def after_run(self, run_context, run_values):
137    completed_sweeps = run_values.results
138    if completed_sweeps >= self._last_sweep:
139      run_context.request_stop()
140
141
142def _wals_factorization_model_function(features, labels, mode, params):
143  """Model function for the WALSFactorization estimator.
144
145  Args:
146    features: Dictionary of features. See WALSMatrixFactorization.
147    labels: Must be None.
148    mode: A model_fn.ModeKeys object.
149    params: Dictionary of parameters containing arguments passed to the
150      WALSMatrixFactorization constructor.
151
152  Returns:
153    A ModelFnOps object.
154
155  Raises:
156    ValueError: If `mode` is not recognized.
157  """
158  assert labels is None
159  use_factors_weights_cache = (params["use_factors_weights_cache_for_training"]
160                               and mode == model_fn.ModeKeys.TRAIN)
161  use_gramian_cache = (params["use_gramian_cache_for_training"] and
162                       mode == model_fn.ModeKeys.TRAIN)
163  max_sweeps = params["max_sweeps"]
164  model = factorization_ops.WALSModel(
165      params["num_rows"],
166      params["num_cols"],
167      params["embedding_dimension"],
168      unobserved_weight=params["unobserved_weight"],
169      regularization=params["regularization_coeff"],
170      row_init=params["row_init"],
171      col_init=params["col_init"],
172      num_row_shards=params["num_row_shards"],
173      num_col_shards=params["num_col_shards"],
174      row_weights=params["row_weights"],
175      col_weights=params["col_weights"],
176      use_factors_weights_cache=use_factors_weights_cache,
177      use_gramian_cache=use_gramian_cache)
178
179  # Get input rows and cols. We either update rows or columns depending on
180  # the value of row_sweep, which is maintained using a session hook.
181  input_rows = features[WALSMatrixFactorization.INPUT_ROWS]
182  input_cols = features[WALSMatrixFactorization.INPUT_COLS]
183
184  # TRAIN mode:
185  if mode == model_fn.ModeKeys.TRAIN:
186    # Training consists of the following ops (controlled using a SweepHook).
187    # Before a row sweep:
188    #   row_update_prep_gramian_op
189    #   initialize_row_update_op
190    # During a row sweep:
191    #   update_row_factors_op
192    # Before a col sweep:
193    #   col_update_prep_gramian_op
194    #   initialize_col_update_op
195    # During a col sweep:
196    #   update_col_factors_op
197
198    is_row_sweep_var = variable_scope.variable(
199        True,
200        trainable=False,
201        name="is_row_sweep",
202        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
203    is_sweep_done_var = variable_scope.variable(
204        False,
205        trainable=False,
206        name="is_sweep_done",
207        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
208    completed_sweeps_var = variable_scope.variable(
209        0,
210        trainable=False,
211        name=WALSMatrixFactorization.COMPLETED_SWEEPS,
212        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
213    loss_var = variable_scope.variable(
214        0.,
215        trainable=False,
216        name=WALSMatrixFactorization.LOSS,
217        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
218    # The root weighted squared error =
219    #   \\(\sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )\\)
220    rwse_var = variable_scope.variable(
221        0.,
222        trainable=False,
223        name=WALSMatrixFactorization.RWSE,
224        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
225
226    summary.scalar("loss", loss_var)
227    summary.scalar("root_weighted_squared_error", rwse_var)
228    summary.scalar("completed_sweeps", completed_sweeps_var)
229
230    def create_axis_ops(sp_input, num_items, update_fn, axis_name):
231      """Creates book-keeping and training ops for a given axis.
232
233      Args:
234        sp_input: A SparseTensor corresponding to the row or column batch.
235        num_items: An integer, the total number of items of this axis.
236        update_fn: A function that takes one argument (`sp_input`), and that
237        returns a tuple of
238          * new_factors: A float Tensor of the factor values after update.
239          * update_op: a TensorFlow op which updates the factors.
240          * loss: A float Tensor, the unregularized loss.
241          * reg_loss: A float Tensor, the regularization loss.
242          * sum_weights: A float Tensor, the sum of factor weights.
243        axis_name: A string that specifies the name of the axis.
244
245      Returns:
246        A tuple consisting of:
247          * reset_processed_items_op: A TensorFlow op, to be run before the
248            beginning of any sweep. It marks all items as not-processed.
249          * axis_train_op: A Tensorflow op, to be run during this axis' sweeps.
250      """
251      processed_items_init = array_ops.fill(dims=[num_items], value=False)
252      with ops.colocate_with(processed_items_init):
253        processed_items = variable_scope.variable(
254            processed_items_init,
255            collections=[ops.GraphKeys.GLOBAL_VARIABLES],
256            trainable=False,
257            name="processed_" + axis_name)
258      _, update_op, loss, reg, sum_weights = update_fn(sp_input)
259      input_indices = sp_input.indices[:, 0]
260      with ops.control_dependencies([
261          update_op,
262          state_ops.assign(loss_var, loss + reg),
263          state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]):
264        with ops.colocate_with(processed_items):
265          update_processed_items = state_ops.scatter_update(
266              processed_items,
267              input_indices,
268              array_ops.ones_like(input_indices, dtype=dtypes.bool),
269              name="update_processed_{}_indices".format(axis_name))
270        with ops.control_dependencies([update_processed_items]):
271          is_sweep_done = math_ops.reduce_all(processed_items)
272          axis_train_op = control_flow_ops.group(
273              state_ops.assign(is_sweep_done_var, is_sweep_done),
274              state_ops.assign_add(
275                  completed_sweeps_var,
276                  math_ops.cast(is_sweep_done, dtypes.int32)),
277              name="{}_sweep_train_op".format(axis_name))
278      return processed_items.initializer, axis_train_op
279
280    reset_processed_rows_op, row_train_op = create_axis_ops(
281        input_rows,
282        params["num_rows"],
283        lambda x: model.update_row_factors(sp_input=x, transpose_input=False),
284        "rows")
285    reset_processed_cols_op, col_train_op = create_axis_ops(
286        input_cols,
287        params["num_cols"],
288        lambda x: model.update_col_factors(sp_input=x, transpose_input=True),
289        "cols")
290    switch_op = control_flow_ops.group(
291        state_ops.assign(
292            is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)),
293        reset_processed_rows_op,
294        reset_processed_cols_op,
295        name="sweep_switch_op")
296    row_prep_ops = [
297        model.row_update_prep_gramian_op, model.initialize_row_update_op]
298    col_prep_ops = [
299        model.col_update_prep_gramian_op, model.initialize_col_update_op]
300    init_op = model.worker_init
301    sweep_hook = _SweepHook(
302        is_row_sweep_var, is_sweep_done_var, init_op,
303        row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op)
304    global_step_hook = _IncrementGlobalStepHook()
305    training_hooks = [sweep_hook, global_step_hook]
306    if max_sweeps is not None:
307      training_hooks.append(_StopAtSweepHook(max_sweeps))
308
309    return model_fn.ModelFnOps(
310        mode=model_fn.ModeKeys.TRAIN,
311        predictions={},
312        loss=loss_var,
313        eval_metric_ops={},
314        train_op=control_flow_ops.no_op(),
315        training_hooks=training_hooks)
316
317  # INFER mode
318  elif mode == model_fn.ModeKeys.INFER:
319    projection_weights = features.get(
320        WALSMatrixFactorization.PROJECTION_WEIGHTS)
321
322    def get_row_projection():
323      return model.project_row_factors(
324          sp_input=input_rows,
325          projection_weights=projection_weights,
326          transpose_input=False)
327
328    def get_col_projection():
329      return model.project_col_factors(
330          sp_input=input_cols,
331          projection_weights=projection_weights,
332          transpose_input=True)
333
334    predictions = {
335        WALSMatrixFactorization.PROJECTION_RESULT: control_flow_ops.cond(
336            features[WALSMatrixFactorization.PROJECT_ROW],
337            get_row_projection,
338            get_col_projection)
339    }
340
341    return model_fn.ModelFnOps(
342        mode=model_fn.ModeKeys.INFER,
343        predictions=predictions,
344        loss=None,
345        eval_metric_ops={},
346        train_op=control_flow_ops.no_op(),
347        training_hooks=[])
348
349  # EVAL mode
350  elif mode == model_fn.ModeKeys.EVAL:
351    def get_row_loss():
352      _, _, loss, reg, _ = model.update_row_factors(
353          sp_input=input_rows, transpose_input=False)
354      return loss + reg
355    def get_col_loss():
356      _, _, loss, reg, _ = model.update_col_factors(
357          sp_input=input_cols, transpose_input=True)
358      return loss + reg
359    loss = control_flow_ops.cond(
360        features[WALSMatrixFactorization.PROJECT_ROW],
361        get_row_loss,
362        get_col_loss)
363    return model_fn.ModelFnOps(
364        mode=model_fn.ModeKeys.EVAL,
365        predictions={},
366        loss=loss,
367        eval_metric_ops={},
368        train_op=control_flow_ops.no_op(),
369        training_hooks=[])
370
371  else:
372    raise ValueError("mode=%s is not recognized." % str(mode))
373
374
375class WALSMatrixFactorization(estimator.Estimator):
376  """An Estimator for Weighted Matrix Factorization, using the WALS method.
377
378  WALS (Weighted Alternating Least Squares) is an algorithm for weighted matrix
379  factorization. It computes a low-rank approximation of a given sparse (n x m)
380  matrix `A`, by a product of two matrices, `U * V^T`, where `U` is a (n x k)
381  matrix and `V` is a (m x k) matrix. Here k is the rank of the approximation,
382  also called the embedding dimension. We refer to `U` as the row factors, and
383  `V` as the column factors.
384  See tensorflow/contrib/factorization/g3doc/wals.md for the precise problem
385  formulation.
386
387  The training proceeds in sweeps: during a row_sweep, we fix `V` and solve for
388  `U`. During a column sweep, we fix `U` and solve for `V`. Each one of these
389  problems is an unconstrained quadratic minimization problem and can be solved
390  exactly (it can also be solved in mini-batches, since the solution decouples
391  across rows of each matrix).
392  The alternating between sweeps is achieved by using a hook during training,
393  which is responsible for keeping track of the sweeps and running preparation
394  ops at the beginning of each sweep. It also updates the global_step variable,
395  which keeps track of the number of batches processed since the beginning of
396  training.
397  The current implementation assumes that the training is run on a single
398  machine, and will fail if `config.num_worker_replicas` is not equal to one.
399  Training is done by calling `self.fit(input_fn=input_fn)`, where `input_fn`
400  provides two tensors: one for rows of the input matrix, and one for rows of
401  the transposed input matrix (i.e. columns of the original matrix). Note that
402  during a row sweep, only row batches are processed (ignoring column batches)
403  and vice-versa.
404  Also note that every row (respectively every column) of the input matrix
405  must be processed at least once for the sweep to be considered complete. In
406  particular, training will not make progress if some rows are not generated by
407  the `input_fn`.
408
409  For prediction, given a new set of input rows `A'`, we compute a corresponding
410  set of row factors `U'`, such that `U' * V^T` is a good approximation of `A'`.
411  We call this operation a row projection. A similar operation is defined for
412  columns. Projection is done by calling
413  `self.get_projections(input_fn=input_fn)`, where `input_fn` satisfies the
414  constraints given below.
415
416  The input functions must satisfy the following constraints: Calling `input_fn`
417  must return a tuple `(features, labels)` where `labels` is None, and
418  `features` is a dict containing the following keys:
419
420  TRAIN:
421    * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
422      Rows of the input matrix to process (or to project).
423    * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
424      Columns of the input matrix to process (or to project), transposed.
425
426  INFER:
427    * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
428      Rows to project.
429    * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
430      Columns to project.
431    * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project
432      the rows or columns.
433    * `WALSMatrixFactorization.PROJECTION_WEIGHTS` (Optional): float32 Tensor
434      (vector). The weights to use in the projection.
435
436  EVAL:
437    * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
438      Rows to project.
439    * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
440      Columns to project.
441    * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project
442      the rows or columns.
443  """
444  # Keys to be used in model_fn
445  # Features keys
446  INPUT_ROWS = "input_rows"
447  INPUT_COLS = "input_cols"
448  PROJECT_ROW = "project_row"
449  PROJECTION_WEIGHTS = "projection_weights"
450  # Predictions key
451  PROJECTION_RESULT = "projection"
452  # Name of the completed_sweeps variable
453  COMPLETED_SWEEPS = "completed_sweeps"
454  # Name of the loss variable
455  LOSS = "WALS_loss"
456  # Name of the Root Weighted Squared Error variable
457  RWSE = "WALS_RWSE"
458
459  def __init__(self,
460               num_rows,
461               num_cols,
462               embedding_dimension,
463               unobserved_weight=0.1,
464               regularization_coeff=None,
465               row_init="random",
466               col_init="random",
467               num_row_shards=1,
468               num_col_shards=1,
469               row_weights=1,
470               col_weights=1,
471               use_factors_weights_cache_for_training=True,
472               use_gramian_cache_for_training=True,
473               max_sweeps=None,
474               model_dir=None,
475               config=None):
476    r"""Creates a model for matrix factorization using the WALS method.
477
478    Args:
479      num_rows: Total number of rows for input matrix.
480      num_cols: Total number of cols for input matrix.
481      embedding_dimension: Dimension to use for the factors.
482      unobserved_weight: Weight of the unobserved entries of matrix.
483      regularization_coeff: Weight of the L2 regularization term. Defaults to
484        None, in which case the problem is not regularized.
485      row_init: Initializer for row factor. Must be either:
486        - A tensor: The row factor matrix is initialized to this tensor,
487        - A numpy constant,
488        - "random": The rows are initialized using a normal distribution.
489      col_init: Initializer for column factor. See row_init.
490      num_row_shards: Number of shards to use for the row factors.
491      num_col_shards: Number of shards to use for the column factors.
492      row_weights: Must be in one of the following three formats:
493        - None: In this case, the weight of every entry is the unobserved_weight
494          and the problem simplifies to ALS. Note that, in this case,
495          col_weights must also be set to "None".
496        - List of lists of non-negative scalars, of the form
497          \\([[w_0, w_1, ...], [w_k, ... ], [...]]\\),
498          where the number of inner lists equal to the number of row factor
499          shards and the elements in each inner list are the weights for the
500          rows of that shard. In this case,
501          \\(w_ij = unonbserved_weight + row_weights[i] * col_weights[j]\\).
502        - A non-negative scalar: This value is used for all row weights.
503          Note that it is allowed to have row_weights as a list and col_weights
504          as a scalar, or vice-versa.
505      col_weights: See row_weights.
506      use_factors_weights_cache_for_training: Boolean, whether the factors and
507        weights will be cached on the workers before the updates start, during
508        training. Defaults to True.
509        Note that caching is disabled during prediction.
510      use_gramian_cache_for_training: Boolean, whether the Gramians will be
511        cached on the workers before the updates start, during training.
512        Defaults to True. Note that caching is disabled during prediction.
513      max_sweeps: integer, optional. Specifies the number of sweeps for which
514        to train the model, where a sweep is defined as a full update of all the
515        row factors (resp. column factors).
516        If `steps` or `max_steps` is also specified in model.fit(), training
517        stops when either of the steps condition or sweeps condition is met.
518      model_dir: The directory to save the model results and log files.
519      config: A Configuration object. See Estimator.
520
521    Raises:
522      ValueError: If config.num_worker_replicas is strictly greater than one.
523        The current implementation only supports running on a single worker.
524    """
525    # TODO(walidk): Support power-law based weight computation.
526    # TODO(walidk): Add factor lookup by indices, with caching.
527    # TODO(walidk): Support caching during prediction.
528    # TODO(walidk): Provide input pipelines that handle missing rows.
529
530    params = {
531        "num_rows":
532            num_rows,
533        "num_cols":
534            num_cols,
535        "embedding_dimension":
536            embedding_dimension,
537        "unobserved_weight":
538            unobserved_weight,
539        "regularization_coeff":
540            regularization_coeff,
541        "row_init":
542            row_init,
543        "col_init":
544            col_init,
545        "num_row_shards":
546            num_row_shards,
547        "num_col_shards":
548            num_col_shards,
549        "row_weights":
550            row_weights,
551        "col_weights":
552            col_weights,
553        "max_sweeps":
554            max_sweeps,
555        "use_factors_weights_cache_for_training":
556            use_factors_weights_cache_for_training,
557        "use_gramian_cache_for_training":
558            use_gramian_cache_for_training
559    }
560    self._row_factors_names = [
561        "row_factors_shard_%d" % i for i in range(num_row_shards)
562    ]
563    self._col_factors_names = [
564        "col_factors_shard_%d" % i for i in range(num_col_shards)
565    ]
566
567    super(WALSMatrixFactorization, self).__init__(
568        model_fn=_wals_factorization_model_function,
569        params=params,
570        model_dir=model_dir,
571        config=config)
572
573    if self._config is not None and self._config.num_worker_replicas > 1:
574      raise ValueError("WALSMatrixFactorization must be run on a single worker "
575                       "replica.")
576
577  def get_row_factors(self):
578    """Returns the row factors of the model, loading them from checkpoint.
579
580    Should only be run after training.
581
582    Returns:
583      A list of the row factors of the model.
584    """
585    return [self.get_variable_value(name) for name in self._row_factors_names]
586
587  def get_col_factors(self):
588    """Returns the column factors of the model, loading them from checkpoint.
589
590    Should only be run after training.
591
592    Returns:
593      A list of the column factors of the model.
594    """
595    return [self.get_variable_value(name) for name in self._col_factors_names]
596
597  def get_projections(self, input_fn):
598    """Computes the projections of the rows or columns given in input_fn.
599
600    Runs predict() with the given input_fn, and returns the results. Should only
601    be run after training.
602
603    Args:
604      input_fn: Input function which specifies the rows or columns to project.
605    Returns:
606      A generator of the projected factors.
607    """
608    return (result[WALSMatrixFactorization.PROJECTION_RESULT]
609            for result in self.predict(input_fn=input_fn))
610