1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""Constraints: functions that impose constraints on weight values. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import six 23 24from tensorflow.python.keras import backend as K 25from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 26from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 27from tensorflow.python.ops import math_ops 28from tensorflow.python.util.tf_export import keras_export 29 30 31@keras_export('keras.constraints.Constraint') 32class Constraint(object): 33 34 def __call__(self, w): 35 return w 36 37 def get_config(self): 38 return {} 39 40 41@keras_export('keras.constraints.MaxNorm', 'keras.constraints.max_norm') 42class MaxNorm(Constraint): 43 """MaxNorm weight constraint. 44 45 Constrains the weights incident to each hidden unit 46 to have a norm less than or equal to a desired value. 47 48 Arguments: 49 m: the maximum norm for the incoming weights. 50 axis: integer, axis along which to calculate weight norms. 51 For instance, in a `Dense` layer the weight matrix 52 has shape `(input_dim, output_dim)`, 53 set `axis` to `0` to constrain each weight vector 54 of length `(input_dim,)`. 55 In a `Conv2D` layer with `data_format="channels_last"`, 56 the weight tensor has shape 57 `(rows, cols, input_depth, output_depth)`, 58 set `axis` to `[0, 1, 2]` 59 to constrain the weights of each filter tensor of size 60 `(rows, cols, input_depth)`. 61 62 """ 63 64 def __init__(self, max_value=2, axis=0): 65 self.max_value = max_value 66 self.axis = axis 67 68 def __call__(self, w): 69 norms = K.sqrt( 70 math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) 71 desired = K.clip(norms, 0, self.max_value) 72 return w * (desired / (K.epsilon() + norms)) 73 74 def get_config(self): 75 return {'max_value': self.max_value, 'axis': self.axis} 76 77 78@keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg') 79class NonNeg(Constraint): 80 """Constrains the weights to be non-negative. 81 """ 82 83 def __call__(self, w): 84 return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx()) 85 86 87@keras_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm') 88class UnitNorm(Constraint): 89 """Constrains the weights incident to each hidden unit to have unit norm. 90 91 Arguments: 92 axis: integer, axis along which to calculate weight norms. 93 For instance, in a `Dense` layer the weight matrix 94 has shape `(input_dim, output_dim)`, 95 set `axis` to `0` to constrain each weight vector 96 of length `(input_dim,)`. 97 In a `Conv2D` layer with `data_format="channels_last"`, 98 the weight tensor has shape 99 `(rows, cols, input_depth, output_depth)`, 100 set `axis` to `[0, 1, 2]` 101 to constrain the weights of each filter tensor of size 102 `(rows, cols, input_depth)`. 103 """ 104 105 def __init__(self, axis=0): 106 self.axis = axis 107 108 def __call__(self, w): 109 return w / ( 110 K.epsilon() + K.sqrt( 111 math_ops.reduce_sum( 112 math_ops.square(w), axis=self.axis, keepdims=True))) 113 114 def get_config(self): 115 return {'axis': self.axis} 116 117 118@keras_export('keras.constraints.MinMaxNorm', 'keras.constraints.min_max_norm') 119class MinMaxNorm(Constraint): 120 """MinMaxNorm weight constraint. 121 122 Constrains the weights incident to each hidden unit 123 to have the norm between a lower bound and an upper bound. 124 125 Arguments: 126 min_value: the minimum norm for the incoming weights. 127 max_value: the maximum norm for the incoming weights. 128 rate: rate for enforcing the constraint: weights will be 129 rescaled to yield 130 `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`. 131 Effectively, this means that rate=1.0 stands for strict 132 enforcement of the constraint, while rate<1.0 means that 133 weights will be rescaled at each step to slowly move 134 towards a value inside the desired interval. 135 axis: integer, axis along which to calculate weight norms. 136 For instance, in a `Dense` layer the weight matrix 137 has shape `(input_dim, output_dim)`, 138 set `axis` to `0` to constrain each weight vector 139 of length `(input_dim,)`. 140 In a `Conv2D` layer with `data_format="channels_last"`, 141 the weight tensor has shape 142 `(rows, cols, input_depth, output_depth)`, 143 set `axis` to `[0, 1, 2]` 144 to constrain the weights of each filter tensor of size 145 `(rows, cols, input_depth)`. 146 """ 147 148 def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0): 149 self.min_value = min_value 150 self.max_value = max_value 151 self.rate = rate 152 self.axis = axis 153 154 def __call__(self, w): 155 norms = K.sqrt( 156 math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) 157 desired = ( 158 self.rate * K.clip(norms, self.min_value, self.max_value) + 159 (1 - self.rate) * norms) 160 return w * (desired / (K.epsilon() + norms)) 161 162 def get_config(self): 163 return { 164 'min_value': self.min_value, 165 'max_value': self.max_value, 166 'rate': self.rate, 167 'axis': self.axis 168 } 169 170 171# Aliases. 172 173max_norm = MaxNorm 174non_neg = NonNeg 175unit_norm = UnitNorm 176min_max_norm = MinMaxNorm 177 178# Legacy aliases. 179maxnorm = max_norm 180nonneg = non_neg 181unitnorm = unit_norm 182 183 184@keras_export('keras.constraints.serialize') 185def serialize(constraint): 186 return serialize_keras_object(constraint) 187 188 189@keras_export('keras.constraints.deserialize') 190def deserialize(config, custom_objects=None): 191 return deserialize_keras_object( 192 config, 193 module_objects=globals(), 194 custom_objects=custom_objects, 195 printable_module_name='constraint') 196 197 198@keras_export('keras.constraints.get') 199def get(identifier): 200 if identifier is None: 201 return None 202 if isinstance(identifier, dict): 203 return deserialize(identifier) 204 elif isinstance(identifier, six.string_types): 205 config = {'class_name': str(identifier), 'config': {}} 206 return deserialize(config) 207 elif callable(identifier): 208 return identifier 209 else: 210 raise ValueError('Could not interpret constraint identifier: ' + 211 str(identifier)) 212