1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15"""TPU Feature Column Library.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import math 21 22from tensorflow.python.feature_column import feature_column as fc 23from tensorflow.python.feature_column import feature_column_lib as fc_lib 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import init_ops 26from tensorflow.python.ops import variable_scope 27from tensorflow.python.tpu import tpu 28from tensorflow.python.tpu import tpu_function 29# pylint: disable=protected-access 30 31 32_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope' 33_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn, 34 fc._VocabularyFileCategoricalColumn, 35 fc._VocabularyListCategoricalColumn, 36 fc._WeightedCategoricalColumn, 37 fc_lib.IdentityCategoricalColumn, 38 fc_lib.VocabularyFileCategoricalColumn, 39 fc_lib.VocabularyListCategoricalColumn, 40 fc_lib.WeightedCategoricalColumn) 41 42 43def embedding_column(categorical_column, 44 dimension, 45 combiner='mean', 46 initializer=None): 47 """TPU embedding_column for `tf.feature_column.embedding_column`. 48 49 Note that the interface for TPU embedding_column is different from the non-TPU 50 version. The following args available for the non-TPU version are NOT 51 supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable. 52 53 Args: 54 categorical_column: A categorical_column returned from 55 categorical_column_with_identity, weighted_categorical_column, 56 categorical_column_with_vocabulary_list or 57 categorical_column_with_vocabulary_file. 58 dimension: An integer specifying dimension of the embedding, must be > 0. 59 combiner: A string specifying how to reduce if there are multiple entries 60 in a single row. For more information, see 61 `tf.feature_column.embedding_column`. 62 initializer: A variable initializer function to be used in embedding 63 variable initialization. If not specified, defaults to 64 `tf.truncated_normal_initializer` with mean `0.0` and standard deviation 65 `1/sqrt(dimension)`. 66 67 Returns: 68 A _TPUEmbeddingColumn. 69 70 Raises: 71 ValueError: if `dimension` not > 0. 72 ValueError: if `initializer` is specified but not callable. 73 """ 74 if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): 75 raise TypeError( 76 'categorical_column for tpu ' 77 ' embedding_column must be type %s, got %s.' % (' or '.join([ 78 cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS 79 ]), type(categorical_column))) 80 if (dimension is None) or (dimension < 1): 81 raise ValueError('Invalid dimension {}.'.format(dimension)) 82 83 if (initializer is not None) and (not callable(initializer)): 84 raise ValueError('initializer must be callable if specified. ' 85 'Embedding of column_name: {}'.format( 86 categorical_column.name)) 87 if initializer is None: 88 initializer = init_ops.truncated_normal_initializer( 89 mean=0.0, stddev=1 / math.sqrt(dimension)) 90 91 embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access 92 93 def _creator(weight_collections, scope): 94 embedding_column_layer = fc._EmbeddingColumnLayer( 95 embedding_shape=embedding_shape, 96 initializer=initializer, 97 weight_collections=weight_collections, 98 trainable=True, 99 name='embedding_column_layer') 100 return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable 101 102 column = _TPUEmbeddingColumn( 103 categorical_column=categorical_column, 104 dimension=dimension, 105 combiner=combiner, 106 layer_creator=_creator, 107 ckpt_to_load_from=None, 108 tensor_name_in_ckpt=None, 109 max_norm=None, 110 trainable=True) 111 # For Embedding column, the initializer is hidden inside the creator Fn, which 112 # is not accessiable later. So, we attach it to a speicial field. Also note 113 # that non-TPU Embedding column and non-TPU shared Embedding column handle the 114 # initializer differently. See shared_embedding_columns for details. 115 column._tpu_initializer = initializer 116 return column 117 118 119def shared_embedding_columns(categorical_columns, 120 dimension, 121 combiner='mean', 122 initializer=None, 123 shared_embedding_collection_name=None): 124 """List of dense columns that convert from sparse, categorical input.""" 125 for categorical_column in categorical_columns: 126 if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): 127 raise TypeError( 128 'categorical_column for tpu ' 129 ' shared_embedding_columns must be type %s, got %s.' % (' or '.join([ 130 cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS 131 ]), type(categorical_column))) 132 columns = fc_lib.shared_embedding_columns( 133 categorical_columns, 134 dimension, 135 combiner=combiner, 136 initializer=initializer, 137 shared_embedding_collection_name=shared_embedding_collection_name, 138 ckpt_to_load_from=None, 139 tensor_name_in_ckpt=None, 140 max_norm=None, 141 trainable=True) 142 143 # Use the initializer and shared_embedding_collection_name to create TPU 144 # version 145 initializer = columns[0].initializer 146 shared_embedding_collection_name = columns[0].shared_embedding_collection_name 147 tpu_columns = [] 148 149 # Create the state (_SharedEmbeddingColumnLayer) here. 150 for categorical_column in categorical_columns: 151 column = _TPUSharedEmbeddingColumn( 152 categorical_column=categorical_column, 153 dimension=dimension, 154 combiner=combiner, 155 initializer=initializer, 156 shared_embedding_collection_name=shared_embedding_collection_name, 157 ckpt_to_load_from=None, 158 tensor_name_in_ckpt=None, 159 max_norm=None, 160 trainable=True) 161 tpu_columns.append(column) 162 163 return tpu_columns 164 165 166class _TPUBaseEmbeddingColumn(object): 167 """Base class for TPU Embedding Column.""" 168 169 def __init__(self, categorical_column): 170 self._tpu_categorical_column = categorical_column 171 172 def get_combiner(self): 173 """Returns the embedding combiner.""" 174 raise NotImplementedError('not implemented') 175 176 def get_embedding_table_size(self): 177 """Returns the embedding table size, tuple of vocab size and dimension.""" 178 raise NotImplementedError('not implemented') 179 180 def get_feature_key_name(self): 181 """Returns the feature key name in the features dict.""" 182 raise NotImplementedError('not impl') 183 184 def get_weight_key_name(self): 185 """Return the key name for weights.""" 186 raise NotImplementedError('not impl') 187 188 def get_embedding_var_name(self): 189 """Returns the embedding variable name. 190 191 Feature key name and embedding variable name are usually one-to-one mapping. 192 But for shared embedding columns, it is many-to-one mapping. 193 """ 194 raise NotImplementedError('not impl') 195 196 def get_initializer(self): 197 """Returns the initializer.""" 198 raise NotImplementedError('not impl') 199 200 def is_categorical_column_weighted(self): 201 """Check if the categorical column of the embedding column is weighted.""" 202 raise NotImplementedError('not impl') 203 204 205class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): 206 """Core Embedding Column.""" 207 208 def __new__(cls, 209 categorical_column, 210 dimension, 211 combiner='mean', 212 layer_creator=None, 213 ckpt_to_load_from=None, 214 tensor_name_in_ckpt=None, 215 max_norm=None, 216 trainable=True): 217 # Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable 218 # are not supported on TPU. They are solely for matching the signature of 219 # __new__ of parent class fc._EmbeddingColumn. 220 return fc._EmbeddingColumn.__new__( 221 cls, 222 categorical_column, 223 dimension, 224 combiner=combiner, 225 layer_creator=layer_creator, 226 ckpt_to_load_from=ckpt_to_load_from, 227 tensor_name_in_ckpt=tensor_name_in_ckpt, 228 max_norm=max_norm, 229 trainable=trainable) 230 231 def __init__(self, 232 categorical_column, 233 dimension, 234 combiner='mean', 235 layer_creator=None, 236 ckpt_to_load_from=None, 237 tensor_name_in_ckpt=None, 238 max_norm=None, 239 trainable=True): 240 _TPUBaseEmbeddingColumn.__init__(self, categorical_column) 241 self._key = None 242 243 def get_combiner(self): 244 return self.combiner 245 246 def get_embedding_table_size(self): 247 """Returns num_ids and width.""" 248 return (self.categorical_column._num_buckets, self.dimension) 249 250 def get_feature_key_name(self): 251 """get_feature_key_name.""" 252 if self.is_categorical_column_weighted(): 253 return self.categorical_column.categorical_column.name 254 return self.categorical_column.name 255 256 def get_weight_key_name(self): 257 """get_weight_key_name.""" 258 if self.is_categorical_column_weighted(): 259 return self.categorical_column.weight_feature_key 260 return None 261 262 def get_embedding_var_name(self): 263 """get_embedding_var_name.""" 264 return self.categorical_column.name 265 266 def get_initializer(self): 267 return self._tpu_initializer 268 269 def is_categorical_column_weighted(self): 270 """Check if the categorical column of the embedding column is weighted.""" 271 if isinstance( 272 self.categorical_column, 273 ( 274 fc._WeightedCategoricalColumn, # pylint: disable=protected-access 275 fc_lib.WeightedCategoricalColumn)): 276 return True 277 return False 278 279 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 280 if tpu.under_tpu_inference_context(): 281 def host_computation(): 282 return fc._EmbeddingColumn._get_dense_tensor( 283 self, inputs, weight_collections, trainable) 284 return tpu.outside_compilation(host_computation) 285 286 if _is_running_on_cpu(): 287 return fc._EmbeddingColumn._get_dense_tensor( 288 self, inputs, weight_collections, trainable) 289 290 # TPU mode 291 # Get the embeddings from the LazyBuilder. 292 tensor = inputs.get(self.get_feature_key_name()) 293 294 # Add to collection for _create_tpu_embedding_variables_and_ops 295 _record_variable_scope_and_name(self.get_embedding_var_name(), 296 'embedding_weights') 297 298 return tensor 299 300 301class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, 302 fc._SharedEmbeddingColumn): 303 """Core Shared Embedding Column.""" 304 305 def __new__(cls, 306 categorical_column, 307 dimension, 308 combiner='mean', 309 initializer=None, 310 shared_embedding_collection_name=None, 311 ckpt_to_load_from=None, 312 tensor_name_in_ckpt=None, 313 max_norm=None, 314 trainable=True): 315 return fc._SharedEmbeddingColumn.__new__( 316 cls, 317 categorical_column, 318 dimension, 319 combiner=combiner, 320 initializer=initializer, 321 shared_embedding_collection_name=shared_embedding_collection_name, 322 ckpt_to_load_from=ckpt_to_load_from, 323 tensor_name_in_ckpt=tensor_name_in_ckpt, 324 max_norm=max_norm, 325 trainable=trainable) 326 327 def __init__(self, 328 categorical_column, 329 dimension, 330 combiner='mean', 331 initializer=None, 332 shared_embedding_collection_name=None, 333 ckpt_to_load_from=None, 334 tensor_name_in_ckpt=None, 335 max_norm=None, 336 trainable=True): 337 338 _TPUBaseEmbeddingColumn.__init__(self, categorical_column) 339 self._key = None 340 341 def get_combiner(self): 342 return self.combiner 343 344 def get_embedding_table_size(self): 345 """Returns num_ids and width.""" 346 return (self.categorical_column._num_buckets, self.dimension) 347 348 def get_feature_key_name(self): 349 """get_feature_key_name.""" 350 if self.is_categorical_column_weighted(): 351 return self.categorical_column.categorical_column.name 352 return self.categorical_column.name 353 354 def get_weight_key_name(self): 355 """get_weight_key_name.""" 356 if self.is_categorical_column_weighted(): 357 return self.categorical_column.weight_feature_key 358 return None 359 360 def get_embedding_var_name(self): 361 """get_embedding_var_name.""" 362 return self.shared_embedding_collection_name 363 364 def get_initializer(self): 365 return self.initializer 366 367 def is_categorical_column_weighted(self): 368 """Check if the categorical column of the embedding column is weighted.""" 369 if isinstance( 370 self.categorical_column, 371 ( 372 fc._WeightedCategoricalColumn, # pylint: disable=protected-access 373 fc_lib.WeightedCategoricalColumn)): 374 return True 375 return False 376 377 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 378 if tpu.under_tpu_inference_context(): 379 def host_computation(): 380 return fc._SharedEmbeddingColumn._get_dense_tensor( 381 self, inputs, weight_collections, trainable) 382 return tpu.outside_compilation(host_computation) 383 384 if _is_running_on_cpu(): 385 return fc._SharedEmbeddingColumn._get_dense_tensor( 386 self, inputs, weight_collections, trainable) 387 388 # TPU mode 389 # Get the embeddings from the LazyBuilder. 390 tensor = inputs.get(self.get_feature_key_name()) 391 392 # Add to collection for _create_tpu_embedding_variables_and_ops 393 _record_variable_scope_and_name( 394 self.get_embedding_var_name(), 395 'embedding_weights', 396 is_shared_embedding=True) 397 return tensor 398 399 400def _record_variable_scope_and_name(embedding_var_name, 401 embedding_var_name_in_fc, 402 is_shared_embedding=False): 403 """Add embedding variable name and scope to collection.""" 404 g = ops.get_default_graph() 405 collection = g.get_collection_ref(_TPU_FC_TO_SCOPE) 406 if not collection: 407 collection.append({}) 408 409 var_def_dict = collection[0] 410 411 captured_scope = variable_scope.get_variable_scope() 412 captured_scope_name = captured_scope.name 413 414 if embedding_var_name in var_def_dict: 415 if (var_def_dict[embedding_var_name][0] != captured_scope_name 416 and not is_shared_embedding): 417 raise ValueError( 418 'For embedding var name {}, the variable scope name is different, ' 419 'got {}; expected {}'.format(embedding_var_name, 420 captured_scope_name, 421 var_def_dict[embedding_var_name][0])) 422 if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc: 423 raise ValueError( 424 'For embedding var name {}, the embedding name is different, ' 425 'got {}; expected {}'.format(embedding_var_name, 426 embedding_var_name_in_fc, 427 var_def_dict[embedding_var_name][1])) 428 else: 429 var_def_dict[embedding_var_name] = (captured_scope_name, 430 embedding_var_name_in_fc) 431 432 433def _is_running_on_cpu(): 434 """Returns True if the current context is CPU model.""" 435 return tpu_function.get_tpu_context().number_of_shards is None 436