1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Version 2 of class Optimizer.""" 17# pylint: disable=g-bad-name 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import abc 24import functools 25 26import six 27 28from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 29from tensorflow.python.distribute import reduce_util as ds_reduce_util 30from tensorflow.python.distribute import values as distributed_values 31from tensorflow.python.eager import backprop 32from tensorflow.python.eager import context 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import ops 35from tensorflow.python.keras import backend 36from tensorflow.python.keras import initializers 37from tensorflow.python.keras.engine import base_layer_utils 38from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 39from tensorflow.python.keras.utils import tf_utils 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import clip_ops 42from tensorflow.python.ops import gradients 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import resource_variable_ops 45from tensorflow.python.ops import variables as tf_variables 46from tensorflow.python.platform import tf_logging as logging 47from tensorflow.python.saved_model import revived_types 48from tensorflow.python.training.tracking import base as trackable 49from tensorflow.python.util import nest 50from tensorflow.python.util.tf_export import keras_export 51 52 53def _deduplicate_indexed_slices(values, indices): 54 """Sums `values` associated with any non-unique `indices`. 55 56 Args: 57 values: A `Tensor` with rank >= 1. 58 indices: A one-dimensional integer `Tensor`, indexing into the first 59 dimension of `values` (as in an IndexedSlices object). 60 61 Returns: 62 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a 63 de-duplicated version of `indices` and `summed_values` contains the sum of 64 `values` slices associated with each unique index. 65 """ 66 unique_indices, new_index_positions = array_ops.unique(indices) 67 summed_values = math_ops.unsorted_segment_sum( 68 values, new_index_positions, 69 array_ops.shape(unique_indices)[0]) 70 return (summed_values, unique_indices) 71 72 73@six.add_metaclass(abc.ABCMeta) 74@keras_export("keras.optimizers.Optimizer") 75class OptimizerV2(trackable.Trackable): 76 """Updated base class for optimizers. 77 78 This class defines the API to add Ops to train a model. You never use this 79 class directly, but instead instantiate one of its subclasses such as 80 `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`. 81 82 ### Usage 83 84 ```python 85 # Create an optimizer with the desired parameters. 86 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 87 # `loss` is a callable that takes no argument and returns the value 88 # to minimize. 89 loss = lambda: 3 * var1 * var1 + 2 * var2 * var2 90 # In graph mode, returns op that minimizes the loss by updating the listed 91 # variables. 92 opt_op = opt.minimize(loss, var_list=[var1, var2]) 93 opt_op.run() 94 # In eager mode, simply call minimize to update the list of variables. 95 opt.minimize(loss, var_list=[var1, var2]) 96 ``` 97 98 ### Processing gradients before applying them. 99 100 Calling `minimize()` takes care of both computing the gradients and 101 applying them to the variables. If you want to process the gradients 102 before applying them you can instead use the optimizer in three steps: 103 104 1. Compute the gradients with `tf.GradientTape`. 105 2. Process the gradients as you wish. 106 3. Apply the processed gradients with `apply_gradients()`. 107 108 Example: 109 110 ```python 111 # Create an optimizer. 112 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 113 114 # Compute the gradients for a list of variables. 115 with tf.GradientTape() as tape: 116 loss = <call_loss_function> 117 vars = <list_of_variables> 118 grads = tape.gradient(loss, vars) 119 processed_grads = [process_gradient(g) for g in grads] 120 grads_and_vars = zip(processed_grads, var_list) 121 122 # grads_and_vars is a list of tuples (gradient, variable). Do whatever you 123 # need to the 'gradient' part, for example cap them, etc. 124 capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars] 125 126 # Ask the optimizer to apply the capped gradients. 127 opt.apply_gradients(capped_grads_and_vars) 128 ``` 129 130 ### Use with `tf.distribute.Strategy`. 131 132 This optimizer class is `tf.distribute.Strategy` aware, which means it 133 automatically sums gradients across all replicas. To average gradients, 134 you divide your loss by the global batch size, which is done automatically 135 if you use a member of `tf.keras.losses` or `tf.losses`. See the 136 `reduction` argument of your loss which should be set to 137 `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or 138 `tf.keras.losses.Reduction.SUM` for not. 139 140 If you are not using these and you want to average gradients, you should use 141 `tf.math.reduce_sum` to add up your per-example losses and then divide by the 142 global batch size. Note that when using `tf.distribute.Strategy`, the first 143 component of a tensor's shape is the *replica-local* batch size, which is off 144 by a factor equal to the number of replicas being used to compute a single 145 step. As a result, using `tf.math.reduce_mean` will give the wrong answer, 146 resulting in gradients that can be many times too big. 147 148 ### Variable Constraint 149 150 All Keras optimizers respect variable constraints. If constraint function is 151 passed to any variable, the constraint will be applied to the variable after 152 the gradient has been applied to the variable. 153 Important: If gradient is sparse tensor, variable constraint is not supported. 154 155 ### Thread Compatibility 156 157 The entire optimizer is currently thread compatible, not thread-safe. The user 158 needs to perform synchronization if necessary. 159 160 ### Slots 161 162 Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage 163 additional variables associated with the variables to train. These are called 164 <i>Slots</i>. Slots have names and you can ask the optimizer for the names of 165 the slots that it uses. Once you have a slot name you can ask the optimizer 166 for the variable it created to hold the slot value. 167 168 This can be useful if you want to log debug a training algorithm, report stats 169 about the slots, etc. 170 171 ### Hyper parameters 172 173 These are arguments passed to the optimizer subclass constructor 174 (the `__init__` method), and then passed to `self._set_hyper()`. 175 They can be either regular Python values (like 1.0), tensors, or 176 callables. If they are callable, the callable will be called during 177 `apply_gradients()` to get the value for the hyper parameter. 178 179 Hyper parameters can be overwritten through user code: 180 181 Example: 182 183 ```python 184 # Create an optimizer with the desired parameters. 185 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 186 # `loss` is a callable that takes no argument and returns the value 187 # to minimize. 188 loss = lambda: 3 * var1 + 2 * var2 189 # In eager mode, simply call minimize to update the list of variables. 190 opt.minimize(loss, var_list=[var1, var2]) 191 # update learning rate 192 opt.learning_rate = 0.05 193 opt.minimize(loss, var_list=[var1, var2]) 194 ``` 195 196 ### Write a customized optimizer. 197 If you intend to create your own optimization algorithm, simply inherit from 198 this class and override the following methods: 199 200 - resource_apply_dense (update variable given gradient tensor is dense) 201 - resource_apply_sparse (update variable given gradient tensor is sparse) 202 - create_slots (if your optimizer algorithm requires additional variables) 203 - get_config (serialization of the optimizer, include all hyper parameters) 204 """ 205 206 def __init__(self, name, **kwargs): 207 """Create a new Optimizer. 208 209 This must be called by the constructors of subclasses. 210 Note that Optimizer instances should not bind to a single graph, 211 and so shouldn't keep Tensors as member variables. Generally 212 you should be able to use the _set_hyper()/state.get_hyper() 213 facility instead. 214 215 This class in stateful and thread-compatible. 216 217 Args: 218 name: A non-empty string. The name to use for accumulators created 219 for the optimizer. 220 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, 221 `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip 222 gradients by value, `decay` is included for backward compatibility to 223 allow time inverse decay of learning rate. `lr` is included for backward 224 compatibility, recommended to use `learning_rate` instead. 225 226 Raises: 227 ValueError: If name is malformed. 228 RuntimeError: If _create_slots has been overridden instead of 229 _create_vars. 230 """ 231 allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay"} 232 for k in kwargs: 233 if k not in allowed_kwargs: 234 raise TypeError("Unexpected keyword argument " 235 "passed to optimizer: " + str(k)) 236 # checks that all keyword arguments are non-negative. 237 if kwargs[k] < 0: 238 raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k])) 239 240 self._use_locking = True 241 self._name = name 242 self._hyper = {} 243 # dict: {variable name : {slot name : variable}} 244 self._slots = {} 245 self._slot_names = [] 246 self._weights = [] 247 self._iterations = None 248 249 # For implementing Trackable. Stores information about how to restore 250 # slot variables which have not yet been created 251 # (trackable._CheckpointPosition objects). 252 # {slot_name : 253 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, 254 # ... } 255 self._deferred_slot_restorations = {} 256 257 decay = kwargs.pop("decay", 0.0) 258 if decay < 0.: 259 raise ValueError("decay cannot be less than 0: {}".format(decay)) 260 self._initial_decay = decay 261 if "clipnorm" in kwargs: 262 self.clipnorm = kwargs.pop("clipnorm") 263 if "clipvalue" in kwargs: 264 self.clipvalue = kwargs.pop("clipvalue") 265 266 self._hypers_created = False 267 268 def minimize(self, loss, var_list, grad_loss=None, name=None): 269 """Add operations to minimize `loss` by updating `var_list`. 270 271 This method simply computes gradient using `tf.GradientTape` and calls 272 `apply_gradients()`. If you want to process the gradient before applying 273 then call `tf.GradientTape` and `apply_gradients()` explicitly instead 274 of using this function. 275 276 Args: 277 loss: A callable taking no arguments which returns the value to minimize. 278 var_list: list or tuple of `Variable` objects to update to minimize 279 `loss`. 280 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 281 name: Optional name for the returned operation. 282 283 Returns: 284 An Operation that updates the variables in `var_list`. If `global_step` 285 was not `None`, that operation also increments `global_step`. 286 287 Raises: 288 ValueError: If some of the variables are not `Variable` objects. 289 290 @compatibility(eager) 291 When eager execution is enabled, `loss` should be a Python function that 292 takes no arguments and computes the value to be minimized. Minimization (and 293 gradient computation) is done with respect to the elements of `var_list`. 294 `grad_loss` is ignored when eager execution is enabled. 295 @end_compatibility 296 """ 297 grads_and_vars = self._compute_gradients( 298 loss, var_list=var_list, grad_loss=grad_loss) 299 300 return self.apply_gradients(grads_and_vars, name=name) 301 302 def _compute_gradients(self, loss, var_list, grad_loss=None): 303 """Compute gradients of `loss` for the variables in `var_list`. 304 305 This is the first part of `minimize()`. It returns a list 306 of (gradient, variable) pairs where "gradient" is the gradient 307 for "variable". Note that "gradient" can be a `Tensor`, an 308 `IndexedSlices`, or `None` if there is no gradient for the 309 given variable. 310 311 Args: 312 loss: A callable taking no arguments which returns the value to minimize. 313 var_list: List or tuple of `tf.Variable` to update to minimize 314 `loss`. Defaults to the list of variables collected in the graph under 315 the key `GraphKeys.TRAINABLE_VARIABLES`. 316 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 317 318 Returns: 319 A list of (gradient, variable) pairs. Variable is always present, but 320 gradient can be `None`. 321 322 Raises: 323 TypeError: If `var_list` contains anything else than `Variable` objects. 324 ValueError: If some arguments are invalid, or var_list is None. 325 """ 326 var_list = nest.flatten(var_list) 327 # TODO(josh11b): Test that we handle weight decay in a reasonable way. 328 with backprop.GradientTape() as tape: 329 tape.watch(var_list) 330 loss_value = loss() 331 grads = tape.gradient(loss_value, var_list, grad_loss) 332 333 if hasattr(self, "clipnorm"): 334 grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads] 335 if hasattr(self, "clipvalue"): 336 grads = [ 337 clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue) 338 for g in grads 339 ] 340 341 grads_and_vars = list(zip(grads, var_list)) 342 self._assert_valid_dtypes([ 343 v for g, v in grads_and_vars 344 if g is not None and v.dtype != dtypes.resource 345 ]) 346 347 return grads_and_vars 348 349 def get_gradients(self, loss, params): 350 """Returns gradients of `loss` with respect to `params`. 351 352 Arguments: 353 loss: Loss tensor. 354 params: List of variables. 355 356 Returns: 357 List of gradient tensors. 358 359 Raises: 360 ValueError: In case any gradient cannot be computed (e.g. if gradient 361 function not implemented). 362 """ 363 with backend.get_graph().as_default(): 364 grads = gradients.gradients(loss, params) 365 if None in grads: 366 raise ValueError("An operation has `None` for gradient. " 367 "Please make sure that all of your ops have a " 368 "gradient defined (i.e. are differentiable). " 369 "Common ops without gradient: " 370 "K.argmax, K.round, K.eval.") 371 if hasattr(self, "clipnorm"): 372 grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads] 373 if hasattr(self, "clipvalue"): 374 grads = [ 375 clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue) 376 for g in grads 377 ] 378 return grads 379 380 def apply_gradients(self, grads_and_vars, name=None): 381 """Apply gradients to variables. 382 383 This is the second part of `minimize()`. It returns an `Operation` that 384 applies gradients. 385 386 Args: 387 grads_and_vars: List of (gradient, variable) pairs. 388 name: Optional name for the returned operation. Default to the name 389 passed to the `Optimizer` constructor. 390 391 Returns: 392 An `Operation` that applies the specified gradients. If `global_step` 393 was not None, that operation also increments `global_step`. 394 395 Raises: 396 TypeError: If `grads_and_vars` is malformed. 397 ValueError: If none of the variables have gradients. 398 """ 399 grads_and_vars = _filter_grads(grads_and_vars) 400 var_list = [v for (_, v) in grads_and_vars] 401 402 # Create iteration if necessary. 403 _ = self.iterations 404 self._create_hypers() 405 with ops.init_scope(): 406 self._create_slots(var_list) 407 408 self._prepare(var_list) 409 410 return distribute_ctx.get_replica_context().merge_call( 411 self._distributed_apply, args=(grads_and_vars,), kwargs={"name": name}) 412 413 def _distributed_apply(self, distribution, grads_and_vars, name): 414 """`apply_gradients` using a `DistributionStrategy`.""" 415 reduced_grads = distribution.extended.batch_reduce_to( 416 ds_reduce_util.ReduceOp.SUM, grads_and_vars) 417 var_list = [v for _, v in grads_and_vars] 418 grads_and_vars = zip(reduced_grads, var_list) 419 420 def apply_grad_to_update_var(var, grad): 421 """Apply gradient to variable.""" 422 if isinstance(var, ops.Tensor): 423 raise NotImplementedError("Trying to update a Tensor ", var) 424 if isinstance(grad, ops.IndexedSlices): 425 if var.constraint is not None: 426 raise RuntimeError( 427 "Cannot use a constraint function on a sparse variable.") 428 return self._resource_apply_sparse_duplicate_indices( 429 grad.values, var, grad.indices) 430 update_op = self._resource_apply_dense(grad, var) 431 if var.constraint is not None: 432 with ops.control_dependencies([update_op]): 433 return var.assign(var.constraint(var)) 434 else: 435 return update_op 436 437 update_ops = [] 438 with ops.name_scope(name, self._name) as name: 439 for grad, var in grads_and_vars: 440 scope_name = ("" if ops.executing_eagerly_outside_functions() else 441 "_" + var.op.name) 442 with ops.name_scope("update" + scope_name): 443 update_ops.extend( 444 distribution.extended.update( 445 var, apply_grad_to_update_var, args=(grad,), group=False)) 446 447 any_symbolic = any(isinstance(i, ops.Operation) or 448 tf_utils.is_symbolic_tensor(i) for i in update_ops) 449 if not context.executing_eagerly() or any_symbolic: 450 # If the current context is graph mode or any of the update ops are 451 # symbolic then the step update should be carried out under a graph 452 # context. (eager updates execute immediately) 453 with ops._get_graph_from_inputs(update_ops).as_default(): # pylint: disable=protected-access 454 with ops.control_dependencies(update_ops): 455 return self._iterations.assign_add(1).op 456 457 return self._iterations.assign_add(1) 458 459 def get_updates(self, loss, params): 460 grads = self.get_gradients(loss, params) 461 grads_and_vars = list(zip(grads, params)) 462 self._assert_valid_dtypes([ 463 v for g, v in grads_and_vars 464 if g is not None and v.dtype != dtypes.resource 465 ]) 466 return [self.apply_gradients(grads_and_vars)] 467 468 def _set_hyper(self, name, value): 469 """set hyper `name` to value. value can be callable, tensor, numeric.""" 470 if isinstance(value, trackable.Trackable): 471 self._track_trackable(value, name, overwrite=True) 472 if name not in self._hyper: 473 self._hyper[name] = value 474 else: 475 prev_value = self._hyper[name] 476 if (callable(prev_value) 477 or isinstance(prev_value, 478 (ops.Tensor, int, float, 479 learning_rate_schedule.LearningRateSchedule)) 480 or isinstance(value, learning_rate_schedule.LearningRateSchedule)): 481 self._hyper[name] = value 482 else: 483 backend.set_value(self._hyper[name], value) 484 485 def _get_hyper(self, name, dtype=None): 486 if not self._hypers_created: 487 self._create_hypers() 488 value = self._hyper[name] 489 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 490 return value 491 if callable(value): 492 value = value() 493 if dtype: 494 return math_ops.cast(value, dtype) 495 else: 496 return value 497 498 def __getattribute__(self, name): 499 """Overridden to support hyperparameter access.""" 500 try: 501 return super(OptimizerV2, self).__getattribute__(name) 502 except AttributeError as e: 503 # Needed to avoid infinite recursion with __setattr__. 504 if name == "_hyper": 505 raise e 506 # Backwards compatibility with Keras optimizers. 507 if name == "lr": 508 name = "learning_rate" 509 if name in self._hyper: 510 return self._get_hyper(name) 511 raise e 512 513 def __setattr__(self, name, value): 514 """Override setattr to support dynamic hyperparameter setting.""" 515 # Backwards compatibility with Keras optimizers. 516 if name == "lr": 517 name = "learning_rate" 518 if hasattr(self, "_hyper") and name in self._hyper: 519 self._set_hyper(name, value) 520 else: 521 super(OptimizerV2, self).__setattr__(name, value) 522 523 def get_slot_names(self): 524 """A list of names for this optimizer's slots.""" 525 return self._slot_names 526 527 def add_slot(self, var, slot_name, initializer="zeros"): 528 """Add a new slot variable for `var`.""" 529 if slot_name not in self._slot_names: 530 self._slot_names.append(slot_name) 531 var_key = _var_key(var) 532 slot_dict = self._slots.setdefault(var_key, {}) 533 weight = slot_dict.get(slot_name, None) 534 if weight is None: 535 if isinstance(initializer, six.string_types) or callable(initializer): 536 initializer = initializers.get(initializer) 537 initial_value = functools.partial( 538 initializer, shape=var.shape, dtype=var.dtype) 539 else: 540 initial_value = initializer 541 strategy = distribute_ctx.get_strategy() 542 with strategy.colocate_vars_with(var): 543 weight = tf_variables.Variable( 544 name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access 545 dtype=var.dtype, 546 trainable=False, 547 initial_value=initial_value) 548 backend.track_variable(weight) 549 slot_dict[slot_name] = weight 550 self._restore_slot_variable( 551 slot_name=slot_name, variable=var, 552 slot_variable=weight) 553 self._weights.append(weight) 554 return weight 555 556 def get_slot(self, var, slot_name): 557 var_key = _var_key(var) 558 slot_dict = self._slots[var_key] 559 return slot_dict[slot_name] 560 561 def _prepare(self, var_list): 562 pass 563 564 def _create_hypers(self): 565 if self._hypers_created: 566 return 567 # Iterate hyper values deterministically. 568 for name, value in sorted(self._hyper.items()): 569 if isinstance(value, ops.Tensor) or callable(value): 570 continue 571 else: 572 self._hyper[name] = self.add_weight( 573 name, 574 shape=[], 575 trainable=False, 576 initializer=value, 577 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 578 self._hypers_created = True 579 580 @property 581 def iterations(self): 582 """Variable. The number of training steps this Optimizer has run.""" 583 if self._iterations is None: 584 self._iterations = self.add_weight( 585 "iter", 586 shape=[], 587 dtype=dtypes.int64, 588 trainable=False, 589 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 590 self._weights.append(self._iterations) 591 return self._iterations 592 593 @iterations.setter 594 def iterations(self, variable): 595 if self._iterations is not None: 596 raise RuntimeError("Cannot set `iterations` to a new Variable after" 597 "the Optimizer weights have been created") 598 self._iterations = variable 599 self._weights.append(self._iterations) 600 601 def _decayed_lr(self, var_dtype): 602 """Get decayed learning rate as a Tensor with dtype=var_dtype.""" 603 lr_t = self._get_hyper("learning_rate", var_dtype) 604 if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule): 605 local_step = math_ops.cast(self.iterations, var_dtype) 606 lr_t = math_ops.cast(lr_t(local_step), var_dtype) 607 if self._initial_decay > 0.: 608 local_step = math_ops.cast(self.iterations, var_dtype) 609 decay_t = self._get_hyper("decay", var_dtype) 610 lr_t = lr_t / (1. + decay_t * local_step) 611 return lr_t 612 613 @abc.abstractmethod 614 def get_config(self): 615 """Returns the config of the optimimizer. 616 617 An optimizer config is a Python dictionary (serializable) 618 containing the configuration of an optimizer. 619 The same optimizer can be reinstantiated later 620 (without any saved state) from this configuration. 621 622 Returns: 623 Python dictionary. 624 """ 625 config = {"name": self._name} 626 if hasattr(self, "clipnorm"): 627 config["clipnorm"] = self.clipnorm 628 if hasattr(self, "clipvalue"): 629 config["clipvalue"] = self.clipvalue 630 return config 631 632 @classmethod 633 def from_config(cls, config, custom_objects=None): 634 """Creates an optimizer from its config. 635 636 This method is the reverse of `get_config`, 637 capable of instantiating the same optimizer from the config 638 dictionary. 639 640 Arguments: 641 config: A Python dictionary, typically the output of get_config. 642 custom_objects: A Python dictionary mapping names to additional Python 643 objects used to create this optimizer, such as a function used for a 644 hyperparameter. 645 646 Returns: 647 An optimizer instance. 648 """ 649 if "lr" in config: 650 config["learning_rate"] = config.pop("lr") 651 if "learning_rate" in config: 652 if isinstance(config["learning_rate"], dict): 653 config["learning_rate"] = learning_rate_schedule.deserialize( 654 config["learning_rate"]) 655 return cls(**config) 656 657 def _serialize_hyperparameter(self, hyperparameter_name): 658 """Serialize a hyperparameter that can be a float, callable, or Tensor.""" 659 value = self._hyper[hyperparameter_name] 660 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 661 return learning_rate_schedule.serialize(value) 662 if callable(value): 663 return value() 664 if isinstance(value, (ops.Tensor, tf_variables.Variable, 665 distributed_values.TPUMirroredVariable, 666 distributed_values.DistributedVariable)): 667 return backend.get_value(value) 668 return value 669 670 def variables(self): 671 """Returns variables of this Optimizer based on the order created.""" 672 return self._weights 673 674 @property 675 def weights(self): 676 """Returns variables of this Optimizer based on the order created.""" 677 return self._weights 678 679 def get_weights(self): 680 params = self.weights 681 return backend.batch_get_value(params) 682 683 # TODO(tanzheny): Maybe share this logic with base_layer. 684 def set_weights(self, weights): 685 params = self.weights 686 if len(params) != len(weights): 687 raise ValueError( 688 "You called `set_weights(weights)` on optimizer " + self._name + 689 " with a weight list of length " + str(len(weights)) + 690 ", but the optimizer was expecting " + str(len(params)) + 691 " weights. Provided weights: " + str(weights)[:50] + "...") 692 if not params: 693 return 694 weight_value_tuples = [] 695 param_values = backend.batch_get_value(params) 696 for pv, p, w in zip(param_values, params, weights): 697 if pv.shape != w.shape: 698 raise ValueError("Optimizer weight shape " + str(pv.shape) + 699 " not compatible with " 700 "provided weight shape " + str(w.shape)) 701 weight_value_tuples.append((p, w)) 702 backend.batch_set_value(weight_value_tuples) 703 704 def add_weight(self, 705 name, 706 shape, 707 dtype=None, 708 initializer="zeros", 709 trainable=None, 710 synchronization=tf_variables.VariableSynchronization.AUTO, 711 aggregation=tf_variables.VariableAggregation.NONE): 712 713 if dtype is None: 714 dtype = dtypes.float32 715 if isinstance(initializer, six.string_types) or callable(initializer): 716 initializer = initializers.get(initializer) 717 718 if synchronization == tf_variables.VariableSynchronization.ON_READ: 719 if trainable: 720 raise ValueError( 721 "Synchronization value can be set to " 722 "VariableSynchronization.ON_READ only for non-trainable variables. " 723 "You have specified trainable=True and " 724 "synchronization=VariableSynchronization.ON_READ.") 725 else: 726 # Set trainable to be false when variable is to be synced on read. 727 trainable = False 728 elif trainable is None: 729 trainable = True 730 731 variable = self._add_variable_with_custom_getter( 732 name=name, 733 shape=shape, 734 getter=base_layer_utils.make_variable, 735 overwrite=True, 736 initializer=initializer, 737 dtype=dtype, 738 trainable=trainable, 739 use_resource=True, 740 synchronization=synchronization, 741 aggregation=aggregation) 742 backend.track_variable(variable) 743 744 return variable 745 746 def _assert_valid_dtypes(self, tensors): 747 """Asserts tensors are all valid types (see `_valid_dtypes`). 748 749 Args: 750 tensors: Tensors to check. 751 752 Raises: 753 ValueError: If any tensor is not a valid type. 754 """ 755 valid_dtypes = self._valid_dtypes() 756 for t in tensors: 757 dtype = t.dtype.base_dtype 758 if dtype not in valid_dtypes: 759 raise ValueError("Invalid type %r for %s, expected: %s." % 760 (dtype, t.name, [v for v in valid_dtypes])) 761 762 def _valid_dtypes(self): 763 """Valid types for loss, variables and gradients. 764 765 Subclasses should override to allow other float types. 766 767 Returns: 768 Valid types for loss, variables and gradients. 769 """ 770 return set( 771 [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]) 772 773 def _call_if_callable(self, param): 774 """Call the function if param is callable.""" 775 return param() if callable(param) else param 776 777 def _resource_apply_dense(self, grad, handle): 778 """Add ops to apply dense gradients to the variable `handle`. 779 780 Args: 781 grad: a `Tensor` representing the gradient. 782 handle: a `Tensor` of dtype `resource` which points to the variable to be 783 updated. 784 785 Returns: 786 An `Operation` which updates the value of the variable. 787 """ 788 raise NotImplementedError() 789 790 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): 791 """Add ops to apply sparse gradients to `handle`, with repeated indices. 792 793 Optimizers which override this method must deal with repeated indices. See 794 the docstring of `_apply_sparse_duplicate_indices` for details. By default 795 the correct behavior, to sum non-unique indices and their associated 796 gradients, is enforced by first pre-processing `grad` and `indices` and 797 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly 798 with duplicate indices may instead override this method to avoid the 799 overhead of summing. 800 801 Args: 802 grad: a `Tensor` representing the gradient for the affected indices. 803 handle: a `Tensor` of dtype `resource` which points to the variable to be 804 updated. 805 indices: a `Tensor` of integral type representing the indices for which 806 the gradient is nonzero. Indices may be repeated. 807 808 Returns: 809 An `Operation` which updates the value of the variable. 810 """ 811 summed_grad, unique_indices = _deduplicate_indexed_slices( 812 values=grad, indices=indices) 813 return self._resource_apply_sparse(summed_grad, handle, unique_indices) 814 815 def _resource_apply_sparse(self, grad, handle, indices): 816 """Add ops to apply sparse gradients to the variable `handle`. 817 818 Similar to `_apply_sparse`, the `indices` argument to this method has been 819 de-duplicated. Optimizers which deal correctly with non-unique indices may 820 instead override `_resource_apply_sparse_duplicate_indices` to avoid this 821 overhead. 822 823 Args: 824 grad: a `Tensor` representing the gradient for the affected indices. 825 handle: a `Tensor` of dtype `resource` which points to the variable to be 826 updated. 827 indices: a `Tensor` of integral type representing the indices for which 828 the gradient is nonzero. Indices are unique. 829 830 Returns: 831 An `Operation` which updates the value of the variable. 832 """ 833 raise NotImplementedError() 834 835 def _resource_scatter_add(self, x, i, v): 836 with ops.control_dependencies( 837 [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): 838 return x.value() 839 840 def _resource_scatter_update(self, x, i, v): 841 with ops.control_dependencies( 842 [resource_variable_ops.resource_scatter_update(x.handle, i, v)]): 843 return x.value() 844 845 # --------------- 846 # For implementing the trackable interface 847 # --------------- 848 849 def _restore_slot_variable(self, slot_name, variable, slot_variable): 850 """Restore a newly created slot variable's value.""" 851 variable_key = _var_key(variable) 852 deferred_restorations = self._deferred_slot_restorations.get( 853 slot_name, {}).pop(variable_key, []) 854 # Iterate over restores, highest restore UID first to minimize the number 855 # of assignments. 856 deferred_restorations.sort(key=lambda position: position.restore_uid, 857 reverse=True) 858 for checkpoint_position in deferred_restorations: 859 checkpoint_position.restore(slot_variable) 860 861 def _create_or_restore_slot_variable( 862 self, slot_variable_position, slot_name, variable): 863 """Restore a slot variable's value, possibly creating it. 864 865 Called when a variable which has an associated slot variable is created or 866 restored. When executing eagerly, we create the slot variable with a 867 restoring initializer. 868 869 No new variables are created when graph building. Instead, 870 _restore_slot_variable catches these after normal creation and adds restore 871 ops to the graph. This method is nonetheless important when graph building 872 for the case when a slot variable has already been created but `variable` 873 has just been added to a dependency graph (causing us to realize that the 874 slot variable needs to be restored). 875 876 Args: 877 slot_variable_position: A `trackable._CheckpointPosition` object 878 indicating the slot variable `Trackable` object to be restored. 879 slot_name: The name of this `Optimizer`'s slot to restore into. 880 variable: The variable object this slot is being created for. 881 """ 882 variable_key = _var_key(variable) 883 slot_dict = self._slots.get(variable_key, {}) 884 slot_variable = slot_dict.get(slot_name, None) 885 if (slot_variable is None and context.executing_eagerly() and 886 slot_variable_position.is_simple_variable() 887 # Defer slot variable creation if there is an active variable creator 888 # scope. Generally we'd like to eagerly create/restore slot variables 889 # when possible, but this may mean that scopes intended to catch 890 # `variable` also catch its eagerly created slot variable 891 # unintentionally (specifically make_template would add a dependency on 892 # a slot variable if not for this case). Deferring is mostly harmless 893 # (aside from double initialization), and makes variable creator scopes 894 # behave the same way they do when graph building. 895 and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access 896 initializer = trackable.CheckpointInitialValue( 897 checkpoint_position=slot_variable_position) 898 slot_variable = self.add_slot( 899 var=variable, 900 initializer=initializer, 901 slot_name=slot_name) 902 # Slot variables are not owned by any one object (because we don't want to 903 # save the slot variable if the optimizer is saved without the non-slot 904 # variable, or if the non-slot variable is saved without the optimizer; 905 # it's a dependency hypergraph with edges of the form (optimizer, non-slot 906 # variable, variable)). So we don't _track_ slot variables anywhere, and 907 # instead special-case this dependency and otherwise pretend it's a normal 908 # graph. 909 if slot_variable is not None: 910 # If we've either made this slot variable, or if we've pulled out an 911 # existing slot variable, we should restore it. 912 slot_variable_position.restore(slot_variable) 913 else: 914 # We didn't make the slot variable. Defer restoring until it gets created 915 # normally. We keep a list rather than the one with the highest restore 916 # UID in case slot variables have their own dependencies, in which case 917 # those could differ between restores. 918 self._deferred_slot_restorations.setdefault( 919 slot_name, {}).setdefault(variable_key, []).append( 920 slot_variable_position) 921 922 923def _filter_grads(grads_and_vars): 924 """Filter out iterable with grad equal to None.""" 925 grads_and_vars = tuple(grads_and_vars) 926 if not grads_and_vars: 927 return grads_and_vars 928 filtered = [] 929 vars_with_empty_grads = [] 930 for grad, var in grads_and_vars: 931 if grad is None: 932 vars_with_empty_grads.append(var) 933 else: 934 filtered.append((grad, var)) 935 filtered = tuple(filtered) 936 if not filtered: 937 raise ValueError("No gradients provided for any variable: %s." % 938 ([v.name for _, v in grads_and_vars],)) 939 if vars_with_empty_grads: 940 logging.warning( 941 ("Gradients does not exist for variables %s when minimizing the loss."), 942 ([v.name for v in vars_with_empty_grads])) 943 return filtered 944 945 946def _var_key(var): 947 """Key for representing a primary variable, for looking up slots. 948 949 In graph mode the name is derived from the var shared name. 950 In eager mode the name is derived from the var unique id. 951 If distribution strategy exists, get the primary variable first. 952 953 Args: 954 var: the variable. 955 956 Returns: 957 the unique name of the variable. 958 """ 959 960 # pylint: disable=protected-access 961 # Get the distributed variable if it exists. 962 if getattr(var, "_distributed_container", None) is not None: 963 var = var._distributed_container() 964 if var._in_graph_mode: 965 return var._shared_name 966 return var._unique_id 967 968 969def _get_slot_key_from_var(var, slot_name): 970 """Get the slot key for the variable: var_name/slot_name.""" 971 972 name = _var_key(var) 973 return name + "/" + slot_name 974 975 976class _RestoredOptimizer(OptimizerV2): 977 """A non-functional Optimizer implementation for checkpoint compatibility. 978 979 Holds slot variables and hyperparameters when an optimizer is restored from a 980 SavedModel. These variables may be referenced in functions along with ops 981 created by the original optimizer, but currently we do not support using the 982 optimizer object iself (e.g. through `apply_gradients`). 983 """ 984 # TODO(allenl): Make the restored optimizer functional by tracing its apply 985 # methods. 986 987 def __init__(self): 988 super(_RestoredOptimizer, self).__init__("_RestoredOptimizer") 989 self._hypers_created = True 990 991 def get_config(self): 992 # TODO(allenl): Save and restore the Optimizer's config 993 raise NotImplementedError( 994 "Restoring functional Optimzers from SavedModels is not currently " 995 "supported. Please file a feature request if this limitation bothers " 996 "you.") 997 998revived_types.register_revived_type( 999 "optimizer", 1000 lambda obj: isinstance(obj, OptimizerV2), 1001 versions=[revived_types.VersionedTypeRegistration( 1002 object_factory=lambda proto: _RestoredOptimizer(), 1003 version=1, 1004 min_producer_version=1, 1005 min_consumer_version=1, 1006 setter=_RestoredOptimizer._set_hyper # pylint: disable=protected-access 1007 )]) 1008