1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""Constraints: functions that impose constraints on weight values.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import six
23
24from tensorflow.python.keras import backend as K
25from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
26from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
27from tensorflow.python.ops import math_ops
28from tensorflow.python.util.tf_export import keras_export
29
30
31@keras_export('keras.constraints.Constraint')
32class Constraint(object):
33
34  def __call__(self, w):
35    return w
36
37  def get_config(self):
38    return {}
39
40
41@keras_export('keras.constraints.MaxNorm', 'keras.constraints.max_norm')
42class MaxNorm(Constraint):
43  """MaxNorm weight constraint.
44
45  Constrains the weights incident to each hidden unit
46  to have a norm less than or equal to a desired value.
47
48  Arguments:
49      m: the maximum norm for the incoming weights.
50      axis: integer, axis along which to calculate weight norms.
51          For instance, in a `Dense` layer the weight matrix
52          has shape `(input_dim, output_dim)`,
53          set `axis` to `0` to constrain each weight vector
54          of length `(input_dim,)`.
55          In a `Conv2D` layer with `data_format="channels_last"`,
56          the weight tensor has shape
57          `(rows, cols, input_depth, output_depth)`,
58          set `axis` to `[0, 1, 2]`
59          to constrain the weights of each filter tensor of size
60          `(rows, cols, input_depth)`.
61
62  """
63
64  def __init__(self, max_value=2, axis=0):
65    self.max_value = max_value
66    self.axis = axis
67
68  def __call__(self, w):
69    norms = K.sqrt(
70        math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
71    desired = K.clip(norms, 0, self.max_value)
72    return w * (desired / (K.epsilon() + norms))
73
74  def get_config(self):
75    return {'max_value': self.max_value, 'axis': self.axis}
76
77
78@keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg')
79class NonNeg(Constraint):
80  """Constrains the weights to be non-negative.
81  """
82
83  def __call__(self, w):
84    return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx())
85
86
87@keras_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm')
88class UnitNorm(Constraint):
89  """Constrains the weights incident to each hidden unit to have unit norm.
90
91  Arguments:
92      axis: integer, axis along which to calculate weight norms.
93          For instance, in a `Dense` layer the weight matrix
94          has shape `(input_dim, output_dim)`,
95          set `axis` to `0` to constrain each weight vector
96          of length `(input_dim,)`.
97          In a `Conv2D` layer with `data_format="channels_last"`,
98          the weight tensor has shape
99          `(rows, cols, input_depth, output_depth)`,
100          set `axis` to `[0, 1, 2]`
101          to constrain the weights of each filter tensor of size
102          `(rows, cols, input_depth)`.
103  """
104
105  def __init__(self, axis=0):
106    self.axis = axis
107
108  def __call__(self, w):
109    return w / (
110        K.epsilon() + K.sqrt(
111            math_ops.reduce_sum(
112                math_ops.square(w), axis=self.axis, keepdims=True)))
113
114  def get_config(self):
115    return {'axis': self.axis}
116
117
118@keras_export('keras.constraints.MinMaxNorm', 'keras.constraints.min_max_norm')
119class MinMaxNorm(Constraint):
120  """MinMaxNorm weight constraint.
121
122  Constrains the weights incident to each hidden unit
123  to have the norm between a lower bound and an upper bound.
124
125  Arguments:
126      min_value: the minimum norm for the incoming weights.
127      max_value: the maximum norm for the incoming weights.
128      rate: rate for enforcing the constraint: weights will be
129          rescaled to yield
130          `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.
131          Effectively, this means that rate=1.0 stands for strict
132          enforcement of the constraint, while rate<1.0 means that
133          weights will be rescaled at each step to slowly move
134          towards a value inside the desired interval.
135      axis: integer, axis along which to calculate weight norms.
136          For instance, in a `Dense` layer the weight matrix
137          has shape `(input_dim, output_dim)`,
138          set `axis` to `0` to constrain each weight vector
139          of length `(input_dim,)`.
140          In a `Conv2D` layer with `data_format="channels_last"`,
141          the weight tensor has shape
142          `(rows, cols, input_depth, output_depth)`,
143          set `axis` to `[0, 1, 2]`
144          to constrain the weights of each filter tensor of size
145          `(rows, cols, input_depth)`.
146  """
147
148  def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0):
149    self.min_value = min_value
150    self.max_value = max_value
151    self.rate = rate
152    self.axis = axis
153
154  def __call__(self, w):
155    norms = K.sqrt(
156        math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
157    desired = (
158        self.rate * K.clip(norms, self.min_value, self.max_value) +
159        (1 - self.rate) * norms)
160    return w * (desired / (K.epsilon() + norms))
161
162  def get_config(self):
163    return {
164        'min_value': self.min_value,
165        'max_value': self.max_value,
166        'rate': self.rate,
167        'axis': self.axis
168    }
169
170
171# Aliases.
172
173max_norm = MaxNorm
174non_neg = NonNeg
175unit_norm = UnitNorm
176min_max_norm = MinMaxNorm
177
178# Legacy aliases.
179maxnorm = max_norm
180nonneg = non_neg
181unitnorm = unit_norm
182
183
184@keras_export('keras.constraints.serialize')
185def serialize(constraint):
186  return serialize_keras_object(constraint)
187
188
189@keras_export('keras.constraints.deserialize')
190def deserialize(config, custom_objects=None):
191  return deserialize_keras_object(
192      config,
193      module_objects=globals(),
194      custom_objects=custom_objects,
195      printable_module_name='constraint')
196
197
198@keras_export('keras.constraints.get')
199def get(identifier):
200  if identifier is None:
201    return None
202  if isinstance(identifier, dict):
203    return deserialize(identifier)
204  elif isinstance(identifier, six.string_types):
205    config = {'class_name': str(identifier), 'config': {}}
206    return deserialize(config)
207  elif callable(identifier):
208    return identifier
209  else:
210    raise ValueError('Could not interpret constraint identifier: ' +
211                     str(identifier))
212