1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Clustering Operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import random_seed as random_seed_ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import gen_clustering_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import nn_impl
31from tensorflow.python.ops import random_ops
32from tensorflow.python.ops import state_ops
33from tensorflow.python.ops import variable_scope
34from tensorflow.python.ops.embedding_ops import embedding_lookup
35# go/tf-wildcard-import
36# pylint: disable=wildcard-import
37from tensorflow.python.ops.gen_clustering_ops import *
38# pylint: enable=wildcard-import
39
40# Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\)
41# which is the square root of the sum of the absolute squares of the elements
42# difference.
43SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean'
44# Cosine distance between vectors U and V is defined as
45# \\(1 - (U \dot V) / (||U||_F ||V||_F)\\)
46COSINE_DISTANCE = 'cosine'
47
48RANDOM_INIT = 'random'
49KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus'
50KMC2_INIT = 'kmc2'
51
52# The name of the variable holding the cluster centers. Used by the Estimator.
53CLUSTERS_VAR_NAME = 'clusters'
54
55
56class KMeans(object):
57  """Creates the graph for k-means clustering."""
58
59  def __init__(self,
60               inputs,
61               num_clusters,
62               initial_clusters=RANDOM_INIT,
63               distance_metric=SQUARED_EUCLIDEAN_DISTANCE,
64               use_mini_batch=False,
65               mini_batch_steps_per_iteration=1,
66               random_seed=0,
67               kmeans_plus_plus_num_retries=2,
68               kmc2_chain_length=200):
69    """Creates an object for generating KMeans clustering graph.
70
71    This class implements the following variants of K-means algorithm:
72
73    If use_mini_batch is False, it runs standard full batch K-means. Each step
74    runs a single iteration of K-Means. This step can be run sharded across
75    multiple workers by passing a list of sharded inputs to this class. Note
76    however that a single step needs to process the full input at once.
77
78    If use_mini_batch is True, it runs a generalization of the mini-batch
79    K-means algorithm. It runs multiple iterations, where each iteration is
80    composed of mini_batch_steps_per_iteration steps. Two copies of cluster
81    centers are maintained: one that is updated at the end of each iteration,
82    and one that is updated every step. The first copy is used to compute
83    cluster allocations for each step, and for inference, while the second copy
84    is the one updated each step using the mini-batch update rule. After each
85    iteration is complete, this second copy is copied back the first copy.
86
87    Note that for use_mini_batch=True, when mini_batch_steps_per_iteration=1,
88    the algorithm reduces to the standard mini-batch algorithm. Also by setting
89    mini_batch_steps_per_iteration = num_inputs / batch_size, the algorithm
90    becomes an asynchronous version of the full-batch algorithm. Note however
91    that there is no guarantee by this implementation that each input is seen
92    exactly once per iteration. Also, different updates are applied
93    asynchronously without locking. So this asynchronous version may not behave
94    exactly like a full-batch version.
95
96    Args:
97      inputs: An input tensor or list of input tensors. It is assumed that the
98        data points have been previously randomly permuted.
99      num_clusters: An integer tensor specifying the number of clusters. This
100        argument is ignored if initial_clusters is a tensor or numpy array.
101      initial_clusters: Specifies the clusters used during initialization. One
102        of the following:
103        - a tensor or numpy array with the initial cluster centers.
104        - a function f(inputs, k) that returns up to k centers from `inputs`.
105        - "random": Choose centers randomly from `inputs`.
106        - "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`.
107        - "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`.
108        In the last three cases, one batch of `inputs` may not yield
109        `num_clusters` centers, in which case initialization will require
110        multiple batches until enough centers are chosen. In the case of
111        "random" or "kmeans_plus_plus", if the input size is <= `num_clusters`
112        then the entire batch is chosen to be cluster centers.
113      distance_metric: Distance metric used for clustering. Supported options:
114        "squared_euclidean", "cosine".
115      use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume
116        full batch.
117      mini_batch_steps_per_iteration: Number of steps after which the updated
118        cluster centers are synced back to a master copy.
119      random_seed: Seed for PRNG used to initialize seeds.
120      kmeans_plus_plus_num_retries: For each point that is sampled during
121        kmeans++ initialization, this parameter specifies the number of
122        additional points to draw from the current distribution before selecting
123        the best. If a negative value is specified, a heuristic is used to
124        sample O(log(num_to_sample)) additional points.
125      kmc2_chain_length: Determines how many candidate points are used by the
126        k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch
127        contains less points, one new cluster center is generated from the
128        (mini-)batch.
129
130    Raises:
131      ValueError: An invalid argument was passed to initial_clusters or
132        distance_metric.
133    """
134    if isinstance(initial_clusters, str) and initial_clusters not in [
135        RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT
136    ]:
137      raise ValueError(
138          "Unsupported initialization algorithm '%s'" % initial_clusters)
139    if distance_metric not in [SQUARED_EUCLIDEAN_DISTANCE, COSINE_DISTANCE]:
140      raise ValueError("Unsupported distance metric '%s'" % distance_metric)
141    self._inputs = inputs if isinstance(inputs, list) else [inputs]
142    self._num_clusters = num_clusters
143    self._initial_clusters = initial_clusters
144    self._distance_metric = distance_metric
145    self._use_mini_batch = use_mini_batch
146    self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration)
147    self._seed = random_seed_ops.get_seed(random_seed)[0]
148    self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
149    self._kmc2_chain_length = kmc2_chain_length
150
151  @classmethod
152  def _distance_graph(cls, inputs, clusters, distance_metric):
153    """Computes distance between each input and each cluster center.
154
155    Args:
156      inputs: list of input Tensors.
157      clusters: cluster Tensor.
158      distance_metric: distance metric used for clustering
159
160    Returns:
161      list of Tensors, where each element corresponds to each element in inputs.
162      The value is the distance of each row to all the cluster centers.
163      Currently only Euclidean distance and cosine distance are supported.
164    """
165    assert isinstance(inputs, list)
166    if distance_metric == SQUARED_EUCLIDEAN_DISTANCE:
167      return cls._compute_euclidean_distance(inputs, clusters)
168    elif distance_metric == COSINE_DISTANCE:
169      return cls._compute_cosine_distance(
170          inputs, clusters, inputs_normalized=True)
171    else:
172      assert False, str(distance_metric)
173
174  @classmethod
175  def _compute_euclidean_distance(cls, inputs, clusters):
176    """Computes Euclidean distance between each input and each cluster center.
177
178    Args:
179      inputs: list of input Tensors.
180      clusters: cluster Tensor.
181
182    Returns:
183      list of Tensors, where each element corresponds to each element in inputs.
184      The value is the distance of each row to all the cluster centers.
185    """
186    output = []
187    for inp in inputs:
188      with ops.colocate_with(inp, ignore_existing=True):
189        # Computes Euclidean distance. Note the first and third terms are
190        # broadcast additions.
191        squared_distance = (
192            math_ops.reduce_sum(math_ops.square(inp), 1, keepdims=True) -
193            2 * math_ops.matmul(inp, clusters, transpose_b=True) +
194            array_ops.transpose(
195                math_ops.reduce_sum(
196                    math_ops.square(clusters), 1, keepdims=True)))
197        output.append(squared_distance)
198
199    return output
200
201  @classmethod
202  def _compute_cosine_distance(cls, inputs, clusters, inputs_normalized=True):
203    """Computes cosine distance between each input and each cluster center.
204
205    Args:
206      inputs: list of input Tensor.
207      clusters: cluster Tensor
208      inputs_normalized: if True, it assumes that inp and clusters are
209      normalized and computes the dot product which is equivalent to the cosine
210      distance. Else it L2 normalizes the inputs first.
211
212    Returns:
213      list of Tensors, where each element corresponds to each element in inp.
214      The value is the distance of each row to all the cluster centers.
215    """
216    output = []
217    if not inputs_normalized:
218      with ops.colocate_with(clusters, ignore_existing=True):
219        clusters = nn_impl.l2_normalize(clusters, dim=1)
220    for inp in inputs:
221      with ops.colocate_with(inp, ignore_existing=True):
222        if not inputs_normalized:
223          inp = nn_impl.l2_normalize(inp, dim=1)
224        output.append(1 - math_ops.matmul(inp, clusters, transpose_b=True))
225    return output
226
227  def _infer_graph(self, inputs, clusters):
228    """Maps input to closest cluster and the score.
229
230    Args:
231      inputs: list of input Tensors.
232      clusters: Tensor of cluster centers.
233
234    Returns:
235      List of tuple, where each value in tuple corresponds to a value in inp.
236      The tuple has following three elements:
237      all_scores: distance of each input to each cluster center.
238      score: distance of each input to closest cluster center.
239      cluster_idx: index of cluster center closest to the corresponding input.
240    """
241    assert isinstance(inputs, list)
242    # Pairwise distances are used only by transform(). In all other cases, this
243    # sub-graph is not evaluated.
244    scores = self._distance_graph(inputs, clusters, self._distance_metric)
245    output = []
246    if (self._distance_metric == COSINE_DISTANCE and
247        not self._clusters_l2_normalized()):
248      # The cosine distance between normalized vectors x and y is the same as
249      # 2 * squared_euclidean_distance. We are using this fact and reusing the
250      # nearest_neighbors op.
251      # TODO(ands): Support COSINE distance in nearest_neighbors and remove
252      # this.
253      with ops.colocate_with(clusters, ignore_existing=True):
254        clusters = nn_impl.l2_normalize(clusters, dim=1)
255    for inp, score in zip(inputs, scores):
256      with ops.colocate_with(inp, ignore_existing=True):
257        (indices, distances) = gen_clustering_ops.nearest_neighbors(
258            inp, clusters, 1)
259        if self._distance_metric == COSINE_DISTANCE:
260          distances *= 0.5
261        output.append((score, array_ops.squeeze(distances, [-1]),
262                       array_ops.squeeze(indices, [-1])))
263    return zip(*output)
264
265  def _clusters_l2_normalized(self):
266    """Returns True if clusters centers are kept normalized."""
267    return (self._distance_metric == COSINE_DISTANCE and
268            (not self._use_mini_batch or
269             self._mini_batch_steps_per_iteration > 1))
270
271  def _create_variables(self, num_clusters):
272    """Creates variables.
273
274    Args:
275      num_clusters: an integer Tensor providing the number of clusters.
276
277    Returns:
278      Tuple with following elements:
279      - cluster_centers: a Tensor for storing cluster centers
280      - cluster_centers_initialized: bool Variable indicating whether clusters
281            are initialized.
282      - cluster_counts: a Tensor for storing counts of points assigned to this
283            cluster. This is used by mini-batch training.
284      - cluster_centers_updated: Tensor representing copy of cluster centers
285            that are updated every step.
286      - update_in_steps: numbers of steps left before we sync
287            cluster_centers_updated back to cluster_centers.
288    """
289    init_value = array_ops.placeholder_with_default([], shape=None)
290    cluster_centers = variable_scope.variable(
291        init_value, name=CLUSTERS_VAR_NAME, validate_shape=False)
292    cluster_centers_initialized = variable_scope.variable(
293        False, dtype=dtypes.bool, name='initialized')
294
295    if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
296      # Copy of cluster centers actively updated each step according to
297      # mini-batch update rule.
298      cluster_centers_updated = variable_scope.variable(
299          init_value, name='clusters_updated', validate_shape=False)
300      # How many steps till we copy the updated clusters to cluster_centers.
301      update_in_steps = variable_scope.variable(
302          self._mini_batch_steps_per_iteration,
303          dtype=dtypes.int64,
304          name='update_in_steps')
305      # Count of points assigned to cluster_centers_updated.
306      cluster_counts = variable_scope.variable(
307          array_ops.zeros([num_clusters], dtype=dtypes.int64))
308    else:
309      cluster_centers_updated = cluster_centers
310      update_in_steps = None
311      cluster_counts = (
312          variable_scope.variable(
313              array_ops.ones([num_clusters], dtype=dtypes.int64))
314          if self._use_mini_batch else None)
315    return (cluster_centers, cluster_centers_initialized, cluster_counts,
316            cluster_centers_updated, update_in_steps)
317
318  @classmethod
319  def _l2_normalize_data(cls, inputs):
320    """Normalized the input data."""
321    output = []
322    for inp in inputs:
323      with ops.colocate_with(inp, ignore_existing=True):
324        output.append(nn_impl.l2_normalize(inp, dim=1))
325    return output
326
327  def training_graph(self):
328    """Generate a training graph for kmeans algorithm.
329
330    This returns, among other things, an op that chooses initial centers
331    (init_op), a boolean variable that is set to True when the initial centers
332    are chosen (cluster_centers_initialized), and an op to perform either an
333    entire Lloyd iteration or a mini-batch of a Lloyd iteration (training_op).
334    The caller should use these components as follows. A single worker should
335    execute init_op multiple times until cluster_centers_initialized becomes
336    True. Then multiple workers may execute training_op any number of times.
337
338    Returns:
339      A tuple consisting of:
340      all_scores: A matrix (or list of matrices) of dimensions (num_input,
341        num_clusters) where the value is the distance of an input vector and a
342        cluster center.
343      cluster_idx: A vector (or list of vectors). Each element in the vector
344        corresponds to an input row in 'inp' and specifies the cluster id
345        corresponding to the input.
346      scores: Similar to cluster_idx but specifies the distance to the
347        assigned cluster instead.
348      cluster_centers_initialized: scalar indicating whether clusters have been
349        initialized.
350      init_op: an op to initialize the clusters.
351      training_op: an op that runs an iteration of training.
352    """
353    # Implementation of kmeans.
354    if (isinstance(self._initial_clusters, str) or
355        callable(self._initial_clusters)):
356      initial_clusters = self._initial_clusters
357      num_clusters = ops.convert_to_tensor(self._num_clusters)
358    else:
359      initial_clusters = ops.convert_to_tensor(self._initial_clusters)
360      num_clusters = array_ops.shape(initial_clusters)[0]
361
362    inputs = self._inputs
363    (cluster_centers_var, cluster_centers_initialized, total_counts,
364     cluster_centers_updated,
365     update_in_steps) = self._create_variables(num_clusters)
366    init_op = _InitializeClustersOpFactory(
367        self._inputs, num_clusters, initial_clusters, self._distance_metric,
368        self._seed, self._kmeans_plus_plus_num_retries, self._kmc2_chain_length,
369        cluster_centers_var, cluster_centers_updated,
370        cluster_centers_initialized).op()
371    cluster_centers = cluster_centers_var
372
373    if self._distance_metric == COSINE_DISTANCE:
374      inputs = self._l2_normalize_data(inputs)
375      if not self._clusters_l2_normalized():
376        cluster_centers = nn_impl.l2_normalize(cluster_centers, dim=1)
377
378    all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
379    if self._use_mini_batch:
380      sync_updates_op = self._mini_batch_sync_updates_op(
381          update_in_steps, cluster_centers_var, cluster_centers_updated,
382          total_counts)
383      assert sync_updates_op is not None
384      with ops.control_dependencies([sync_updates_op]):
385        training_op = self._mini_batch_training_op(
386            inputs, cluster_idx, cluster_centers_updated, total_counts)
387    else:
388      assert cluster_centers == cluster_centers_var
389      training_op = self._full_batch_training_op(
390          inputs, num_clusters, cluster_idx, cluster_centers_var)
391
392    return (all_scores, cluster_idx, scores, cluster_centers_initialized,
393            init_op, training_op)
394
395  def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
396                                  cluster_centers_updated, total_counts):
397    if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
398      assert update_in_steps is not None
399      with ops.colocate_with(update_in_steps, ignore_existing=True):
400
401        def _f():
402          # Note that there is a race condition here, so we do a best effort
403          # updates here. We reset update_in_steps first so that other workers
404          # don't duplicate the updates. Also we update cluster_center_vars
405          # before resetting total_counts to avoid large updates to
406          # cluster_centers_updated based on partially updated
407          # cluster_center_vars.
408          with ops.control_dependencies([
409              state_ops.assign(update_in_steps,
410                               self._mini_batch_steps_per_iteration - 1)
411          ]):
412            with ops.colocate_with(
413                cluster_centers_updated, ignore_existing=True):
414              if self._distance_metric == COSINE_DISTANCE:
415                cluster_centers = nn_impl.l2_normalize(
416                    cluster_centers_updated, dim=1)
417              else:
418                cluster_centers = cluster_centers_updated
419            with ops.colocate_with(cluster_centers_var, ignore_existing=True):
420              with ops.control_dependencies(
421                  [state_ops.assign(cluster_centers_var, cluster_centers)]):
422                with ops.colocate_with(None, ignore_existing=True):
423                  with ops.control_dependencies([
424                      state_ops.assign(total_counts,
425                                       array_ops.zeros_like(total_counts))
426                  ]):
427                    return array_ops.identity(update_in_steps)
428
429        return control_flow_ops.cond(
430            update_in_steps <= 0, _f,
431            lambda: state_ops.assign_sub(update_in_steps, 1))
432    else:
433      return control_flow_ops.no_op()
434
435  def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
436                              total_counts):
437    """Creates an op for training for mini batch case.
438
439    Args:
440      inputs: list of input Tensors.
441      cluster_idx_list: A vector (or list of vectors). Each element in the
442        vector corresponds to an input row in 'inp' and specifies the cluster id
443        corresponding to the input.
444      cluster_centers: Tensor Ref of cluster centers.
445      total_counts: Tensor Ref of cluster counts.
446
447    Returns:
448      An op for doing an update of mini-batch k-means.
449    """
450    update_ops = []
451    for inp, cluster_idx in zip(inputs, cluster_idx_list):
452      with ops.colocate_with(inp, ignore_existing=True):
453        assert total_counts is not None
454        cluster_idx = array_ops.reshape(cluster_idx, [-1])
455        # Dedupe the unique ids of cluster_centers being updated so that updates
456        # can be locally aggregated.
457        unique_ids, unique_idx = array_ops.unique(cluster_idx)
458        num_unique_cluster_idx = array_ops.size(unique_ids)
459        # Fetch the old values of counts and cluster_centers.
460        with ops.colocate_with(total_counts, ignore_existing=True):
461          old_counts = array_ops.gather(total_counts, unique_ids)
462        # TODO(agarwal): This colocation seems to run into problems. Fix it.
463        with ops.colocate_with(cluster_centers, ignore_existing=True):
464          old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
465        # Locally aggregate the increment to counts.
466        count_updates = math_ops.unsorted_segment_sum(
467            array_ops.ones_like(unique_idx, dtype=total_counts.dtype),
468            unique_idx, num_unique_cluster_idx)
469        # Locally compute the sum of inputs mapped to each id.
470        # For a cluster with old cluster value x, old count n, and with data
471        # d_1,...d_k newly assigned to it, we recompute the new value as
472        # \\(x += (sum_i(d_i) - k * x) / (n + k)\\).
473        # Compute \\(sum_i(d_i)\\), see comment above.
474        cluster_center_updates = math_ops.unsorted_segment_sum(
475            inp, unique_idx, num_unique_cluster_idx)
476        # Shape to enable broadcasting count_updates and learning_rate to inp.
477        # It extends the shape with 1's to match the rank of inp.
478        broadcast_shape = array_ops.concat([
479            array_ops.reshape(num_unique_cluster_idx, [1]),
480            array_ops.ones(
481                array_ops.reshape(array_ops.rank(inp) - 1, [1]),
482                dtype=dtypes.int32)
483        ], 0)
484        # Subtract k * x, see comment above.
485        cluster_center_updates -= math_ops.cast(
486            array_ops.reshape(count_updates, broadcast_shape),
487            inp.dtype) * old_cluster_centers
488        learning_rate = math_ops.reciprocal(
489            math_ops.cast(old_counts + count_updates, inp.dtype))
490        learning_rate = array_ops.reshape(learning_rate, broadcast_shape)
491        # scale by 1 / (n + k), see comment above.
492        cluster_center_updates *= learning_rate
493        # Apply the updates.
494      update_counts = state_ops.scatter_add(total_counts, unique_ids,
495                                            count_updates)
496      update_cluster_centers = state_ops.scatter_add(
497          cluster_centers, unique_ids, cluster_center_updates)
498      update_ops.extend([update_counts, update_cluster_centers])
499    return control_flow_ops.group(*update_ops)
500
501  def _full_batch_training_op(self, inputs, num_clusters, cluster_idx_list,
502                              cluster_centers):
503    """Creates an op for training for full batch case.
504
505    Args:
506      inputs: list of input Tensors.
507      num_clusters: an integer Tensor providing the number of clusters.
508      cluster_idx_list: A vector (or list of vectors). Each element in the
509        vector corresponds to an input row in 'inp' and specifies the cluster id
510        corresponding to the input.
511      cluster_centers: Tensor Ref of cluster centers.
512
513    Returns:
514      An op for doing an update of mini-batch k-means.
515    """
516    cluster_sums = []
517    cluster_counts = []
518    epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
519    for inp, cluster_idx in zip(inputs, cluster_idx_list):
520      with ops.colocate_with(inp, ignore_existing=True):
521        cluster_sums.append(
522            math_ops.unsorted_segment_sum(inp, cluster_idx, num_clusters))
523        cluster_counts.append(
524            math_ops.unsorted_segment_sum(
525                array_ops.reshape(
526                    array_ops.ones(
527                        array_ops.reshape(array_ops.shape(inp)[0], [-1])),
528                    [-1, 1]), cluster_idx, num_clusters))
529    with ops.colocate_with(cluster_centers, ignore_existing=True):
530      new_clusters_centers = math_ops.add_n(cluster_sums) / (
531          math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) +
532          epsilon)
533      if self._clusters_l2_normalized():
534        new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1)
535    return state_ops.assign(cluster_centers, new_clusters_centers)
536
537
538class _InitializeClustersOpFactory(object):
539  """Internal class to create the op to initialize the clusters.
540
541    The op performs this algorithm (see constructor args):
542
543    num_remaining = num_clusters - length(cluster_centers)
544    if num_remaining == 0:
545      assert that cluster_centers_initialized is true
546    else:
547      assert that num_remaining > 0
548      new_centers = choose up to num_remaining initial centers
549      l2-normalize new_centers if using cosine distance
550      all_centers = concat(cluster_centers, new_centers)
551      cluster_centers := all_centers
552      if there is a cluster_centers_updated variable:
553        cluster_centers_updated := cluster_centers
554      num_now_remaining = num_clusters - length(cluster_centers)
555      if num_now_remaining == 0:
556        cluster_centers_initialized := true
557  """
558
559  # TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case.
560
561  def __init__(self, inputs, num_clusters, initial_clusters, distance_metric,
562               random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length,
563               cluster_centers, cluster_centers_updated,
564               cluster_centers_initialized):
565    """Creates an op factory.
566
567    Args:
568      inputs: See KMeans constructor.
569      num_clusters: An integer Tensor providing the number of clusters.
570      initial_clusters: See KMeans constructor.
571      distance_metric: See KMeans constructor.
572      random_seed: See KMeans constructor.
573      kmeans_plus_plus_num_retries: See KMeans constructor.
574      kmc2_chain_length: See KMeans constructor.
575      cluster_centers: The TF variable holding the initial centers. It may
576          already contain some centers when the op is executed.
577      cluster_centers_updated: A second TF variable to hold a copy of the
578          initial centers, used for full-batch mode. In mini-batch mode,
579          cluster_centers_updated is the same variable as cluster_centers.
580      cluster_centers_initialized: A boolean TF variable that will be set
581          to true when all the initial centers have been chosen.
582    """
583    # All of these instance variables are constants.
584    self._inputs = inputs
585    self._num_clusters = num_clusters
586    self._initial_clusters = initial_clusters
587    self._distance_metric = distance_metric
588    self._seed = random_seed
589    self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
590    self._kmc2_chain_length = kmc2_chain_length
591    self._cluster_centers = cluster_centers
592    self._cluster_centers_updated = cluster_centers_updated
593    self._cluster_centers_initialized = cluster_centers_initialized
594
595    self._num_selected = array_ops.shape(self._cluster_centers)[0]
596    self._num_remaining = self._num_clusters - self._num_selected
597    self._num_data = math_ops.add_n(
598        [array_ops.shape(i)[0] for i in self._inputs])
599
600  def _random(self):
601    indices = random_ops.random_uniform(
602        array_ops.reshape(self._num_remaining, [-1]),
603        minval=0,
604        maxval=math_ops.cast(self._num_data, dtypes.int64),
605        seed=self._seed,
606        dtype=dtypes.int64)
607    return embedding_lookup(self._inputs, indices, partition_strategy='div')
608
609  def _kmeans_plus_plus(self):
610    # Points from only the first shard are used for initializing centers.
611    # TODO(ands): Use all points.
612    inp = self._inputs[0]
613    if self._distance_metric == COSINE_DISTANCE:
614      inp = nn_impl.l2_normalize(inp, dim=1)
615    return gen_clustering_ops.kmeans_plus_plus_initialization(
616        inp, math_ops.cast(self._num_remaining, dtypes.int64), self._seed,
617        self._kmeans_plus_plus_num_retries)
618
619  def _kmc2_multiple_centers(self):
620    """Adds new initial cluster centers using the k-MC2 algorithm.
621
622    In each call to the op, the provided batch is split into subsets based on
623    the specified `kmc2_chain_length`. On each subset, a single Markov chain of
624    the k-MC2 algorithm is used to add *one* new center cluster center. If there
625    are less than `kmc2_chain_length` points in the subset, a single center is
626    added using one Markov chain on the full input. It is assumed that the
627    provided batch has previously been randomly permuted. Otherwise, k-MC2 may
628    return suboptimal centers.
629
630    Returns:
631      An op that adds new cluster centers.
632    """
633    # The op only operates on the first shard of data.
634    first_shard = self._inputs[0]
635    # Number of points in the input that can be used.
636    batch_size = array_ops.shape(first_shard)[0]
637    # Maximum number of subsets such that the size of each subset is at least
638    # `kmc2_chain_length`. Final subsets may be larger.
639    max_to_sample = math_ops.cast(
640        batch_size / self._kmc2_chain_length, dtype=dtypes.int32)
641    # We sample at least one new center and at most all remaining centers.
642    num_to_sample = math_ops.maximum(
643        math_ops.minimum(self._num_remaining, max_to_sample), 1)
644
645    def _cond(i, _):
646      """Stopping condition for the while loop."""
647      return math_ops.less(i, num_to_sample)
648
649    def _body(i, _):
650      """Body that adds a single new center based on a subset."""
651
652      def _sample_random():
653        """Returns a random point as a cluster center."""
654        # By assumption the batch is reshuffled and _sample_random is always
655        # called for i=0. Hence, we simply return the first point.
656        new_center = array_ops.reshape(first_shard[0], [1, -1])
657        if self._distance_metric == COSINE_DISTANCE:
658          new_center = nn_impl.l2_normalize(new_center, dim=1)
659        return new_center
660
661      def _sample_kmc2_chain():
662        """Returns previous centers as well as a new center sampled using k-MC2.
663        """
664        # Extract the subset from the underlying batch.
665        start = i * self._kmc2_chain_length
666        end = start + self._kmc2_chain_length
667        subset = first_shard[start:end]
668        # Compute the distances from points in the subset to previous centers.
669        _, distances = gen_clustering_ops.nearest_neighbors(
670            subset, self._cluster_centers, 1)
671        # Sample index of new center using k-MC2 Markov chain.
672        new_center_index = gen_clustering_ops.kmc2_chain_initialization(
673            array_ops.squeeze(distances), self._seed)
674        # Extract actual new center.
675        newly_sampled_center = array_ops.reshape(subset[new_center_index],
676                                                 [1, -1])
677        # Return concatenation with previously sampled centers.
678        if self._distance_metric == COSINE_DISTANCE:
679          newly_sampled_center = nn_impl.l2_normalize(
680              newly_sampled_center, dim=1)
681        return array_ops.concat([self._cluster_centers, newly_sampled_center],
682                                0)
683
684      # Obtain a random point if there are no previously sampled centers.
685      # Otherwise, construct a k-MC2 Markov chain.
686      new_centers = control_flow_ops.cond(
687          math_ops.equal(self._num_selected, 0), _sample_random,
688          _sample_kmc2_chain)
689      # Assign new cluster centers to underlying variable.
690      assigned_centers = state_ops.assign(
691          self._cluster_centers, new_centers, validate_shape=False)
692      if self._cluster_centers_updated is not self._cluster_centers:
693        assigned_centers = state_ops.assign(
694            self._cluster_centers_updated,
695            assigned_centers,
696            validate_shape=False)
697      return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0]
698
699    # Add num_to_sample new data points.
700    _, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0])
701    return num_remaining
702
703  def _greedy_batch_sampler(self, sampler):
704    # If the input dataset size is smaller than the number of centers
705    # remaining, choose the entire input dataset as centers. This can happen
706    # with mini-batch. Otherwise, sample the batch according to the provided
707    # sampler.
708    return control_flow_ops.cond(self._num_data <= self._num_remaining,
709                                 lambda: array_ops.concat(self._inputs, 0),
710                                 sampler)
711
712  def _single_batch_sampler(self, sampler):
713    # Enforce that there are at least as many data points as centers
714    # remaining. This gives the provided sampler the chance to select all
715    # remaining centers from a single batch.
716    with ops.control_dependencies(
717        [check_ops.assert_greater_equal(self._num_data, self._num_remaining)]):
718      return sampler()
719
720  def _choose_initial_centers(self):
721    if isinstance(self._initial_clusters, str):
722      if self._initial_clusters == RANDOM_INIT:
723        return self._greedy_batch_sampler(self._random)
724      else:  # self._initial_clusters == KMEANS_PLUS_PLUS_INIT
725        return self._single_batch_sampler(self._kmeans_plus_plus)
726    elif callable(self._initial_clusters):
727      return self._initial_clusters(self._inputs, self._num_remaining)
728    else:
729      with ops.control_dependencies([
730          check_ops.assert_equal(self._num_remaining,
731                                 array_ops.shape(self._initial_clusters)[0])
732      ]):
733        return self._initial_clusters
734
735  def _add_new_centers(self):
736    """Adds some centers and returns the number of centers remaining."""
737    new_centers = self._choose_initial_centers()
738    if self._distance_metric == COSINE_DISTANCE:
739      new_centers = nn_impl.l2_normalize(new_centers, dim=1)
740    # If cluster_centers is empty, it doesn't have the right shape for concat.
741    all_centers = control_flow_ops.cond(
742        math_ops.equal(self._num_selected, 0), lambda: new_centers,
743        lambda: array_ops.concat([self._cluster_centers, new_centers], 0))
744    # TODO(ccolby): De-dupe all_centers?
745    a = state_ops.assign(
746        self._cluster_centers, all_centers, validate_shape=False)
747    if self._cluster_centers_updated is not self._cluster_centers:
748      a = state_ops.assign(
749          self._cluster_centers_updated, a, validate_shape=False)
750    return self._num_clusters - array_ops.shape(a)[0]
751
752  def _initialize(self):
753    with ops.control_dependencies([
754        check_ops.assert_positive(self._num_remaining),
755    ]):
756      if self._initial_clusters == KMC2_INIT:
757        num_now_remaining = self._kmc2_multiple_centers()
758      else:
759        num_now_remaining = self._add_new_centers()
760      return control_flow_ops.cond(
761          math_ops.equal(num_now_remaining, 0),
762          lambda: state_ops.assign(self._cluster_centers_initialized, True),
763          control_flow_ops.no_op)
764
765  def op(self):
766    """Returns the cluster initializer op."""
767    return control_flow_ops.cond(
768        math_ops.equal(self._num_remaining, 0),
769        lambda: check_ops.assert_equal(self._cluster_centers_initialized, True),
770        self._initialize)
771