1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Normalization preprocessing layer.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.keras import backend as K 27from tensorflow.python.keras.engine import base_preprocessing_layer 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import init_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import nn_impl 32from tensorflow.python.ops import variables 33from tensorflow.python.util.tf_export import keras_export 34 35 36@keras_export('keras.layers.experimental.preprocessing.Normalization', v1=[]) 37class Normalization(base_preprocessing_layer.PreprocessingLayer): 38 """Feature-wise normalization of the data. 39 40 This layer will coerce its inputs into a distribution centered around 41 0 with standard deviation 1. It accomplishes this by precomputing the mean and 42 variance of the data, and calling (input-mean)/sqrt(var) at runtime. 43 44 What happens in `adapt`: Compute mean and variance of the data and store them 45 as the layer's weights. `adapt` should be called before `fit`, `evaluate`, 46 or `predict`. 47 48 Args: 49 axis: Integer or tuple of integers, the axis or axes that should be 50 "kept". These axes are not be summed over when calculating the 51 normalization statistics. By default the last axis, the `features` axis 52 is kept and any `space` or `time` axes are summed. Each element in the 53 the axes that are kept is normalized independently. If `axis` is set to 54 'None', the layer will perform scalar normalization (dividing the input 55 by a single scalar value). The `batch` axis, 0, is always summed over 56 (`axis=0` is not allowed). 57 mean: The mean value(s) to use during normalization. The passed value(s) 58 will be broadcast to the shape of the kept axes above; if the value(s) 59 cannot be broadcast, an error will be raised when this layer's build() 60 method is called. 61 variance: The variance value(s) to use during normalization. The passed 62 value(s) will be broadcast to the shape of the kept axes above; if the 63 value(s)cannot be broadcast, an error will be raised when this layer's 64 build() method is called. 65 66 Examples: 67 68 Calculate the mean and variance by analyzing the dataset in `adapt`. 69 70 >>> adapt_data = np.array([[1.], [2.], [3.], [4.], [5.]], dtype=np.float32) 71 >>> input_data = np.array([[1.], [2.], [3.]], np.float32) 72 >>> layer = Normalization() 73 >>> layer.adapt(adapt_data) 74 >>> layer(input_data) 75 <tf.Tensor: shape=(3, 1), dtype=float32, numpy= 76 array([[-1.4142135 ], 77 [-0.70710677], 78 [ 0. ]], dtype=float32)> 79 80 Pass the mean and variance directly. 81 82 >>> input_data = np.array([[1.], [2.], [3.]], np.float32) 83 >>> layer = Normalization(mean=3., variance=2.) 84 >>> layer(input_data) 85 <tf.Tensor: shape=(3, 1), dtype=float32, numpy= 86 array([[-1.4142135 ], 87 [-0.70710677], 88 [ 0. ]], dtype=float32)> 89 """ 90 91 def __init__(self, axis=-1, mean=None, variance=None, **kwargs): 92 super(Normalization, self).__init__(stateful=True, streaming=True, **kwargs) 93 base_preprocessing_layer.keras_kpl_gauge.get_cell('Normalization').set(True) 94 95 # Standardize `axis` to a tuple. 96 if axis is None: 97 axis = () 98 elif isinstance(axis, int): 99 axis = (axis,) 100 else: 101 axis = tuple(axis) 102 if 0 in axis: 103 raise ValueError('The argument \'axis\' may not be 0.') 104 self.axis = axis 105 106 # Set `mean` and `variance` if passed. 107 if isinstance(mean, variables.Variable): 108 raise ValueError('Normalization does not support passing a Variable ' 109 'for the `mean` init arg.') 110 if isinstance(variance, variables.Variable): 111 raise ValueError('Normalization does not support passing a Variable ' 112 'for the `variance` init arg.') 113 if mean is not None and variance is not None: 114 mean = convert_to_ndarray(mean) 115 variance = convert_to_ndarray(variance) 116 elif mean is not None or variance is not None: 117 raise ValueError( 118 'When setting values directly, both `mean` and `variance` ' 119 'must be set. Got mean: {} and variance: {}'.format(mean, variance)) 120 self.mean_val = mean 121 self.variance_val = variance 122 123 def build(self, input_shape): 124 input_shape = tensor_shape.TensorShape(input_shape).as_list() 125 if len(input_shape) == 1: 126 input_shape = input_shape + [1] 127 ndim = len(input_shape) 128 129 if any(a < 1 - ndim or a >= ndim for a in self.axis): 130 raise ValueError('All `axis` values must be in the range ' 131 '[1 - ndim, ndim - 1]. Found ' 132 'ndim: `{}`, axis: {}'.format(ndim, self.axis)) 133 134 # Axes to be kept, replacing negative values with positive equivalents. 135 # Sorted to avoid transposing axes. 136 self._keep_axis = sorted([d if d >= 0 else d + ndim for d in self.axis]) 137 # Axes to be reduced. 138 self._reduce_axis = [d for d in range(ndim) if d not in self._keep_axis] 139 # 1 if an axis should be reduced, 0 otherwise. 140 self._reduce_axis_mask = [ 141 0 if d in self._keep_axis else 1 for d in range(ndim) 142 ] 143 # Broadcast any reduced axes. 144 self._broadcast_shape = [ 145 input_shape[d] if d in self._keep_axis else 1 for d in range(ndim) 146 ] 147 # Create variables without keeping reduced axes. 148 mean_and_var_shape = tuple(input_shape[d] for d in self._keep_axis) 149 150 self.mean = self.add_weight( 151 name='mean', 152 shape=mean_and_var_shape, 153 dtype=self.dtype, 154 initializer=init_ops.zeros_initializer, 155 trainable=False) 156 self.variance = self.add_weight( 157 name='variance', 158 shape=mean_and_var_shape, 159 dtype=self.dtype, 160 initializer=init_ops.ones_initializer, 161 trainable=False) 162 self.count = self.add_weight( 163 name='count', 164 shape=(), 165 dtype=dtypes.int64, 166 initializer=init_ops.zeros_initializer, 167 trainable=False) 168 169 super(Normalization, self).build(input_shape) 170 171 if (self.mean_val is not None and self.variance_val is not None): 172 mean_val = self.mean_val * np.ones(mean_and_var_shape) 173 variance_val = self.variance_val * np.ones(mean_and_var_shape) 174 self.mean.assign(mean_val) 175 self.variance.assign(variance_val) 176 177 self.built = True 178 179 def update_state(self, data): 180 if not self.built: 181 raise RuntimeError('`build` must be called before `update_state`.') 182 183 data = self._standardize_inputs(data) 184 batch_mean, batch_variance = nn_impl.moments_v2( 185 data, axes=self._reduce_axis) 186 batch_shape = array_ops.shape(data, out_type=self.count.dtype) 187 batch_reduce_shape = array_ops.gather(batch_shape, self._reduce_axis) 188 batch_count = math_ops.reduce_prod(batch_reduce_shape) 189 190 total_count = batch_count + self.count 191 batch_weight = ( 192 math_ops.cast(batch_count, dtype=self.dtype) / 193 math_ops.cast(total_count, dtype=self.dtype)) 194 existing_weight = 1. - batch_weight 195 196 total_mean = self.mean * existing_weight + batch_mean * batch_weight 197 # The variance is computed using the lack-of-fit sum of squares 198 # formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares). 199 total_variance = ((self.variance + 200 (self.mean - total_mean)**2) * existing_weight + 201 (batch_variance + 202 (batch_mean - total_mean)**2) * batch_weight) 203 self.mean.assign(total_mean) 204 self.variance.assign(total_variance) 205 self.count.assign(total_count) 206 207 def merge_state(self, layers): 208 layers = layers + [self] 209 if any(not l.built for l in layers): 210 raise ValueError( 211 'All layers to be merged must have been adapted to some inputs ' 212 'first (otherwise they have no state).') 213 214 layer_counts = [l.count for l in layers] 215 layer_means = [l.mean for l in layers] 216 layer_variances = [l.variance for l in layers] 217 218 total_count = math_ops.reduce_sum(layer_counts) 219 layer_weightings = ( 220 math_ops.cast(layer_counts, self.dtype) / 221 math_ops.cast(total_count, self.dtype)) 222 layer_weightings = array_ops.reshape( 223 layer_weightings, shape=[len(layers)] + [1] * self.mean.shape.rank) 224 225 total_mean = math_ops.reduce_sum(layer_means * layer_weightings, axis=0) 226 inter_layer_variances = (layer_means - total_mean)**2 227 total_variance = math_ops.reduce_sum( 228 ((layer_variances + inter_layer_variances) * layer_weightings), axis=0) 229 230 self.mean.assign(total_mean) 231 self.variance.assign(total_variance) 232 self.count.assign(total_count) 233 234 def reset_state(self): # pylint: disable=method-hidden 235 if self.built: 236 self.mean.assign(array_ops.zeros_like(self.mean)) 237 self.variance.assign(array_ops.ones_like(self.variance)) 238 self.count.assign(array_ops.zeros_like(self.count)) 239 240 def call(self, inputs): 241 inputs = self._standardize_inputs(inputs) 242 # We need to reshape the mean and variance data to ensure that Tensorflow 243 # broadcasts the data correctly. 244 mean = array_ops.reshape(self.mean, self._broadcast_shape) 245 variance = array_ops.reshape(self.variance, self._broadcast_shape) 246 return ((inputs - mean) / 247 math_ops.maximum(math_ops.sqrt(variance), K.epsilon())) 248 249 def compute_output_shape(self, input_shape): 250 return input_shape 251 252 def compute_output_signature(self, input_spec): 253 return input_spec 254 255 def get_config(self): 256 config = super(Normalization, self).get_config() 257 config.update({'axis': self.axis}) 258 return config 259 260 def set_weights(self, weights): 261 """Override for set_weights to ensure we can set just mean/var weights.""" 262 if len(weights) == 2: 263 weights.append(np.array(0)) 264 super(Normalization, self).set_weights(weights) 265 266 def _standardize_inputs(self, inputs): 267 inputs = ops.convert_to_tensor_v2_with_dispatch(inputs) 268 if inputs.shape.rank == 0: 269 inputs = array_ops.reshape(inputs, [1, 1]) 270 elif inputs.shape.rank == 1: 271 inputs = array_ops.expand_dims(inputs, 1) 272 273 if inputs.dtype != self.dtype: 274 inputs = math_ops.cast(inputs, self.dtype) 275 return inputs 276 277 278def convert_to_ndarray(values): 279 if isinstance(values, np.ndarray): 280 return values 281 elif isinstance(values, ops.Tensor): 282 return K.get_value(values) 283 else: 284 return np.array(values) 285