1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Clustering Operations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import random_seed as random_seed_ops 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import check_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import gen_clustering_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import nn_impl 31from tensorflow.python.ops import random_ops 32from tensorflow.python.ops import state_ops 33from tensorflow.python.ops import variable_scope 34from tensorflow.python.ops.embedding_ops import embedding_lookup 35# go/tf-wildcard-import 36# pylint: disable=wildcard-import 37from tensorflow.python.ops.gen_clustering_ops import * 38# pylint: enable=wildcard-import 39 40# Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\) 41# which is the square root of the sum of the absolute squares of the elements 42# difference. 43SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean' 44# Cosine distance between vectors U and V is defined as 45# \\(1 - (U \dot V) / (||U||_F ||V||_F)\\) 46COSINE_DISTANCE = 'cosine' 47 48RANDOM_INIT = 'random' 49KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus' 50KMC2_INIT = 'kmc2' 51 52# The name of the variable holding the cluster centers. Used by the Estimator. 53CLUSTERS_VAR_NAME = 'clusters' 54 55 56class KMeans(object): 57 """Creates the graph for k-means clustering.""" 58 59 def __init__(self, 60 inputs, 61 num_clusters, 62 initial_clusters=RANDOM_INIT, 63 distance_metric=SQUARED_EUCLIDEAN_DISTANCE, 64 use_mini_batch=False, 65 mini_batch_steps_per_iteration=1, 66 random_seed=0, 67 kmeans_plus_plus_num_retries=2, 68 kmc2_chain_length=200): 69 """Creates an object for generating KMeans clustering graph. 70 71 This class implements the following variants of K-means algorithm: 72 73 If use_mini_batch is False, it runs standard full batch K-means. Each step 74 runs a single iteration of K-Means. This step can be run sharded across 75 multiple workers by passing a list of sharded inputs to this class. Note 76 however that a single step needs to process the full input at once. 77 78 If use_mini_batch is True, it runs a generalization of the mini-batch 79 K-means algorithm. It runs multiple iterations, where each iteration is 80 composed of mini_batch_steps_per_iteration steps. Two copies of cluster 81 centers are maintained: one that is updated at the end of each iteration, 82 and one that is updated every step. The first copy is used to compute 83 cluster allocations for each step, and for inference, while the second copy 84 is the one updated each step using the mini-batch update rule. After each 85 iteration is complete, this second copy is copied back the first copy. 86 87 Note that for use_mini_batch=True, when mini_batch_steps_per_iteration=1, 88 the algorithm reduces to the standard mini-batch algorithm. Also by setting 89 mini_batch_steps_per_iteration = num_inputs / batch_size, the algorithm 90 becomes an asynchronous version of the full-batch algorithm. Note however 91 that there is no guarantee by this implementation that each input is seen 92 exactly once per iteration. Also, different updates are applied 93 asynchronously without locking. So this asynchronous version may not behave 94 exactly like a full-batch version. 95 96 Args: 97 inputs: An input tensor or list of input tensors. It is assumed that the 98 data points have been previously randomly permuted. 99 num_clusters: An integer tensor specifying the number of clusters. This 100 argument is ignored if initial_clusters is a tensor or numpy array. 101 initial_clusters: Specifies the clusters used during initialization. One 102 of the following: 103 - a tensor or numpy array with the initial cluster centers. 104 - a function f(inputs, k) that returns up to k centers from `inputs`. 105 - "random": Choose centers randomly from `inputs`. 106 - "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`. 107 - "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`. 108 In the last three cases, one batch of `inputs` may not yield 109 `num_clusters` centers, in which case initialization will require 110 multiple batches until enough centers are chosen. In the case of 111 "random" or "kmeans_plus_plus", if the input size is <= `num_clusters` 112 then the entire batch is chosen to be cluster centers. 113 distance_metric: Distance metric used for clustering. Supported options: 114 "squared_euclidean", "cosine". 115 use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume 116 full batch. 117 mini_batch_steps_per_iteration: Number of steps after which the updated 118 cluster centers are synced back to a master copy. 119 random_seed: Seed for PRNG used to initialize seeds. 120 kmeans_plus_plus_num_retries: For each point that is sampled during 121 kmeans++ initialization, this parameter specifies the number of 122 additional points to draw from the current distribution before selecting 123 the best. If a negative value is specified, a heuristic is used to 124 sample O(log(num_to_sample)) additional points. 125 kmc2_chain_length: Determines how many candidate points are used by the 126 k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch 127 contains less points, one new cluster center is generated from the 128 (mini-)batch. 129 130 Raises: 131 ValueError: An invalid argument was passed to initial_clusters or 132 distance_metric. 133 """ 134 if isinstance(initial_clusters, str) and initial_clusters not in [ 135 RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT 136 ]: 137 raise ValueError( 138 "Unsupported initialization algorithm '%s'" % initial_clusters) 139 if distance_metric not in [SQUARED_EUCLIDEAN_DISTANCE, COSINE_DISTANCE]: 140 raise ValueError("Unsupported distance metric '%s'" % distance_metric) 141 self._inputs = inputs if isinstance(inputs, list) else [inputs] 142 self._num_clusters = num_clusters 143 self._initial_clusters = initial_clusters 144 self._distance_metric = distance_metric 145 self._use_mini_batch = use_mini_batch 146 self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration) 147 self._seed = random_seed_ops.get_seed(random_seed)[0] 148 self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries 149 self._kmc2_chain_length = kmc2_chain_length 150 151 @classmethod 152 def _distance_graph(cls, inputs, clusters, distance_metric): 153 """Computes distance between each input and each cluster center. 154 155 Args: 156 inputs: list of input Tensors. 157 clusters: cluster Tensor. 158 distance_metric: distance metric used for clustering 159 160 Returns: 161 list of Tensors, where each element corresponds to each element in inputs. 162 The value is the distance of each row to all the cluster centers. 163 Currently only Euclidean distance and cosine distance are supported. 164 """ 165 assert isinstance(inputs, list) 166 if distance_metric == SQUARED_EUCLIDEAN_DISTANCE: 167 return cls._compute_euclidean_distance(inputs, clusters) 168 elif distance_metric == COSINE_DISTANCE: 169 return cls._compute_cosine_distance( 170 inputs, clusters, inputs_normalized=True) 171 else: 172 assert False, str(distance_metric) 173 174 @classmethod 175 def _compute_euclidean_distance(cls, inputs, clusters): 176 """Computes Euclidean distance between each input and each cluster center. 177 178 Args: 179 inputs: list of input Tensors. 180 clusters: cluster Tensor. 181 182 Returns: 183 list of Tensors, where each element corresponds to each element in inputs. 184 The value is the distance of each row to all the cluster centers. 185 """ 186 output = [] 187 for inp in inputs: 188 with ops.colocate_with(inp, ignore_existing=True): 189 # Computes Euclidean distance. Note the first and third terms are 190 # broadcast additions. 191 squared_distance = ( 192 math_ops.reduce_sum(math_ops.square(inp), 1, keepdims=True) - 193 2 * math_ops.matmul(inp, clusters, transpose_b=True) + 194 array_ops.transpose( 195 math_ops.reduce_sum( 196 math_ops.square(clusters), 1, keepdims=True))) 197 output.append(squared_distance) 198 199 return output 200 201 @classmethod 202 def _compute_cosine_distance(cls, inputs, clusters, inputs_normalized=True): 203 """Computes cosine distance between each input and each cluster center. 204 205 Args: 206 inputs: list of input Tensor. 207 clusters: cluster Tensor 208 inputs_normalized: if True, it assumes that inp and clusters are 209 normalized and computes the dot product which is equivalent to the cosine 210 distance. Else it L2 normalizes the inputs first. 211 212 Returns: 213 list of Tensors, where each element corresponds to each element in inp. 214 The value is the distance of each row to all the cluster centers. 215 """ 216 output = [] 217 if not inputs_normalized: 218 with ops.colocate_with(clusters, ignore_existing=True): 219 clusters = nn_impl.l2_normalize(clusters, dim=1) 220 for inp in inputs: 221 with ops.colocate_with(inp, ignore_existing=True): 222 if not inputs_normalized: 223 inp = nn_impl.l2_normalize(inp, dim=1) 224 output.append(1 - math_ops.matmul(inp, clusters, transpose_b=True)) 225 return output 226 227 def _infer_graph(self, inputs, clusters): 228 """Maps input to closest cluster and the score. 229 230 Args: 231 inputs: list of input Tensors. 232 clusters: Tensor of cluster centers. 233 234 Returns: 235 List of tuple, where each value in tuple corresponds to a value in inp. 236 The tuple has following three elements: 237 all_scores: distance of each input to each cluster center. 238 score: distance of each input to closest cluster center. 239 cluster_idx: index of cluster center closest to the corresponding input. 240 """ 241 assert isinstance(inputs, list) 242 # Pairwise distances are used only by transform(). In all other cases, this 243 # sub-graph is not evaluated. 244 scores = self._distance_graph(inputs, clusters, self._distance_metric) 245 output = [] 246 if (self._distance_metric == COSINE_DISTANCE and 247 not self._clusters_l2_normalized()): 248 # The cosine distance between normalized vectors x and y is the same as 249 # 2 * squared_euclidean_distance. We are using this fact and reusing the 250 # nearest_neighbors op. 251 # TODO(ands): Support COSINE distance in nearest_neighbors and remove 252 # this. 253 with ops.colocate_with(clusters, ignore_existing=True): 254 clusters = nn_impl.l2_normalize(clusters, dim=1) 255 for inp, score in zip(inputs, scores): 256 with ops.colocate_with(inp, ignore_existing=True): 257 (indices, distances) = gen_clustering_ops.nearest_neighbors( 258 inp, clusters, 1) 259 if self._distance_metric == COSINE_DISTANCE: 260 distances *= 0.5 261 output.append((score, array_ops.squeeze(distances, [-1]), 262 array_ops.squeeze(indices, [-1]))) 263 return zip(*output) 264 265 def _clusters_l2_normalized(self): 266 """Returns True if clusters centers are kept normalized.""" 267 return (self._distance_metric == COSINE_DISTANCE and 268 (not self._use_mini_batch or 269 self._mini_batch_steps_per_iteration > 1)) 270 271 def _create_variables(self, num_clusters): 272 """Creates variables. 273 274 Args: 275 num_clusters: an integer Tensor providing the number of clusters. 276 277 Returns: 278 Tuple with following elements: 279 - cluster_centers: a Tensor for storing cluster centers 280 - cluster_centers_initialized: bool Variable indicating whether clusters 281 are initialized. 282 - cluster_counts: a Tensor for storing counts of points assigned to this 283 cluster. This is used by mini-batch training. 284 - cluster_centers_updated: Tensor representing copy of cluster centers 285 that are updated every step. 286 - update_in_steps: numbers of steps left before we sync 287 cluster_centers_updated back to cluster_centers. 288 """ 289 init_value = array_ops.placeholder_with_default([], shape=None) 290 cluster_centers = variable_scope.variable( 291 init_value, name=CLUSTERS_VAR_NAME, validate_shape=False) 292 cluster_centers_initialized = variable_scope.variable( 293 False, dtype=dtypes.bool, name='initialized') 294 295 if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: 296 # Copy of cluster centers actively updated each step according to 297 # mini-batch update rule. 298 cluster_centers_updated = variable_scope.variable( 299 init_value, name='clusters_updated', validate_shape=False) 300 # How many steps till we copy the updated clusters to cluster_centers. 301 update_in_steps = variable_scope.variable( 302 self._mini_batch_steps_per_iteration, 303 dtype=dtypes.int64, 304 name='update_in_steps') 305 # Count of points assigned to cluster_centers_updated. 306 cluster_counts = variable_scope.variable( 307 array_ops.zeros([num_clusters], dtype=dtypes.int64)) 308 else: 309 cluster_centers_updated = cluster_centers 310 update_in_steps = None 311 cluster_counts = ( 312 variable_scope.variable( 313 array_ops.ones([num_clusters], dtype=dtypes.int64)) 314 if self._use_mini_batch else None) 315 return (cluster_centers, cluster_centers_initialized, cluster_counts, 316 cluster_centers_updated, update_in_steps) 317 318 @classmethod 319 def _l2_normalize_data(cls, inputs): 320 """Normalized the input data.""" 321 output = [] 322 for inp in inputs: 323 with ops.colocate_with(inp, ignore_existing=True): 324 output.append(nn_impl.l2_normalize(inp, dim=1)) 325 return output 326 327 def training_graph(self): 328 """Generate a training graph for kmeans algorithm. 329 330 This returns, among other things, an op that chooses initial centers 331 (init_op), a boolean variable that is set to True when the initial centers 332 are chosen (cluster_centers_initialized), and an op to perform either an 333 entire Lloyd iteration or a mini-batch of a Lloyd iteration (training_op). 334 The caller should use these components as follows. A single worker should 335 execute init_op multiple times until cluster_centers_initialized becomes 336 True. Then multiple workers may execute training_op any number of times. 337 338 Returns: 339 A tuple consisting of: 340 all_scores: A matrix (or list of matrices) of dimensions (num_input, 341 num_clusters) where the value is the distance of an input vector and a 342 cluster center. 343 cluster_idx: A vector (or list of vectors). Each element in the vector 344 corresponds to an input row in 'inp' and specifies the cluster id 345 corresponding to the input. 346 scores: Similar to cluster_idx but specifies the distance to the 347 assigned cluster instead. 348 cluster_centers_initialized: scalar indicating whether clusters have been 349 initialized. 350 init_op: an op to initialize the clusters. 351 training_op: an op that runs an iteration of training. 352 """ 353 # Implementation of kmeans. 354 if (isinstance(self._initial_clusters, str) or 355 callable(self._initial_clusters)): 356 initial_clusters = self._initial_clusters 357 num_clusters = ops.convert_to_tensor(self._num_clusters) 358 else: 359 initial_clusters = ops.convert_to_tensor(self._initial_clusters) 360 num_clusters = array_ops.shape(initial_clusters)[0] 361 362 inputs = self._inputs 363 (cluster_centers_var, cluster_centers_initialized, total_counts, 364 cluster_centers_updated, 365 update_in_steps) = self._create_variables(num_clusters) 366 init_op = _InitializeClustersOpFactory( 367 self._inputs, num_clusters, initial_clusters, self._distance_metric, 368 self._seed, self._kmeans_plus_plus_num_retries, self._kmc2_chain_length, 369 cluster_centers_var, cluster_centers_updated, 370 cluster_centers_initialized).op() 371 cluster_centers = cluster_centers_var 372 373 if self._distance_metric == COSINE_DISTANCE: 374 inputs = self._l2_normalize_data(inputs) 375 if not self._clusters_l2_normalized(): 376 cluster_centers = nn_impl.l2_normalize(cluster_centers, dim=1) 377 378 all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers) 379 if self._use_mini_batch: 380 sync_updates_op = self._mini_batch_sync_updates_op( 381 update_in_steps, cluster_centers_var, cluster_centers_updated, 382 total_counts) 383 assert sync_updates_op is not None 384 with ops.control_dependencies([sync_updates_op]): 385 training_op = self._mini_batch_training_op( 386 inputs, cluster_idx, cluster_centers_updated, total_counts) 387 else: 388 assert cluster_centers == cluster_centers_var 389 training_op = self._full_batch_training_op( 390 inputs, num_clusters, cluster_idx, cluster_centers_var) 391 392 return (all_scores, cluster_idx, scores, cluster_centers_initialized, 393 init_op, training_op) 394 395 def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var, 396 cluster_centers_updated, total_counts): 397 if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: 398 assert update_in_steps is not None 399 with ops.colocate_with(update_in_steps, ignore_existing=True): 400 401 def _f(): 402 # Note that there is a race condition here, so we do a best effort 403 # updates here. We reset update_in_steps first so that other workers 404 # don't duplicate the updates. Also we update cluster_center_vars 405 # before resetting total_counts to avoid large updates to 406 # cluster_centers_updated based on partially updated 407 # cluster_center_vars. 408 with ops.control_dependencies([ 409 state_ops.assign(update_in_steps, 410 self._mini_batch_steps_per_iteration - 1) 411 ]): 412 with ops.colocate_with( 413 cluster_centers_updated, ignore_existing=True): 414 if self._distance_metric == COSINE_DISTANCE: 415 cluster_centers = nn_impl.l2_normalize( 416 cluster_centers_updated, dim=1) 417 else: 418 cluster_centers = cluster_centers_updated 419 with ops.colocate_with(cluster_centers_var, ignore_existing=True): 420 with ops.control_dependencies( 421 [state_ops.assign(cluster_centers_var, cluster_centers)]): 422 with ops.colocate_with(None, ignore_existing=True): 423 with ops.control_dependencies([ 424 state_ops.assign(total_counts, 425 array_ops.zeros_like(total_counts)) 426 ]): 427 return array_ops.identity(update_in_steps) 428 429 return control_flow_ops.cond( 430 update_in_steps <= 0, _f, 431 lambda: state_ops.assign_sub(update_in_steps, 1)) 432 else: 433 return control_flow_ops.no_op() 434 435 def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers, 436 total_counts): 437 """Creates an op for training for mini batch case. 438 439 Args: 440 inputs: list of input Tensors. 441 cluster_idx_list: A vector (or list of vectors). Each element in the 442 vector corresponds to an input row in 'inp' and specifies the cluster id 443 corresponding to the input. 444 cluster_centers: Tensor Ref of cluster centers. 445 total_counts: Tensor Ref of cluster counts. 446 447 Returns: 448 An op for doing an update of mini-batch k-means. 449 """ 450 update_ops = [] 451 for inp, cluster_idx in zip(inputs, cluster_idx_list): 452 with ops.colocate_with(inp, ignore_existing=True): 453 assert total_counts is not None 454 cluster_idx = array_ops.reshape(cluster_idx, [-1]) 455 # Dedupe the unique ids of cluster_centers being updated so that updates 456 # can be locally aggregated. 457 unique_ids, unique_idx = array_ops.unique(cluster_idx) 458 num_unique_cluster_idx = array_ops.size(unique_ids) 459 # Fetch the old values of counts and cluster_centers. 460 with ops.colocate_with(total_counts, ignore_existing=True): 461 old_counts = array_ops.gather(total_counts, unique_ids) 462 # TODO(agarwal): This colocation seems to run into problems. Fix it. 463 with ops.colocate_with(cluster_centers, ignore_existing=True): 464 old_cluster_centers = array_ops.gather(cluster_centers, unique_ids) 465 # Locally aggregate the increment to counts. 466 count_updates = math_ops.unsorted_segment_sum( 467 array_ops.ones_like(unique_idx, dtype=total_counts.dtype), 468 unique_idx, num_unique_cluster_idx) 469 # Locally compute the sum of inputs mapped to each id. 470 # For a cluster with old cluster value x, old count n, and with data 471 # d_1,...d_k newly assigned to it, we recompute the new value as 472 # \\(x += (sum_i(d_i) - k * x) / (n + k)\\). 473 # Compute \\(sum_i(d_i)\\), see comment above. 474 cluster_center_updates = math_ops.unsorted_segment_sum( 475 inp, unique_idx, num_unique_cluster_idx) 476 # Shape to enable broadcasting count_updates and learning_rate to inp. 477 # It extends the shape with 1's to match the rank of inp. 478 broadcast_shape = array_ops.concat([ 479 array_ops.reshape(num_unique_cluster_idx, [1]), 480 array_ops.ones( 481 array_ops.reshape(array_ops.rank(inp) - 1, [1]), 482 dtype=dtypes.int32) 483 ], 0) 484 # Subtract k * x, see comment above. 485 cluster_center_updates -= math_ops.cast( 486 array_ops.reshape(count_updates, broadcast_shape), 487 inp.dtype) * old_cluster_centers 488 learning_rate = math_ops.reciprocal( 489 math_ops.cast(old_counts + count_updates, inp.dtype)) 490 learning_rate = array_ops.reshape(learning_rate, broadcast_shape) 491 # scale by 1 / (n + k), see comment above. 492 cluster_center_updates *= learning_rate 493 # Apply the updates. 494 update_counts = state_ops.scatter_add(total_counts, unique_ids, 495 count_updates) 496 update_cluster_centers = state_ops.scatter_add( 497 cluster_centers, unique_ids, cluster_center_updates) 498 update_ops.extend([update_counts, update_cluster_centers]) 499 return control_flow_ops.group(*update_ops) 500 501 def _full_batch_training_op(self, inputs, num_clusters, cluster_idx_list, 502 cluster_centers): 503 """Creates an op for training for full batch case. 504 505 Args: 506 inputs: list of input Tensors. 507 num_clusters: an integer Tensor providing the number of clusters. 508 cluster_idx_list: A vector (or list of vectors). Each element in the 509 vector corresponds to an input row in 'inp' and specifies the cluster id 510 corresponding to the input. 511 cluster_centers: Tensor Ref of cluster centers. 512 513 Returns: 514 An op for doing an update of mini-batch k-means. 515 """ 516 cluster_sums = [] 517 cluster_counts = [] 518 epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype) 519 for inp, cluster_idx in zip(inputs, cluster_idx_list): 520 with ops.colocate_with(inp, ignore_existing=True): 521 cluster_sums.append( 522 math_ops.unsorted_segment_sum(inp, cluster_idx, num_clusters)) 523 cluster_counts.append( 524 math_ops.unsorted_segment_sum( 525 array_ops.reshape( 526 array_ops.ones( 527 array_ops.reshape(array_ops.shape(inp)[0], [-1])), 528 [-1, 1]), cluster_idx, num_clusters)) 529 with ops.colocate_with(cluster_centers, ignore_existing=True): 530 new_clusters_centers = math_ops.add_n(cluster_sums) / ( 531 math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + 532 epsilon) 533 if self._clusters_l2_normalized(): 534 new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1) 535 return state_ops.assign(cluster_centers, new_clusters_centers) 536 537 538class _InitializeClustersOpFactory(object): 539 """Internal class to create the op to initialize the clusters. 540 541 The op performs this algorithm (see constructor args): 542 543 num_remaining = num_clusters - length(cluster_centers) 544 if num_remaining == 0: 545 assert that cluster_centers_initialized is true 546 else: 547 assert that num_remaining > 0 548 new_centers = choose up to num_remaining initial centers 549 l2-normalize new_centers if using cosine distance 550 all_centers = concat(cluster_centers, new_centers) 551 cluster_centers := all_centers 552 if there is a cluster_centers_updated variable: 553 cluster_centers_updated := cluster_centers 554 num_now_remaining = num_clusters - length(cluster_centers) 555 if num_now_remaining == 0: 556 cluster_centers_initialized := true 557 """ 558 559 # TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case. 560 561 def __init__(self, inputs, num_clusters, initial_clusters, distance_metric, 562 random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length, 563 cluster_centers, cluster_centers_updated, 564 cluster_centers_initialized): 565 """Creates an op factory. 566 567 Args: 568 inputs: See KMeans constructor. 569 num_clusters: An integer Tensor providing the number of clusters. 570 initial_clusters: See KMeans constructor. 571 distance_metric: See KMeans constructor. 572 random_seed: See KMeans constructor. 573 kmeans_plus_plus_num_retries: See KMeans constructor. 574 kmc2_chain_length: See KMeans constructor. 575 cluster_centers: The TF variable holding the initial centers. It may 576 already contain some centers when the op is executed. 577 cluster_centers_updated: A second TF variable to hold a copy of the 578 initial centers, used for full-batch mode. In mini-batch mode, 579 cluster_centers_updated is the same variable as cluster_centers. 580 cluster_centers_initialized: A boolean TF variable that will be set 581 to true when all the initial centers have been chosen. 582 """ 583 # All of these instance variables are constants. 584 self._inputs = inputs 585 self._num_clusters = num_clusters 586 self._initial_clusters = initial_clusters 587 self._distance_metric = distance_metric 588 self._seed = random_seed 589 self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries 590 self._kmc2_chain_length = kmc2_chain_length 591 self._cluster_centers = cluster_centers 592 self._cluster_centers_updated = cluster_centers_updated 593 self._cluster_centers_initialized = cluster_centers_initialized 594 595 self._num_selected = array_ops.shape(self._cluster_centers)[0] 596 self._num_remaining = self._num_clusters - self._num_selected 597 self._num_data = math_ops.add_n( 598 [array_ops.shape(i)[0] for i in self._inputs]) 599 600 def _random(self): 601 indices = random_ops.random_uniform( 602 array_ops.reshape(self._num_remaining, [-1]), 603 minval=0, 604 maxval=math_ops.cast(self._num_data, dtypes.int64), 605 seed=self._seed, 606 dtype=dtypes.int64) 607 return embedding_lookup(self._inputs, indices, partition_strategy='div') 608 609 def _kmeans_plus_plus(self): 610 # Points from only the first shard are used for initializing centers. 611 # TODO(ands): Use all points. 612 inp = self._inputs[0] 613 if self._distance_metric == COSINE_DISTANCE: 614 inp = nn_impl.l2_normalize(inp, dim=1) 615 return gen_clustering_ops.kmeans_plus_plus_initialization( 616 inp, math_ops.cast(self._num_remaining, dtypes.int64), self._seed, 617 self._kmeans_plus_plus_num_retries) 618 619 def _kmc2_multiple_centers(self): 620 """Adds new initial cluster centers using the k-MC2 algorithm. 621 622 In each call to the op, the provided batch is split into subsets based on 623 the specified `kmc2_chain_length`. On each subset, a single Markov chain of 624 the k-MC2 algorithm is used to add *one* new center cluster center. If there 625 are less than `kmc2_chain_length` points in the subset, a single center is 626 added using one Markov chain on the full input. It is assumed that the 627 provided batch has previously been randomly permuted. Otherwise, k-MC2 may 628 return suboptimal centers. 629 630 Returns: 631 An op that adds new cluster centers. 632 """ 633 # The op only operates on the first shard of data. 634 first_shard = self._inputs[0] 635 # Number of points in the input that can be used. 636 batch_size = array_ops.shape(first_shard)[0] 637 # Maximum number of subsets such that the size of each subset is at least 638 # `kmc2_chain_length`. Final subsets may be larger. 639 max_to_sample = math_ops.cast( 640 batch_size / self._kmc2_chain_length, dtype=dtypes.int32) 641 # We sample at least one new center and at most all remaining centers. 642 num_to_sample = math_ops.maximum( 643 math_ops.minimum(self._num_remaining, max_to_sample), 1) 644 645 def _cond(i, _): 646 """Stopping condition for the while loop.""" 647 return math_ops.less(i, num_to_sample) 648 649 def _body(i, _): 650 """Body that adds a single new center based on a subset.""" 651 652 def _sample_random(): 653 """Returns a random point as a cluster center.""" 654 # By assumption the batch is reshuffled and _sample_random is always 655 # called for i=0. Hence, we simply return the first point. 656 new_center = array_ops.reshape(first_shard[0], [1, -1]) 657 if self._distance_metric == COSINE_DISTANCE: 658 new_center = nn_impl.l2_normalize(new_center, dim=1) 659 return new_center 660 661 def _sample_kmc2_chain(): 662 """Returns previous centers as well as a new center sampled using k-MC2. 663 """ 664 # Extract the subset from the underlying batch. 665 start = i * self._kmc2_chain_length 666 end = start + self._kmc2_chain_length 667 subset = first_shard[start:end] 668 # Compute the distances from points in the subset to previous centers. 669 _, distances = gen_clustering_ops.nearest_neighbors( 670 subset, self._cluster_centers, 1) 671 # Sample index of new center using k-MC2 Markov chain. 672 new_center_index = gen_clustering_ops.kmc2_chain_initialization( 673 array_ops.squeeze(distances), self._seed) 674 # Extract actual new center. 675 newly_sampled_center = array_ops.reshape(subset[new_center_index], 676 [1, -1]) 677 # Return concatenation with previously sampled centers. 678 if self._distance_metric == COSINE_DISTANCE: 679 newly_sampled_center = nn_impl.l2_normalize( 680 newly_sampled_center, dim=1) 681 return array_ops.concat([self._cluster_centers, newly_sampled_center], 682 0) 683 684 # Obtain a random point if there are no previously sampled centers. 685 # Otherwise, construct a k-MC2 Markov chain. 686 new_centers = control_flow_ops.cond( 687 math_ops.equal(self._num_selected, 0), _sample_random, 688 _sample_kmc2_chain) 689 # Assign new cluster centers to underlying variable. 690 assigned_centers = state_ops.assign( 691 self._cluster_centers, new_centers, validate_shape=False) 692 if self._cluster_centers_updated is not self._cluster_centers: 693 assigned_centers = state_ops.assign( 694 self._cluster_centers_updated, 695 assigned_centers, 696 validate_shape=False) 697 return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0] 698 699 # Add num_to_sample new data points. 700 _, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0]) 701 return num_remaining 702 703 def _greedy_batch_sampler(self, sampler): 704 # If the input dataset size is smaller than the number of centers 705 # remaining, choose the entire input dataset as centers. This can happen 706 # with mini-batch. Otherwise, sample the batch according to the provided 707 # sampler. 708 return control_flow_ops.cond(self._num_data <= self._num_remaining, 709 lambda: array_ops.concat(self._inputs, 0), 710 sampler) 711 712 def _single_batch_sampler(self, sampler): 713 # Enforce that there are at least as many data points as centers 714 # remaining. This gives the provided sampler the chance to select all 715 # remaining centers from a single batch. 716 with ops.control_dependencies( 717 [check_ops.assert_greater_equal(self._num_data, self._num_remaining)]): 718 return sampler() 719 720 def _choose_initial_centers(self): 721 if isinstance(self._initial_clusters, str): 722 if self._initial_clusters == RANDOM_INIT: 723 return self._greedy_batch_sampler(self._random) 724 else: # self._initial_clusters == KMEANS_PLUS_PLUS_INIT 725 return self._single_batch_sampler(self._kmeans_plus_plus) 726 elif callable(self._initial_clusters): 727 return self._initial_clusters(self._inputs, self._num_remaining) 728 else: 729 with ops.control_dependencies([ 730 check_ops.assert_equal(self._num_remaining, 731 array_ops.shape(self._initial_clusters)[0]) 732 ]): 733 return self._initial_clusters 734 735 def _add_new_centers(self): 736 """Adds some centers and returns the number of centers remaining.""" 737 new_centers = self._choose_initial_centers() 738 if self._distance_metric == COSINE_DISTANCE: 739 new_centers = nn_impl.l2_normalize(new_centers, dim=1) 740 # If cluster_centers is empty, it doesn't have the right shape for concat. 741 all_centers = control_flow_ops.cond( 742 math_ops.equal(self._num_selected, 0), lambda: new_centers, 743 lambda: array_ops.concat([self._cluster_centers, new_centers], 0)) 744 # TODO(ccolby): De-dupe all_centers? 745 a = state_ops.assign( 746 self._cluster_centers, all_centers, validate_shape=False) 747 if self._cluster_centers_updated is not self._cluster_centers: 748 a = state_ops.assign( 749 self._cluster_centers_updated, a, validate_shape=False) 750 return self._num_clusters - array_ops.shape(a)[0] 751 752 def _initialize(self): 753 with ops.control_dependencies([ 754 check_ops.assert_positive(self._num_remaining), 755 ]): 756 if self._initial_clusters == KMC2_INIT: 757 num_now_remaining = self._kmc2_multiple_centers() 758 else: 759 num_now_remaining = self._add_new_centers() 760 return control_flow_ops.cond( 761 math_ops.equal(num_now_remaining, 0), 762 lambda: state_ops.assign(self._cluster_centers_initialized, True), 763 control_flow_ops.no_op) 764 765 def op(self): 766 """Returns the cluster initializer op.""" 767 return control_flow_ops.cond( 768 math_ops.equal(self._num_remaining, 0), 769 lambda: check_ops.assert_equal(self._cluster_centers_initialized, True), 770 self._initialize) 771