1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""Constraints: functions that impose constraints on weight values.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import six
23
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.keras import backend as K
26from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
27from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.util.tf_export import keras_export
32from tensorflow.tools.docs import doc_controls
33
34
35@keras_export('keras.constraints.Constraint')
36class Constraint(object):
37
38  def __call__(self, w):
39    return w
40
41  def get_config(self):
42    return {}
43
44
45@keras_export('keras.constraints.MaxNorm', 'keras.constraints.max_norm')
46class MaxNorm(Constraint):
47  """MaxNorm weight constraint.
48
49  Constrains the weights incident to each hidden unit
50  to have a norm less than or equal to a desired value.
51
52  Also available via the shortcut function `tf.keras.constraints.max_norm`.
53
54  Args:
55    max_value: the maximum norm value for the incoming weights.
56    axis: integer, axis along which to calculate weight norms.
57      For instance, in a `Dense` layer the weight matrix
58      has shape `(input_dim, output_dim)`,
59      set `axis` to `0` to constrain each weight vector
60      of length `(input_dim,)`.
61      In a `Conv2D` layer with `data_format="channels_last"`,
62      the weight tensor has shape
63      `(rows, cols, input_depth, output_depth)`,
64      set `axis` to `[0, 1, 2]`
65      to constrain the weights of each filter tensor of size
66      `(rows, cols, input_depth)`.
67
68  """
69
70  def __init__(self, max_value=2, axis=0):
71    self.max_value = max_value
72    self.axis = axis
73
74  @doc_controls.do_not_generate_docs
75  def __call__(self, w):
76    norms = K.sqrt(
77        math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
78    desired = K.clip(norms, 0, self.max_value)
79    return w * (desired / (K.epsilon() + norms))
80
81  @doc_controls.do_not_generate_docs
82  def get_config(self):
83    return {'max_value': self.max_value, 'axis': self.axis}
84
85
86@keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg')
87class NonNeg(Constraint):
88  """Constrains the weights to be non-negative.
89
90  Also available via the shortcut function `tf.keras.constraints.non_neg`.
91  """
92
93  def __call__(self, w):
94    return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx())
95
96
97@keras_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm')
98class UnitNorm(Constraint):
99  """Constrains the weights incident to each hidden unit to have unit norm.
100
101  Also available via the shortcut function `tf.keras.constraints.unit_norm`.
102
103  Args:
104    axis: integer, axis along which to calculate weight norms.
105      For instance, in a `Dense` layer the weight matrix
106      has shape `(input_dim, output_dim)`,
107      set `axis` to `0` to constrain each weight vector
108      of length `(input_dim,)`.
109      In a `Conv2D` layer with `data_format="channels_last"`,
110      the weight tensor has shape
111      `(rows, cols, input_depth, output_depth)`,
112      set `axis` to `[0, 1, 2]`
113      to constrain the weights of each filter tensor of size
114      `(rows, cols, input_depth)`.
115  """
116
117  def __init__(self, axis=0):
118    self.axis = axis
119
120  @doc_controls.do_not_generate_docs
121  def __call__(self, w):
122    return w / (
123        K.epsilon() + K.sqrt(
124            math_ops.reduce_sum(
125                math_ops.square(w), axis=self.axis, keepdims=True)))
126
127  @doc_controls.do_not_generate_docs
128  def get_config(self):
129    return {'axis': self.axis}
130
131
132@keras_export('keras.constraints.MinMaxNorm', 'keras.constraints.min_max_norm')
133class MinMaxNorm(Constraint):
134  """MinMaxNorm weight constraint.
135
136  Constrains the weights incident to each hidden unit
137  to have the norm between a lower bound and an upper bound.
138
139  Also available via the shortcut function `tf.keras.constraints.min_max_norm`.
140
141  Args:
142    min_value: the minimum norm for the incoming weights.
143    max_value: the maximum norm for the incoming weights.
144    rate: rate for enforcing the constraint: weights will be
145      rescaled to yield
146      `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.
147      Effectively, this means that rate=1.0 stands for strict
148      enforcement of the constraint, while rate<1.0 means that
149      weights will be rescaled at each step to slowly move
150      towards a value inside the desired interval.
151    axis: integer, axis along which to calculate weight norms.
152      For instance, in a `Dense` layer the weight matrix
153      has shape `(input_dim, output_dim)`,
154      set `axis` to `0` to constrain each weight vector
155      of length `(input_dim,)`.
156      In a `Conv2D` layer with `data_format="channels_last"`,
157      the weight tensor has shape
158      `(rows, cols, input_depth, output_depth)`,
159      set `axis` to `[0, 1, 2]`
160      to constrain the weights of each filter tensor of size
161      `(rows, cols, input_depth)`.
162  """
163
164  def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0):
165    self.min_value = min_value
166    self.max_value = max_value
167    self.rate = rate
168    self.axis = axis
169
170  @doc_controls.do_not_generate_docs
171  def __call__(self, w):
172    norms = K.sqrt(
173        math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
174    desired = (
175        self.rate * K.clip(norms, self.min_value, self.max_value) +
176        (1 - self.rate) * norms)
177    return w * (desired / (K.epsilon() + norms))
178
179  @doc_controls.do_not_generate_docs
180  def get_config(self):
181    return {
182        'min_value': self.min_value,
183        'max_value': self.max_value,
184        'rate': self.rate,
185        'axis': self.axis
186    }
187
188
189@keras_export('keras.constraints.RadialConstraint',
190              'keras.constraints.radial_constraint')
191class RadialConstraint(Constraint):
192  """Constrains `Conv2D` kernel weights to be the same for each radius.
193
194  Also available via the shortcut function
195  `tf.keras.constraints.radial_constraint`.
196
197  For example, the desired output for the following 4-by-4 kernel:
198
199  ```
200      kernel = [[v_00, v_01, v_02, v_03],
201                [v_10, v_11, v_12, v_13],
202                [v_20, v_21, v_22, v_23],
203                [v_30, v_31, v_32, v_33]]
204  ```
205
206  is this::
207
208  ```
209      kernel = [[v_11, v_11, v_11, v_11],
210                [v_11, v_33, v_33, v_11],
211                [v_11, v_33, v_33, v_11],
212                [v_11, v_11, v_11, v_11]]
213  ```
214
215  This constraint can be applied to any `Conv2D` layer version, including
216  `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or
217  `"channels_first"` data format. The method assumes the weight tensor is of
218  shape `(rows, cols, input_depth, output_depth)`.
219  """
220
221  @doc_controls.do_not_generate_docs
222  def __call__(self, w):
223    w_shape = w.shape
224    if w_shape.rank is None or w_shape.rank != 4:
225      raise ValueError(
226          'The weight tensor must be of rank 4, but is of shape: %s' % w_shape)
227
228    height, width, channels, kernels = w_shape
229    w = K.reshape(w, (height, width, channels * kernels))
230    # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once K.switch
231    # is supported.
232    w = K.map_fn(
233        self._kernel_constraint,
234        K.stack(array_ops.unstack(w, axis=-1), axis=0))
235    return K.reshape(K.stack(array_ops.unstack(w, axis=0), axis=-1),
236                     (height, width, channels, kernels))
237
238  def _kernel_constraint(self, kernel):
239    """Radially constraints a kernel with shape (height, width, channels)."""
240    padding = K.constant([[1, 1], [1, 1]], dtype='int32')
241
242    kernel_shape = K.shape(kernel)[0]
243    start = K.cast(kernel_shape / 2, 'int32')
244
245    kernel_new = K.switch(
246        K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
247        lambda: kernel[start - 1:start, start - 1:start],
248        lambda: kernel[start - 1:start, start - 1:start] + K.zeros(  # pylint: disable=g-long-lambda
249            (2, 2), dtype=kernel.dtype))
250    index = K.switch(
251        K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
252        lambda: K.constant(0, dtype='int32'),
253        lambda: K.constant(1, dtype='int32'))
254    while_condition = lambda index, *args: K.less(index, start)
255
256    def body_fn(i, array):
257      return i + 1, array_ops.pad(
258          array,
259          padding,
260          constant_values=kernel[start + i, start + i])
261
262    _, kernel_new = control_flow_ops.while_loop(
263        while_condition,
264        body_fn,
265        [index, kernel_new],
266        shape_invariants=[index.get_shape(),
267                          tensor_shape.TensorShape([None, None])])
268    return kernel_new
269
270
271# Aliases.
272
273max_norm = MaxNorm
274non_neg = NonNeg
275unit_norm = UnitNorm
276min_max_norm = MinMaxNorm
277radial_constraint = RadialConstraint
278
279# Legacy aliases.
280maxnorm = max_norm
281nonneg = non_neg
282unitnorm = unit_norm
283
284
285@keras_export('keras.constraints.serialize')
286def serialize(constraint):
287  return serialize_keras_object(constraint)
288
289
290@keras_export('keras.constraints.deserialize')
291def deserialize(config, custom_objects=None):
292  return deserialize_keras_object(
293      config,
294      module_objects=globals(),
295      custom_objects=custom_objects,
296      printable_module_name='constraint')
297
298
299@keras_export('keras.constraints.get')
300def get(identifier):
301  if identifier is None:
302    return None
303  if isinstance(identifier, dict):
304    return deserialize(identifier)
305  elif isinstance(identifier, six.string_types):
306    config = {'class_name': str(identifier), 'config': {}}
307    return deserialize(config)
308  elif callable(identifier):
309    return identifier
310  else:
311    raise ValueError('Could not interpret constraint identifier: ' +
312                     str(identifier))
313