1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Normalization preprocessing layer."""
16# pylint: disable=g-classes-have-attributes
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.keras import backend as K
27from tensorflow.python.keras.engine import base_preprocessing_layer
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import init_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn_impl
32from tensorflow.python.ops import variables
33from tensorflow.python.util.tf_export import keras_export
34
35
36@keras_export('keras.layers.experimental.preprocessing.Normalization', v1=[])
37class Normalization(base_preprocessing_layer.PreprocessingLayer):
38  """Feature-wise normalization of the data.
39
40  This layer will coerce its inputs into a distribution centered around
41  0 with standard deviation 1. It accomplishes this by precomputing the mean and
42  variance of the data, and calling (input-mean)/sqrt(var) at runtime.
43
44  What happens in `adapt`: Compute mean and variance of the data and store them
45    as the layer's weights. `adapt` should be called before `fit`, `evaluate`,
46    or `predict`.
47
48  Args:
49      axis: Integer or tuple of integers, the axis or axes that should be
50        "kept". These axes are not be summed over when calculating the
51        normalization statistics. By default the last axis, the `features` axis
52        is kept and any `space` or `time` axes are summed. Each element in the
53        the axes that are kept is normalized independently. If `axis` is set to
54        'None', the layer will perform scalar normalization (dividing the input
55        by a single scalar value). The `batch` axis, 0, is always summed over
56        (`axis=0` is not allowed).
57      mean: The mean value(s) to use during normalization. The passed value(s)
58        will be broadcast to the shape of the kept axes above; if the value(s)
59        cannot be broadcast, an error will be raised when this layer's build()
60        method is called.
61      variance: The variance value(s) to use during normalization. The passed
62        value(s) will be broadcast to the shape of the kept axes above; if the
63        value(s)cannot be broadcast, an error will be raised when this layer's
64        build() method is called.
65
66  Examples:
67
68  Calculate the mean and variance by analyzing the dataset in `adapt`.
69
70  >>> adapt_data = np.array([[1.], [2.], [3.], [4.], [5.]], dtype=np.float32)
71  >>> input_data = np.array([[1.], [2.], [3.]], np.float32)
72  >>> layer = Normalization()
73  >>> layer.adapt(adapt_data)
74  >>> layer(input_data)
75  <tf.Tensor: shape=(3, 1), dtype=float32, numpy=
76  array([[-1.4142135 ],
77         [-0.70710677],
78         [ 0.        ]], dtype=float32)>
79
80  Pass the mean and variance directly.
81
82  >>> input_data = np.array([[1.], [2.], [3.]], np.float32)
83  >>> layer = Normalization(mean=3., variance=2.)
84  >>> layer(input_data)
85  <tf.Tensor: shape=(3, 1), dtype=float32, numpy=
86  array([[-1.4142135 ],
87         [-0.70710677],
88         [ 0.        ]], dtype=float32)>
89  """
90
91  def __init__(self, axis=-1, mean=None, variance=None, **kwargs):
92    super(Normalization, self).__init__(stateful=True, streaming=True, **kwargs)
93    base_preprocessing_layer.keras_kpl_gauge.get_cell('Normalization').set(True)
94
95    # Standardize `axis` to a tuple.
96    if axis is None:
97      axis = ()
98    elif isinstance(axis, int):
99      axis = (axis,)
100    else:
101      axis = tuple(axis)
102    if 0 in axis:
103      raise ValueError('The argument \'axis\' may not be 0.')
104    self.axis = axis
105
106    # Set `mean` and `variance` if passed.
107    if isinstance(mean, variables.Variable):
108      raise ValueError('Normalization does not support passing a Variable '
109                       'for the `mean` init arg.')
110    if isinstance(variance, variables.Variable):
111      raise ValueError('Normalization does not support passing a Variable '
112                       'for the `variance` init arg.')
113    if mean is not None and variance is not None:
114      mean = convert_to_ndarray(mean)
115      variance = convert_to_ndarray(variance)
116    elif mean is not None or variance is not None:
117      raise ValueError(
118          'When setting values directly, both `mean` and `variance` '
119          'must be set. Got mean: {} and variance: {}'.format(mean, variance))
120    self.mean_val = mean
121    self.variance_val = variance
122
123  def build(self, input_shape):
124    input_shape = tensor_shape.TensorShape(input_shape).as_list()
125    if len(input_shape) == 1:
126      input_shape = input_shape + [1]
127    ndim = len(input_shape)
128
129    if any(a < 1 - ndim or a >= ndim for a in self.axis):
130      raise ValueError('All `axis` values must be in the range '
131                       '[1 - ndim, ndim - 1]. Found '
132                       'ndim: `{}`, axis: {}'.format(ndim, self.axis))
133
134    # Axes to be kept, replacing negative values with positive equivalents.
135    # Sorted to avoid transposing axes.
136    self._keep_axis = sorted([d if d >= 0 else d + ndim for d in self.axis])
137    # Axes to be reduced.
138    self._reduce_axis = [d for d in range(ndim) if d not in self._keep_axis]
139    # 1 if an axis should be reduced, 0 otherwise.
140    self._reduce_axis_mask = [
141        0 if d in self._keep_axis else 1 for d in range(ndim)
142    ]
143    # Broadcast any reduced axes.
144    self._broadcast_shape = [
145        input_shape[d] if d in self._keep_axis else 1 for d in range(ndim)
146    ]
147    # Create variables without keeping reduced axes.
148    mean_and_var_shape = tuple(input_shape[d] for d in self._keep_axis)
149
150    self.mean = self.add_weight(
151        name='mean',
152        shape=mean_and_var_shape,
153        dtype=self.dtype,
154        initializer=init_ops.zeros_initializer,
155        trainable=False)
156    self.variance = self.add_weight(
157        name='variance',
158        shape=mean_and_var_shape,
159        dtype=self.dtype,
160        initializer=init_ops.ones_initializer,
161        trainable=False)
162    self.count = self.add_weight(
163        name='count',
164        shape=(),
165        dtype=dtypes.int64,
166        initializer=init_ops.zeros_initializer,
167        trainable=False)
168
169    super(Normalization, self).build(input_shape)
170
171    if (self.mean_val is not None and self.variance_val is not None):
172      mean_val = self.mean_val * np.ones(mean_and_var_shape)
173      variance_val = self.variance_val * np.ones(mean_and_var_shape)
174      self.mean.assign(mean_val)
175      self.variance.assign(variance_val)
176
177    self.built = True
178
179  def update_state(self, data):
180    if not self.built:
181      raise RuntimeError('`build` must be called before `update_state`.')
182
183    data = self._standardize_inputs(data)
184    batch_mean, batch_variance = nn_impl.moments_v2(
185        data, axes=self._reduce_axis)
186    batch_shape = array_ops.shape(data, out_type=self.count.dtype)
187    batch_reduce_shape = array_ops.gather(batch_shape, self._reduce_axis)
188    batch_count = math_ops.reduce_prod(batch_reduce_shape)
189
190    total_count = batch_count + self.count
191    batch_weight = (
192        math_ops.cast(batch_count, dtype=self.dtype) /
193        math_ops.cast(total_count, dtype=self.dtype))
194    existing_weight = 1. - batch_weight
195
196    total_mean = self.mean * existing_weight + batch_mean * batch_weight
197    # The variance is computed using the lack-of-fit sum of squares
198    # formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
199    total_variance = ((self.variance +
200                       (self.mean - total_mean)**2) * existing_weight +
201                      (batch_variance +
202                       (batch_mean - total_mean)**2) * batch_weight)
203    self.mean.assign(total_mean)
204    self.variance.assign(total_variance)
205    self.count.assign(total_count)
206
207  def merge_state(self, layers):
208    layers = layers + [self]
209    if any(not l.built for l in layers):
210      raise ValueError(
211          'All layers to be merged must have been adapted to some inputs '
212          'first (otherwise they have no state).')
213
214    layer_counts = [l.count for l in layers]
215    layer_means = [l.mean for l in layers]
216    layer_variances = [l.variance for l in layers]
217
218    total_count = math_ops.reduce_sum(layer_counts)
219    layer_weightings = (
220        math_ops.cast(layer_counts, self.dtype) /
221        math_ops.cast(total_count, self.dtype))
222    layer_weightings = array_ops.reshape(
223        layer_weightings, shape=[len(layers)] + [1] * self.mean.shape.rank)
224
225    total_mean = math_ops.reduce_sum(layer_means * layer_weightings, axis=0)
226    inter_layer_variances = (layer_means - total_mean)**2
227    total_variance = math_ops.reduce_sum(
228        ((layer_variances + inter_layer_variances) * layer_weightings), axis=0)
229
230    self.mean.assign(total_mean)
231    self.variance.assign(total_variance)
232    self.count.assign(total_count)
233
234  def reset_state(self):  # pylint: disable=method-hidden
235    if self.built:
236      self.mean.assign(array_ops.zeros_like(self.mean))
237      self.variance.assign(array_ops.ones_like(self.variance))
238      self.count.assign(array_ops.zeros_like(self.count))
239
240  def call(self, inputs):
241    inputs = self._standardize_inputs(inputs)
242    # We need to reshape the mean and variance data to ensure that Tensorflow
243    # broadcasts the data correctly.
244    mean = array_ops.reshape(self.mean, self._broadcast_shape)
245    variance = array_ops.reshape(self.variance, self._broadcast_shape)
246    return ((inputs - mean) /
247            math_ops.maximum(math_ops.sqrt(variance), K.epsilon()))
248
249  def compute_output_shape(self, input_shape):
250    return input_shape
251
252  def compute_output_signature(self, input_spec):
253    return input_spec
254
255  def get_config(self):
256    config = super(Normalization, self).get_config()
257    config.update({'axis': self.axis})
258    return config
259
260  def set_weights(self, weights):
261    """Override for set_weights to ensure we can set just mean/var weights."""
262    if len(weights) == 2:
263      weights.append(np.array(0))
264    super(Normalization, self).set_weights(weights)
265
266  def _standardize_inputs(self, inputs):
267    inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
268    if inputs.shape.rank == 0:
269      inputs = array_ops.reshape(inputs, [1, 1])
270    elif inputs.shape.rank == 1:
271      inputs = array_ops.expand_dims(inputs, 1)
272
273    if inputs.dtype != self.dtype:
274      inputs = math_ops.cast(inputs, self.dtype)
275    return inputs
276
277
278def convert_to_ndarray(values):
279  if isinstance(values, np.ndarray):
280    return values
281  elif isinstance(values, ops.Tensor):
282    return K.get_value(values)
283  else:
284    return np.array(values)
285