1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Built-in regularizers. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22 23from tensorflow.python.keras import backend as K 24from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 25from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 26from tensorflow.python.ops import math_ops 27from tensorflow.python.util.tf_export import keras_export 28 29 30@keras_export('keras.regularizers.Regularizer') 31class Regularizer(object): 32 """Regularizer base class. 33 """ 34 35 def __call__(self, x): 36 return 0. 37 38 @classmethod 39 def from_config(cls, config): 40 return cls(**config) 41 42 43@keras_export('keras.regularizers.L1L2') 44class L1L2(Regularizer): 45 """Regularizer for L1 and L2 regularization. 46 47 Arguments: 48 l1: Float; L1 regularization factor. 49 l2: Float; L2 regularization factor. 50 """ 51 52 def __init__(self, l1=0., l2=0.): # pylint: disable=redefined-outer-name 53 self.l1 = K.cast_to_floatx(l1) 54 self.l2 = K.cast_to_floatx(l2) 55 56 def __call__(self, x): 57 if not self.l1 and not self.l2: 58 return K.constant(0.) 59 regularization = 0. 60 if self.l1: 61 regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x)) 62 if self.l2: 63 regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x)) 64 return regularization 65 66 def get_config(self): 67 return {'l1': float(self.l1), 'l2': float(self.l2)} 68 69 70# Aliases. 71 72 73@keras_export('keras.regularizers.l1') 74def l1(l=0.01): 75 return L1L2(l1=l) 76 77 78@keras_export('keras.regularizers.l2') 79def l2(l=0.01): 80 return L1L2(l2=l) 81 82 83@keras_export('keras.regularizers.l1_l2') 84def l1_l2(l1=0.01, l2=0.01): # pylint: disable=redefined-outer-name 85 return L1L2(l1=l1, l2=l2) 86 87 88@keras_export('keras.regularizers.serialize') 89def serialize(regularizer): 90 return serialize_keras_object(regularizer) 91 92 93@keras_export('keras.regularizers.deserialize') 94def deserialize(config, custom_objects=None): 95 return deserialize_keras_object( 96 config, 97 module_objects=globals(), 98 custom_objects=custom_objects, 99 printable_module_name='regularizer') 100 101 102@keras_export('keras.regularizers.get') 103def get(identifier): 104 if identifier is None: 105 return None 106 if isinstance(identifier, dict): 107 return deserialize(identifier) 108 elif isinstance(identifier, six.string_types): 109 config = {'class_name': str(identifier), 'config': {}} 110 return deserialize(config) 111 elif callable(identifier): 112 return identifier 113 else: 114 raise ValueError('Could not interpret regularizer identifier:', identifier) 115