1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras training and evaluation routines for eager execution. 16""" 17# pylint: disable=protected-access 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23 24import numpy as np 25 26from tensorflow.python.eager.backprop import GradientTape 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_util 29from tensorflow.python.keras import backend 30from tensorflow.python.keras.engine import training_utils 31from tensorflow.python.keras.utils import losses_utils 32from tensorflow.python.ops import math_ops 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.util import nest 35 36 37def _eager_loss_fn(outputs, targets, loss_fn, output_name): 38 with backend.name_scope(output_name + '_loss'): 39 loss = loss_fn(targets, outputs) 40 return loss 41 42 43def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None): 44 """Calculates the metrics for each output of the given model. 45 46 Arguments: 47 model: The model on which metrics are being calculated. 48 outputs: The outputs of the given model. 49 targets: The predictions or targets of the given model. 50 sample_weights: Optional list of sample weights for each output. 51 masks: Optional list of masks for each output. 52 53 Returns: 54 Returns the metric results for each output of the model. 55 """ 56 outputs = nest.flatten(outputs) 57 targets = nest.flatten(targets) 58 # TODO(psv): Consider supporting skip target indices in eager mode? 59 metric_results = model._handle_metrics( 60 outputs, targets=targets, sample_weights=sample_weights, masks=masks) 61 return [backend.mean(t) for t in metric_results] 62 63 64def _model_loss(model, 65 inputs, 66 targets, 67 output_loss_metrics=None, 68 sample_weights=None, 69 training=False): 70 """Calculates the loss for a given model. 71 72 Arguments: 73 model: The model on which metrics are being calculated. 74 inputs: Either a dictionary of inputs to the model or a list of input 75 arrays. 76 targets: List of target arrays. 77 output_loss_metrics: List of metrics that are used to aggregated output 78 loss values. 79 sample_weights: Optional list of sample weight arrays. 80 training: Whether the model should be run in inference or training mode. 81 82 Returns: 83 Returns the model output, total loss, loss value calculated using the 84 specified loss function and masks for each output. The total loss includes 85 regularization losses and applies masking and sample weighting 86 to the loss value. 87 """ 88 # Used to keep track of the total loss value (stateless). 89 # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) + 90 # loss_weight_2 * output_2_loss_fn(...) + 91 # layer losses. 92 total_loss = 0 93 kwargs = {} 94 if model._expects_training_arg: 95 kwargs['training'] = training 96 if len(inputs) == 1 and not isinstance(inputs, dict): 97 inputs = inputs[0] 98 99 # Allow mixed `NumPy` and `EagerTensor` input here. 100 if any( 101 isinstance(input_t, (np.ndarray, float, int)) 102 for input_t in nest.flatten(inputs)): 103 inputs = nest.map_structure(ops.convert_to_tensor, inputs) 104 105 outs = model(inputs, **kwargs) 106 107 outs = nest.flatten(outs) 108 # `None` by default for `EagerTensors`. 109 masks = [t._keras_mask for t in outs] 110 targets = nest.flatten(targets) 111 112 # Used to keep track of individual output losses. 113 output_losses = [] 114 115 with backend.name_scope('loss'): 116 for i, loss_fn in enumerate(model.loss_functions): 117 weights = sample_weights[i] if sample_weights else None 118 mask = masks[i] 119 with backend.name_scope(model.output_names[i] + '_loss'): 120 if mask is not None: 121 mask = math_ops.cast(mask, outs[i].dtype) 122 # Update weights with mask. 123 if weights is None: 124 weights = mask 125 else: 126 # Update dimensions of weights to match with mask if possible. 127 mask, _, weights = ( 128 losses_utils.squeeze_or_expand_dimensions(mask, None, weights)) 129 weights *= mask 130 131 # Reset reduction on the loss so that we can get the per sample loss 132 # value. We use this to get both the stateless and stateful loss 133 # values without having to compute the underlying loss function 134 # twice. 135 weighted_losses = None 136 if hasattr(loss_fn, 'reduction'): 137 current_loss_reduction = loss_fn.reduction 138 loss_fn.reduction = losses_utils.ReductionV2.NONE 139 weighted_losses = loss_fn(targets[i], outs[i], sample_weight=weights) 140 loss_fn.reduction = current_loss_reduction 141 142 # Compute the stateless loss value. 143 output_loss = losses_utils.reduce_weighted_loss(weighted_losses) 144 else: 145 # Compute the stateless loss value for a custom loss class. 146 # Here we assume that the class takes care of loss reduction 147 # because if this class returns a vector value we cannot 148 # differentiate between use case where a custom optimizer 149 # expects a vector loss value vs unreduced per-sample loss value. 150 output_loss = loss_fn(targets[i], outs[i], sample_weight=weights) 151 152 # If the number of outputs is 1 then we don't append the loss metric 153 # associated with each model output. When there are multiple outputs 154 # associated with a model, each output's loss is calculated and returned 155 # as part of the loss_metrics. 156 if len(model.outputs) > 1: 157 # Compute the stateful loss value. 158 if weighted_losses is not None: 159 aggregated_output_loss = output_loss_metrics[i](weighted_losses) 160 else: 161 # Custom loss class. 162 aggregated_output_loss = training_utils.call_metric_function( 163 output_loss_metrics[i], targets[i], outs[i], weights=weights) 164 # Keep track of the stateful output loss result. 165 output_losses.append(aggregated_output_loss) 166 167 total_loss += model.loss_weights_list[i] * output_loss 168 169 total_loss = backend.mean(total_loss) 170 # Add regularization losses 171 custom_losses = model.losses 172 if custom_losses: 173 total_loss += losses_utils.scale_loss_for_distribution( 174 math_ops.add_n(custom_losses)) 175 176 return outs, total_loss, output_losses, masks 177 178 179def _process_single_batch(model, 180 inputs, 181 targets, 182 output_loss_metrics=None, 183 sample_weights=None, 184 training=False): 185 """Calculate the loss and gradient for one input batch. 186 187 The model weights are updated if training is set to True. 188 189 Arguments: 190 model: Model whose loss has to be calculated. 191 inputs: List of input arrays. 192 targets: List of target arrays. 193 output_loss_metrics: List of metrics that are used to aggregated output 194 loss values. 195 sample_weights: Optional list of sample weight arrays. 196 training: The boolean represents if the weights of the model are updated. 197 'fit' methods will set this to True while 'evaluate' methods will 198 set this to False. 199 200 Returns: 201 output of the model, total loss, the loss and the mask 202 associated with each output. 203 204 Raises: 205 ValueError: If the model has no loss to optimize. 206 """ 207 with backend.eager_learning_phase_scope(1 if training else 0): 208 with GradientTape() as tape: 209 outs, total_loss, output_losses, masks = ( 210 _model_loss( 211 model, 212 inputs, 213 targets, 214 output_loss_metrics=output_loss_metrics, 215 sample_weights=sample_weights, 216 training=training)) 217 if total_loss is None: 218 raise ValueError('The model cannot be run ' 219 'because it has no loss to optimize.') 220 if training: 221 if not model.trainable_weights: 222 logging.warning('The list of trainable weights is empty. Make sure that' 223 ' you are not setting model.trainable to False before ' 224 'compiling the model.') 225 else: 226 grads = tape.gradient(total_loss, model.trainable_weights) 227 model.optimizer.apply_gradients(zip(grads, 228 model.trainable_weights)) 229 return outs, total_loss, output_losses, masks 230 231 232def train_on_batch(model, 233 inputs, 234 targets, 235 sample_weights=None, 236 output_loss_metrics=None): 237 """Calculates the loss and gradient updates for one input batch. 238 239 Arguments: 240 model: Model whose loss has to be calculated. 241 inputs: Input batch data. 242 targets: Target batch data. 243 sample_weights: Sample weight batch data. 244 output_loss_metrics: List of metrics that are used to aggregated output 245 loss values. 246 247 Returns: 248 total loss and the loss associated with each output. 249 """ 250 if isinstance(inputs, collections.Sequence): 251 if len(inputs) and tensor_util.is_tensor(inputs[0]): 252 inputs = training_utils.cast_if_floating_dtype(inputs) 253 targets = training_utils.cast_if_floating_dtype(targets) 254 else: 255 inputs = training_utils.cast_if_floating_dtype( 256 [ops.convert_to_tensor(val) for val in inputs]) 257 targets = training_utils.cast_if_floating_dtype( 258 [ops.convert_to_tensor(val) for val in targets]) 259 if sample_weights: 260 sample_weights = [ 261 training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val)) 262 if val is not None else None for val in sample_weights 263 ] 264 265 outs, total_loss, output_losses, masks = ( 266 _process_single_batch( 267 model, 268 inputs, 269 targets, 270 sample_weights=sample_weights, 271 training=True, 272 output_loss_metrics=output_loss_metrics)) 273 if not isinstance(outs, list): 274 outs = [outs] 275 metrics_results = _eager_metrics_fn( 276 model, outs, targets, sample_weights=sample_weights, masks=masks) 277 total_loss = nest.flatten(total_loss) 278 results = total_loss + output_losses + metrics_results 279 280 return [tensor_util.constant_value(v) for v in results] 281 282 283def test_on_batch(model, 284 inputs, 285 targets, 286 sample_weights=None, 287 output_loss_metrics=None): 288 """Calculates the loss for one input batch. 289 290 Arguments: 291 model: Model whose loss has to be calculated. 292 inputs: Input batch data. 293 targets: Target batch data. 294 sample_weights: Sample weight batch data. 295 output_loss_metrics: List of metrics that are used to aggregated output 296 loss values. 297 298 Returns: 299 total loss, loss and metrics associated with each output. 300 """ 301 if isinstance(inputs, collections.Sequence): 302 if len(inputs) and tensor_util.is_tensor(inputs[0]): 303 inputs = training_utils.cast_if_floating_dtype(inputs) 304 targets = training_utils.cast_if_floating_dtype(targets) 305 else: 306 inputs = training_utils.cast_if_floating_dtype( 307 [ops.convert_to_tensor(val) for val in inputs]) 308 targets = training_utils.cast_if_floating_dtype( 309 [ops.convert_to_tensor(val) for val in targets]) 310 if sample_weights: 311 sample_weights = [ 312 training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val)) 313 if val is not None else None for val in sample_weights 314 ] 315 outs, total_loss, output_losses, masks = ( 316 _model_loss( 317 model, 318 inputs, 319 targets, 320 sample_weights=sample_weights, 321 training=False, 322 output_loss_metrics=output_loss_metrics)) 323 if not isinstance(outs, list): 324 outs = [outs] 325 metrics_results = _eager_metrics_fn( 326 model, outs, targets, sample_weights=sample_weights, masks=masks) 327 total_loss = nest.flatten(total_loss) 328 results = total_loss + output_losses + metrics_results 329 330 return [tensor_util.constant_value(v) for v in results] 331