1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for WALSMatrixFactorization."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22import json
23import numpy as np
24
25from tensorflow.contrib.factorization.python.ops import factorization_ops_test_utils
26from tensorflow.contrib.factorization.python.ops import wals as wals_lib
27from tensorflow.contrib.learn.python.learn import run_config
28from tensorflow.contrib.learn.python.learn.estimators import model_fn
29from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import embedding_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import sparse_ops
38from tensorflow.python.ops import state_ops
39from tensorflow.python.ops import variables
40from tensorflow.python.platform import test
41from tensorflow.python.training import input as input_lib
42from tensorflow.python.training import monitored_session
43
44
45class WALSMatrixFactorizationTest(test.TestCase):
46  INPUT_MATRIX = factorization_ops_test_utils.INPUT_MATRIX
47
48  def np_array_to_sparse(self, np_array):
49    """Transforms an np.array to a tf.SparseTensor."""
50    return factorization_ops_test_utils.np_matrix_to_tf_sparse(np_array)
51
52  def calculate_loss(self):
53    """Calculates the loss of the current (trained) model."""
54    current_rows = embedding_ops.embedding_lookup(
55        self._model.get_row_factors(), math_ops.range(self._num_rows),
56        partition_strategy='div')
57    current_cols = embedding_ops.embedding_lookup(
58        self._model.get_col_factors(), math_ops.range(self._num_cols),
59        partition_strategy='div')
60    row_wts = embedding_ops.embedding_lookup(
61        self._row_weights, math_ops.range(self._num_rows),
62        partition_strategy='div')
63    col_wts = embedding_ops.embedding_lookup(
64        self._col_weights, math_ops.range(self._num_cols),
65        partition_strategy='div')
66    sp_inputs = self.np_array_to_sparse(self.INPUT_MATRIX)
67    return factorization_ops_test_utils.calculate_loss(
68        sp_inputs, current_rows, current_cols, self._regularization_coeff,
69        self._unobserved_weight, row_wts, col_wts)
70
71  # TODO(walidk): Replace with input_reader_utils functions once open sourced.
72  def remap_sparse_tensor_rows(self, sp_x, row_ids, shape):
73    """Remaps the row ids of a tf.SparseTensor."""
74    old_row_ids, old_col_ids = array_ops.split(
75        value=sp_x.indices, num_or_size_splits=2, axis=1)
76    new_row_ids = array_ops.gather(row_ids, old_row_ids)
77    new_indices = array_ops.concat([new_row_ids, old_col_ids], 1)
78    return sparse_tensor.SparseTensor(
79        indices=new_indices, values=sp_x.values, dense_shape=shape)
80
81  # TODO(walidk): Add an option to shuffle inputs.
82  def input_fn(self, np_matrix, batch_size, mode,
83               project_row=None, projection_weights=None,
84               remove_empty_rows_columns=False):
85    """Returns an input_fn that selects row and col batches from np_matrix.
86
87    This simple utility creates an input function from a numpy_array. The
88    following transformations are performed:
89    * The empty rows and columns in np_matrix are removed (if
90      remove_empty_rows_columns is true)
91    * np_matrix is converted to a SparseTensor.
92    * The rows of the sparse matrix (and the rows of its transpose) are batched.
93    * A features dictionary is created, which contains the row / column batches.
94
95    In TRAIN mode, one only needs to specify the np_matrix and the batch_size.
96    In INFER and EVAL modes, one must also provide project_row, a boolean which
97    specifies whether we are projecting rows or columns.
98
99    Args:
100      np_matrix: A numpy array. The input matrix to use.
101      batch_size: Integer.
102      mode: Can be one of model_fn.ModeKeys.{TRAIN, INFER, EVAL}.
103      project_row: A boolean. Used in INFER and EVAL modes. Specifies whether
104        to project rows or columns.
105      projection_weights: A float numpy array. Used in INFER mode. Specifies
106        the weights to use in the projection (the weights are optional, and
107        default to 1.).
108      remove_empty_rows_columns: A boolean. When true, this will remove empty
109        rows and columns in the np_matrix. Note that this will result in
110        modifying the indices of the input matrix. The mapping from new indices
111        to old indices is returned in the form of two numpy arrays.
112
113    Returns:
114      A tuple consisting of:
115      _fn: A callable. Calling _fn returns a features dict.
116      nz_row_ids: A numpy array of the ids of non-empty rows, such that
117        nz_row_ids[i] is the old row index corresponding to new index i.
118      nz_col_ids: A numpy array of the ids of non-empty columns, such that
119        nz_col_ids[j] is the old column index corresponding to new index j.
120    """
121    if remove_empty_rows_columns:
122      np_matrix, nz_row_ids, nz_col_ids = (
123          factorization_ops_test_utils.remove_empty_rows_columns(np_matrix))
124    else:
125      nz_row_ids = np.arange(np.shape(np_matrix)[0])
126      nz_col_ids = np.arange(np.shape(np_matrix)[1])
127
128    def extract_features(row_batch, col_batch, num_rows, num_cols):
129      row_ids = row_batch[0]
130      col_ids = col_batch[0]
131      rows = self.remap_sparse_tensor_rows(
132          row_batch[1], row_ids, shape=[num_rows, num_cols])
133      cols = self.remap_sparse_tensor_rows(
134          col_batch[1], col_ids, shape=[num_cols, num_rows])
135      features = {
136          wals_lib.WALSMatrixFactorization.INPUT_ROWS: rows,
137          wals_lib.WALSMatrixFactorization.INPUT_COLS: cols,
138      }
139      return features
140
141    def _fn():
142      num_rows = np.shape(np_matrix)[0]
143      num_cols = np.shape(np_matrix)[1]
144      row_ids = math_ops.range(num_rows, dtype=dtypes.int64)
145      col_ids = math_ops.range(num_cols, dtype=dtypes.int64)
146      sp_mat = self.np_array_to_sparse(np_matrix)
147      sp_mat_t = sparse_ops.sparse_transpose(sp_mat)
148      row_batch = input_lib.batch(
149          [row_ids, sp_mat],
150          batch_size=min(batch_size, num_rows),
151          capacity=10,
152          enqueue_many=True)
153      col_batch = input_lib.batch(
154          [col_ids, sp_mat_t],
155          batch_size=min(batch_size, num_cols),
156          capacity=10,
157          enqueue_many=True)
158
159      features = extract_features(row_batch, col_batch, num_rows, num_cols)
160
161      if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL:
162        self.assertTrue(
163            project_row is not None,
164            msg='project_row must be specified in INFER or EVAL mode.')
165        features[wals_lib.WALSMatrixFactorization.PROJECT_ROW] = (
166            constant_op.constant(project_row))
167
168      if mode == model_fn.ModeKeys.INFER and projection_weights is not None:
169        weights_batch = input_lib.batch(
170            projection_weights,
171            batch_size=batch_size,
172            capacity=10,
173            enqueue_many=True)
174        features[wals_lib.WALSMatrixFactorization.PROJECTION_WEIGHTS] = (
175            weights_batch)
176
177      labels = None
178      return features, labels
179
180    return _fn, nz_row_ids, nz_col_ids
181
182  @property
183  def input_matrix(self):
184    return self.INPUT_MATRIX
185
186  @property
187  def row_steps(self):
188    return np.ceil(self._num_rows / self.batch_size)
189
190  @property
191  def col_steps(self):
192    return np.ceil(self._num_cols / self.batch_size)
193
194  @property
195  def batch_size(self):
196    return 5
197
198  @property
199  def use_cache(self):
200    return False
201
202  @property
203  def max_sweeps(self):
204    return None
205
206  def setUp(self):
207    self._num_rows = 5
208    self._num_cols = 7
209    self._embedding_dimension = 3
210    self._unobserved_weight = 0.1
211    self._num_row_shards = 2
212    self._num_col_shards = 3
213    self._regularization_coeff = 0.01
214    self._col_init = [
215        # Shard 0.
216        [[-0.36444709, -0.39077035, -0.32528427],
217         [1.19056475, 0.07231052, 2.11834812],
218         [0.93468881, -0.71099287, 1.91826844]],
219        # Shard 1.
220        [[1.18160152, 1.52490723, -0.50015002],
221         [1.82574749, -0.57515913, -1.32810032]],
222        # Shard 2.
223        [[-0.15515432, -0.84675711, 0.13097958],
224         [-0.9246484, 0.69117504, 1.2036494]],
225    ]
226    self._row_weights = [[0.1, 0.2, 0.3], [0.4, 0.5]]
227    self._col_weights = [[0.1, 0.2, 0.3], [0.4, 0.5], [0.6, 0.7]]
228
229    # Values of row and column factors after running one iteration or factor
230    # updates.
231    self._row_factors_0 = [[0.097689, -0.219293, -0.020780],
232                           [0.50842, 0.64626, 0.22364],
233                           [0.401159, -0.046558, -0.192854]]
234    self._row_factors_1 = [[1.20597, -0.48025, 0.35582],
235                           [1.5564, 1.2528, 1.0528]]
236    self._col_factors_0 = [[2.4725, -1.2950, -1.9980],
237                           [0.44625, 1.50771, 1.27118],
238                           [1.39801, -2.10134, 0.73572]]
239    self._col_factors_1 = [[3.36509, -0.66595, -3.51208],
240                           [0.57191, 1.59407, 1.33020]]
241    self._col_factors_2 = [[3.3459, -1.3341, -3.3008],
242                           [0.57366, 1.83729, 1.26798]]
243    self._model = wals_lib.WALSMatrixFactorization(
244        self._num_rows,
245        self._num_cols,
246        self._embedding_dimension,
247        self._unobserved_weight,
248        col_init=self._col_init,
249        regularization_coeff=self._regularization_coeff,
250        num_row_shards=self._num_row_shards,
251        num_col_shards=self._num_col_shards,
252        row_weights=self._row_weights,
253        col_weights=self._col_weights,
254        max_sweeps=self.max_sweeps,
255        use_factors_weights_cache_for_training=self.use_cache,
256        use_gramian_cache_for_training=self.use_cache)
257
258  def test_fit(self):
259    # Row sweep.
260    input_fn = self.input_fn(np_matrix=self.input_matrix,
261                             batch_size=self.batch_size,
262                             mode=model_fn.ModeKeys.TRAIN,
263                             remove_empty_rows_columns=True)[0]
264    self._model.fit(input_fn=input_fn, steps=self.row_steps)
265    row_factors = self._model.get_row_factors()
266    self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3)
267    self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3)
268
269    # Col sweep.
270    # Running fit a second time will resume training from the checkpoint.
271    input_fn = self.input_fn(np_matrix=self.input_matrix,
272                             batch_size=self.batch_size,
273                             mode=model_fn.ModeKeys.TRAIN,
274                             remove_empty_rows_columns=True)[0]
275    self._model.fit(input_fn=input_fn, steps=self.col_steps)
276    col_factors = self._model.get_col_factors()
277    self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3)
278    self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3)
279    self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3)
280
281  def test_predict(self):
282    input_fn = self.input_fn(np_matrix=self.input_matrix,
283                             batch_size=self.batch_size,
284                             mode=model_fn.ModeKeys.TRAIN,
285                             remove_empty_rows_columns=True,
286                            )[0]
287    # Project rows 1 and 4 from the input matrix.
288    proj_input_fn = self.input_fn(
289        np_matrix=self.INPUT_MATRIX[[1, 4], :],
290        batch_size=2,
291        mode=model_fn.ModeKeys.INFER,
292        project_row=True,
293        projection_weights=[[0.2, 0.5]])[0]
294
295    self._model.fit(input_fn=input_fn, steps=self.row_steps)
296    projections = self._model.get_projections(proj_input_fn)
297    projected_rows = list(itertools.islice(projections, 2))
298
299    self.assertAllClose(
300        projected_rows,
301        [self._row_factors_0[1], self._row_factors_1[1]],
302        atol=1e-3)
303
304    # Project columns 5, 3, 1 from the input matrix.
305    proj_input_fn = self.input_fn(
306        np_matrix=self.INPUT_MATRIX[:, [5, 3, 1]],
307        batch_size=3,
308        mode=model_fn.ModeKeys.INFER,
309        project_row=False,
310        projection_weights=[[0.6, 0.4, 0.2]])[0]
311
312    self._model.fit(input_fn=input_fn, steps=self.col_steps)
313    projections = self._model.get_projections(proj_input_fn)
314    projected_cols = list(itertools.islice(projections, 3))
315    self.assertAllClose(
316        projected_cols,
317        [self._col_factors_2[0], self._col_factors_1[0],
318         self._col_factors_0[1]],
319        atol=1e-3)
320
321  def test_eval(self):
322    # Do a row sweep then evaluate the model on row inputs.
323    # The evaluate function returns the loss of the projected rows, but since
324    # projection is idempotent, the eval loss must match the model loss.
325    input_fn = self.input_fn(np_matrix=self.input_matrix,
326                             batch_size=self.batch_size,
327                             mode=model_fn.ModeKeys.TRAIN,
328                             remove_empty_rows_columns=True,
329                            )[0]
330    self._model.fit(input_fn=input_fn, steps=self.row_steps)
331    eval_input_fn_row = self.input_fn(np_matrix=self.input_matrix,
332                                      batch_size=1,
333                                      mode=model_fn.ModeKeys.EVAL,
334                                      project_row=True,
335                                      remove_empty_rows_columns=True)[0]
336    loss = self._model.evaluate(
337        input_fn=eval_input_fn_row, steps=self._num_rows)['loss']
338
339    with self.cached_session():
340      true_loss = self.calculate_loss()
341
342    self.assertNear(
343        loss, true_loss, err=.001,
344        msg="""After row update, eval loss = {}, does not match the true
345        loss = {}.""".format(loss, true_loss))
346
347    # Do a col sweep then evaluate the model on col inputs.
348    self._model.fit(input_fn=input_fn, steps=self.col_steps)
349    eval_input_fn_col = self.input_fn(np_matrix=self.input_matrix,
350                                      batch_size=1,
351                                      mode=model_fn.ModeKeys.EVAL,
352                                      project_row=False,
353                                      remove_empty_rows_columns=True)[0]
354    loss = self._model.evaluate(
355        input_fn=eval_input_fn_col, steps=self._num_cols)['loss']
356
357    with self.cached_session():
358      true_loss = self.calculate_loss()
359
360    self.assertNear(
361        loss, true_loss, err=.001,
362        msg="""After col update, eval loss = {}, does not match the true
363        loss = {}.""".format(loss, true_loss))
364
365
366class WALSMatrixFactorizationTestSweeps(WALSMatrixFactorizationTest):
367
368  @property
369  def max_sweeps(self):
370    return 2
371
372  # We set the column steps to None so that we rely only on max_sweeps to stop
373  # training.
374  @property
375  def col_steps(self):
376    return None
377
378
379class WALSMatrixFactorizationTestCached(WALSMatrixFactorizationTest):
380
381  @property
382  def use_cache(self):
383    return True
384
385
386class WALSMatrixFactorizaiontTestPaddedInput(WALSMatrixFactorizationTest):
387  PADDED_INPUT_MATRIX = np.pad(
388      WALSMatrixFactorizationTest.INPUT_MATRIX,
389      [(1, 0), (1, 0)], mode='constant')
390
391  @property
392  def input_matrix(self):
393    return self.PADDED_INPUT_MATRIX
394
395
396class WALSMatrixFactorizationUnsupportedTest(test.TestCase):
397
398  def setUp(self):
399    pass
400
401  def testDistributedWALSUnsupported(self):
402    tf_config = {
403        'cluster': {
404            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
405            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
406        },
407        'task': {
408            'type': run_config_lib.TaskType.WORKER,
409            'index': 1
410        }
411    }
412    with test.mock.patch.dict('os.environ',
413                              {'TF_CONFIG': json.dumps(tf_config)}):
414      config = run_config.RunConfig()
415    self.assertEqual(config.num_worker_replicas, 2)
416    with self.assertRaises(ValueError):
417      self._model = wals_lib.WALSMatrixFactorization(1, 1, 1, config=config)
418
419
420class SweepHookTest(test.TestCase):
421
422  def test_sweeps(self):
423    is_row_sweep_var = variables.VariableV1(True)
424    is_sweep_done_var = variables.VariableV1(False)
425    init_done = variables.VariableV1(False)
426    row_prep_done = variables.VariableV1(False)
427    col_prep_done = variables.VariableV1(False)
428    row_train_done = variables.VariableV1(False)
429    col_train_done = variables.VariableV1(False)
430
431    init_op = state_ops.assign(init_done, True)
432    row_prep_op = state_ops.assign(row_prep_done, True)
433    col_prep_op = state_ops.assign(col_prep_done, True)
434    row_train_op = state_ops.assign(row_train_done, True)
435    col_train_op = state_ops.assign(col_train_done, True)
436    train_op = control_flow_ops.no_op()
437    switch_op = control_flow_ops.group(
438        state_ops.assign(is_sweep_done_var, False),
439        state_ops.assign(is_row_sweep_var,
440                         math_ops.logical_not(is_row_sweep_var)))
441    mark_sweep_done = state_ops.assign(is_sweep_done_var, True)
442
443    with self.cached_session() as sess:
444      sweep_hook = wals_lib._SweepHook(
445          is_row_sweep_var,
446          is_sweep_done_var,
447          init_op,
448          [row_prep_op],
449          [col_prep_op],
450          row_train_op,
451          col_train_op,
452          switch_op)
453      mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
454      sess.run([variables.global_variables_initializer()])
455
456      # Row sweep.
457      mon_sess.run(train_op)
458      self.assertTrue(sess.run(init_done),
459                      msg='init op not run by the Sweephook')
460      self.assertTrue(sess.run(row_prep_done),
461                      msg='row_prep_op not run by the SweepHook')
462      self.assertTrue(sess.run(row_train_done),
463                      msg='row_train_op not run by the SweepHook')
464      self.assertTrue(
465          sess.run(is_row_sweep_var),
466          msg='Row sweep is not complete but is_row_sweep_var is False.')
467      # Col sweep.
468      mon_sess.run(mark_sweep_done)
469      mon_sess.run(train_op)
470      self.assertTrue(sess.run(col_prep_done),
471                      msg='col_prep_op not run by the SweepHook')
472      self.assertTrue(sess.run(col_train_done),
473                      msg='col_train_op not run by the SweepHook')
474      self.assertFalse(
475          sess.run(is_row_sweep_var),
476          msg='Col sweep is not complete but is_row_sweep_var is True.')
477      # Row sweep.
478      mon_sess.run(mark_sweep_done)
479      mon_sess.run(train_op)
480      self.assertTrue(
481          sess.run(is_row_sweep_var),
482          msg='Col sweep is complete but is_row_sweep_var is False.')
483
484
485class StopAtSweepHookTest(test.TestCase):
486
487  def test_stop(self):
488    hook = wals_lib._StopAtSweepHook(last_sweep=10)
489    completed_sweeps = variables.VariableV1(
490        8, name=wals_lib.WALSMatrixFactorization.COMPLETED_SWEEPS)
491    train_op = state_ops.assign_add(completed_sweeps, 1)
492    hook.begin()
493
494    with self.cached_session() as sess:
495      sess.run([variables.global_variables_initializer()])
496      mon_sess = monitored_session._HookedSession(sess, [hook])
497      mon_sess.run(train_op)
498      # completed_sweeps is 9 after running train_op.
499      self.assertFalse(mon_sess.should_stop())
500      mon_sess.run(train_op)
501      # completed_sweeps is 10 after running train_op.
502      self.assertTrue(mon_sess.should_stop())
503
504
505if __name__ == '__main__':
506  test.main()
507