1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""Constraints: functions that impose constraints on weight values. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import six 23 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.keras import backend as K 26from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 27from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.util.tf_export import keras_export 32from tensorflow.tools.docs import doc_controls 33 34 35@keras_export('keras.constraints.Constraint') 36class Constraint(object): 37 38 def __call__(self, w): 39 return w 40 41 def get_config(self): 42 return {} 43 44 45@keras_export('keras.constraints.MaxNorm', 'keras.constraints.max_norm') 46class MaxNorm(Constraint): 47 """MaxNorm weight constraint. 48 49 Constrains the weights incident to each hidden unit 50 to have a norm less than or equal to a desired value. 51 52 Also available via the shortcut function `tf.keras.constraints.max_norm`. 53 54 Args: 55 max_value: the maximum norm value for the incoming weights. 56 axis: integer, axis along which to calculate weight norms. 57 For instance, in a `Dense` layer the weight matrix 58 has shape `(input_dim, output_dim)`, 59 set `axis` to `0` to constrain each weight vector 60 of length `(input_dim,)`. 61 In a `Conv2D` layer with `data_format="channels_last"`, 62 the weight tensor has shape 63 `(rows, cols, input_depth, output_depth)`, 64 set `axis` to `[0, 1, 2]` 65 to constrain the weights of each filter tensor of size 66 `(rows, cols, input_depth)`. 67 68 """ 69 70 def __init__(self, max_value=2, axis=0): 71 self.max_value = max_value 72 self.axis = axis 73 74 @doc_controls.do_not_generate_docs 75 def __call__(self, w): 76 norms = K.sqrt( 77 math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) 78 desired = K.clip(norms, 0, self.max_value) 79 return w * (desired / (K.epsilon() + norms)) 80 81 @doc_controls.do_not_generate_docs 82 def get_config(self): 83 return {'max_value': self.max_value, 'axis': self.axis} 84 85 86@keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg') 87class NonNeg(Constraint): 88 """Constrains the weights to be non-negative. 89 90 Also available via the shortcut function `tf.keras.constraints.non_neg`. 91 """ 92 93 def __call__(self, w): 94 return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx()) 95 96 97@keras_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm') 98class UnitNorm(Constraint): 99 """Constrains the weights incident to each hidden unit to have unit norm. 100 101 Also available via the shortcut function `tf.keras.constraints.unit_norm`. 102 103 Args: 104 axis: integer, axis along which to calculate weight norms. 105 For instance, in a `Dense` layer the weight matrix 106 has shape `(input_dim, output_dim)`, 107 set `axis` to `0` to constrain each weight vector 108 of length `(input_dim,)`. 109 In a `Conv2D` layer with `data_format="channels_last"`, 110 the weight tensor has shape 111 `(rows, cols, input_depth, output_depth)`, 112 set `axis` to `[0, 1, 2]` 113 to constrain the weights of each filter tensor of size 114 `(rows, cols, input_depth)`. 115 """ 116 117 def __init__(self, axis=0): 118 self.axis = axis 119 120 @doc_controls.do_not_generate_docs 121 def __call__(self, w): 122 return w / ( 123 K.epsilon() + K.sqrt( 124 math_ops.reduce_sum( 125 math_ops.square(w), axis=self.axis, keepdims=True))) 126 127 @doc_controls.do_not_generate_docs 128 def get_config(self): 129 return {'axis': self.axis} 130 131 132@keras_export('keras.constraints.MinMaxNorm', 'keras.constraints.min_max_norm') 133class MinMaxNorm(Constraint): 134 """MinMaxNorm weight constraint. 135 136 Constrains the weights incident to each hidden unit 137 to have the norm between a lower bound and an upper bound. 138 139 Also available via the shortcut function `tf.keras.constraints.min_max_norm`. 140 141 Args: 142 min_value: the minimum norm for the incoming weights. 143 max_value: the maximum norm for the incoming weights. 144 rate: rate for enforcing the constraint: weights will be 145 rescaled to yield 146 `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`. 147 Effectively, this means that rate=1.0 stands for strict 148 enforcement of the constraint, while rate<1.0 means that 149 weights will be rescaled at each step to slowly move 150 towards a value inside the desired interval. 151 axis: integer, axis along which to calculate weight norms. 152 For instance, in a `Dense` layer the weight matrix 153 has shape `(input_dim, output_dim)`, 154 set `axis` to `0` to constrain each weight vector 155 of length `(input_dim,)`. 156 In a `Conv2D` layer with `data_format="channels_last"`, 157 the weight tensor has shape 158 `(rows, cols, input_depth, output_depth)`, 159 set `axis` to `[0, 1, 2]` 160 to constrain the weights of each filter tensor of size 161 `(rows, cols, input_depth)`. 162 """ 163 164 def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0): 165 self.min_value = min_value 166 self.max_value = max_value 167 self.rate = rate 168 self.axis = axis 169 170 @doc_controls.do_not_generate_docs 171 def __call__(self, w): 172 norms = K.sqrt( 173 math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) 174 desired = ( 175 self.rate * K.clip(norms, self.min_value, self.max_value) + 176 (1 - self.rate) * norms) 177 return w * (desired / (K.epsilon() + norms)) 178 179 @doc_controls.do_not_generate_docs 180 def get_config(self): 181 return { 182 'min_value': self.min_value, 183 'max_value': self.max_value, 184 'rate': self.rate, 185 'axis': self.axis 186 } 187 188 189@keras_export('keras.constraints.RadialConstraint', 190 'keras.constraints.radial_constraint') 191class RadialConstraint(Constraint): 192 """Constrains `Conv2D` kernel weights to be the same for each radius. 193 194 Also available via the shortcut function 195 `tf.keras.constraints.radial_constraint`. 196 197 For example, the desired output for the following 4-by-4 kernel: 198 199 ``` 200 kernel = [[v_00, v_01, v_02, v_03], 201 [v_10, v_11, v_12, v_13], 202 [v_20, v_21, v_22, v_23], 203 [v_30, v_31, v_32, v_33]] 204 ``` 205 206 is this:: 207 208 ``` 209 kernel = [[v_11, v_11, v_11, v_11], 210 [v_11, v_33, v_33, v_11], 211 [v_11, v_33, v_33, v_11], 212 [v_11, v_11, v_11, v_11]] 213 ``` 214 215 This constraint can be applied to any `Conv2D` layer version, including 216 `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or 217 `"channels_first"` data format. The method assumes the weight tensor is of 218 shape `(rows, cols, input_depth, output_depth)`. 219 """ 220 221 @doc_controls.do_not_generate_docs 222 def __call__(self, w): 223 w_shape = w.shape 224 if w_shape.rank is None or w_shape.rank != 4: 225 raise ValueError( 226 'The weight tensor must be of rank 4, but is of shape: %s' % w_shape) 227 228 height, width, channels, kernels = w_shape 229 w = K.reshape(w, (height, width, channels * kernels)) 230 # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once K.switch 231 # is supported. 232 w = K.map_fn( 233 self._kernel_constraint, 234 K.stack(array_ops.unstack(w, axis=-1), axis=0)) 235 return K.reshape(K.stack(array_ops.unstack(w, axis=0), axis=-1), 236 (height, width, channels, kernels)) 237 238 def _kernel_constraint(self, kernel): 239 """Radially constraints a kernel with shape (height, width, channels).""" 240 padding = K.constant([[1, 1], [1, 1]], dtype='int32') 241 242 kernel_shape = K.shape(kernel)[0] 243 start = K.cast(kernel_shape / 2, 'int32') 244 245 kernel_new = K.switch( 246 K.cast(math_ops.floormod(kernel_shape, 2), 'bool'), 247 lambda: kernel[start - 1:start, start - 1:start], 248 lambda: kernel[start - 1:start, start - 1:start] + K.zeros( # pylint: disable=g-long-lambda 249 (2, 2), dtype=kernel.dtype)) 250 index = K.switch( 251 K.cast(math_ops.floormod(kernel_shape, 2), 'bool'), 252 lambda: K.constant(0, dtype='int32'), 253 lambda: K.constant(1, dtype='int32')) 254 while_condition = lambda index, *args: K.less(index, start) 255 256 def body_fn(i, array): 257 return i + 1, array_ops.pad( 258 array, 259 padding, 260 constant_values=kernel[start + i, start + i]) 261 262 _, kernel_new = control_flow_ops.while_loop( 263 while_condition, 264 body_fn, 265 [index, kernel_new], 266 shape_invariants=[index.get_shape(), 267 tensor_shape.TensorShape([None, None])]) 268 return kernel_new 269 270 271# Aliases. 272 273max_norm = MaxNorm 274non_neg = NonNeg 275unit_norm = UnitNorm 276min_max_norm = MinMaxNorm 277radial_constraint = RadialConstraint 278 279# Legacy aliases. 280maxnorm = max_norm 281nonneg = non_neg 282unitnorm = unit_norm 283 284 285@keras_export('keras.constraints.serialize') 286def serialize(constraint): 287 return serialize_keras_object(constraint) 288 289 290@keras_export('keras.constraints.deserialize') 291def deserialize(config, custom_objects=None): 292 return deserialize_keras_object( 293 config, 294 module_objects=globals(), 295 custom_objects=custom_objects, 296 printable_module_name='constraint') 297 298 299@keras_export('keras.constraints.get') 300def get(identifier): 301 if identifier is None: 302 return None 303 if isinstance(identifier, dict): 304 return deserialize(identifier) 305 elif isinstance(identifier, six.string_types): 306 config = {'class_name': str(identifier), 'config': {}} 307 return deserialize(config) 308 elif callable(identifier): 309 return identifier 310 else: 311 raise ValueError('Could not interpret constraint identifier: ' + 312 str(identifier)) 313