1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras layers that implement explicit (approximate) kernel feature maps.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22import six 23 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.keras import initializers 28from tensorflow.python.keras.engine import base_layer 29from tensorflow.python.keras.engine import input_spec 30from tensorflow.python.ops import gen_math_ops 31from tensorflow.python.ops import init_ops 32from tensorflow.python.ops import nn 33 34_SUPPORTED_RBF_KERNEL_TYPES = ['gaussian', 'laplacian'] 35 36 37class RandomFourierFeatures(base_layer.Layer): 38 r"""Layer that maps its inputs using random Fourier features. 39 40 This layer implements a feature map \\(\phi: \mathbb{R}^d \rightarrow 41 \mathbb{R}^D\\) which approximates shift-invariant kernels. A kernel function 42 K(x, y) defined over \\(\mathbb{R}^d x \mathbb{R}^d\\) is shift-invariant if 43 K(x, y) = k(x-y) for some function defined over \\(\mathbb{R}^d\\). Many 44 popular Radial Basis Functions (in short RBF), including gaussian and 45 laplacian kernels are shift-invariant. 46 47 The layer approximates a (shift invariant) kernel K in the following sense: 48 up to a scaling factor, for all inputs \\(x, y \in \mathbb{R}^d\\) 49 \\(\phi(x)^T \cdot \phi(y) \approx K(x, y)\\) 50 51 The implementation of this layer is based on the following paper: 52 "Random Features for Large-Scale Kernel Machines" by Ali Rahimi and Ben Recht. 53 (link: https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf) 54 55 The distribution from which the parameters of the random features map (layer) 56 are sampled, determines which shift-invariant kernel the layer approximates 57 (see paper for more details). The users can use the distribution of their 58 choice. Due to their popularity, the layer supports the out-of-the-box 59 approximation of the following RBF kernels: 60 - Gaussian: \\(K(x, y) = e^{-\frac{\|x-y\|_2^2}{2 \cdot scale^2}}\\) 61 - Laplacian: \\(K(x, y) = e^{-\frac{\|x-y\|_1}{scale}}\\) 62 63 NOTE: Unlike the map described in the paper and the scikit-learn 64 implementation, the output of this layer does not apply the sqrt(2/D) 65 normalization factor. 66 67 Usage for ML: Typically, this layer is used to "kernelize" linear models by 68 applying a non-linear transformation (this layer) to the input features and 69 then training a linear model on top of the transformed features. Depending on 70 the loss function of the linear model, the composition of this layer and the 71 linear model results to models that are equivalent (up to approximation) to 72 kernel SVMs (for hinge loss), kernel logistic regression (for logistic loss), 73 kernel linear regression (for squared loss) etc. 74 75 Example of building a kernel multinomial logistic regression model with 76 Gaussian kernel in keras: 77 ```python 78 random_features_layer = RandomFourierFeatures( 79 output_dim=500, 80 kernel_initializer='gaussian', 81 scale=5.0, 82 ...) 83 84 model = tf.keras.models.Sequential() 85 model.add(random_features_layer) 86 model.add(tf.keras.layers.Dense(units=num_classes, activation='softmax') 87 88 model.compile(elif isinstance(identifier, six.string_types): 89 loss=tf.keras.losses.categorical_crossentropy, optimizer=..., metrics=...) 90 ``` 91 92 To use another kernel, replace the layer creation command with: 93 ```python 94 random_features_layer = RandomFourierFeatures( 95 output_dim=500, 96 kernel_initializer=<my_initializer>, 97 scale=..., 98 ...) 99 ``` 100 101 Arguments: 102 output_dim: Positive integer, the dimension of the layer's output, i.e., the 103 number of random features used to approximate the kernel. 104 kernel_initializer: Determines the distribution of the parameters of the 105 random features map (and therefore the kernel approximated by the layer). 106 It can be either a string or an instance of TensorFlow's Initializer 107 class. Currently only 'gaussian' and 'laplacian' are supported as string 108 initializers (case insensitive). Note that these parameters are not 109 trainable. 110 scale: For gaussian and laplacian kernels, this corresponds to a scaling 111 factor of the corresponding kernel approximated by the layer (see concrete 112 definitions above). When provided, it should be a positive float. If None, 113 the implementation chooses a default value (1.0 typically). Both the 114 approximation error of the kernel and the classification quality are 115 sensitive to this parameter. If trainable is set to True, this paramater 116 is learned end-to-end during training and the provided value serves as an 117 initialization value. 118 NOTE: When this layer is used to map the initial features and then the 119 transformed features are fed to a linear model, by making `scale` 120 trainable, the resulting optimization problem is no longer convex (even 121 if the loss function used by the linear model is convex). 122 trainable: Whether the scaling parameter of th layer is trainable. Defaults 123 to False. 124 name: name for the RandomFourierFeatures layer. 125 126 Raises: 127 ValueError: if output_dim or stddev are not positive or if the provided 128 kernel_initializer is not supported. 129 """ 130 131 def __init__(self, 132 output_dim, 133 kernel_initializer='gaussian', 134 scale=None, 135 trainable=False, 136 name=None, 137 **kwargs): 138 if output_dim <= 0: 139 raise ValueError( 140 '`output_dim` should be a positive integer. Given: {}.'.format( 141 output_dim)) 142 if isinstance(kernel_initializer, six.string_types): 143 if kernel_initializer.lower() not in _SUPPORTED_RBF_KERNEL_TYPES: 144 raise ValueError( 145 'Unsupported kernel type: \'{}\'. Supported kernel types: {}.' 146 .format(kernel_initializer, _SUPPORTED_RBF_KERNEL_TYPES)) 147 if scale is not None and scale <= 0.0: 148 raise ValueError('When provided, `scale` should be a positive float. ' 149 'Given: {}.'.format(scale)) 150 super(RandomFourierFeatures, self).__init__( 151 trainable=trainable, name=name, **kwargs) 152 self.output_dim = output_dim 153 self.kernel_initializer = kernel_initializer 154 self.scale = scale 155 156 def build(self, input_shape): 157 input_shape = tensor_shape.TensorShape(input_shape) 158 # TODO(sibyl-vie3Poto): Allow higher dimension inputs. Currently the input is expected 159 # to have shape [batch_size, dimension]. 160 if input_shape.rank != 2: 161 raise ValueError( 162 'The rank of the input tensor should be 2. Got {} instead.'.format( 163 input_shape.ndims)) 164 if input_shape.dims[1].value is None: 165 raise ValueError( 166 'The last dimension of the inputs to `RandomFourierFeatures` ' 167 'should be defined. Found `None`.') 168 self.input_spec = input_spec.InputSpec( 169 ndim=2, axes={1: input_shape.dims[1].value}) 170 input_dim = input_shape.dims[1].value 171 172 kernel_initializer = _get_random_features_initializer( 173 self.kernel_initializer, shape=(input_dim, self.output_dim)) 174 175 unscaled_kernel = self.add_weight( 176 name='unscaled_random_features', 177 shape=(input_dim, self.output_dim), 178 dtype=dtypes.float32, 179 initializer=kernel_initializer, 180 trainable=False) 181 182 self.bias = self.add_weight( 183 name='random_features_bias', 184 shape=(self.output_dim,), 185 dtype=dtypes.float32, 186 initializer=init_ops.random_uniform_initializer( 187 minval=0.0, maxval=2 * np.pi, dtype=dtypes.float32), 188 trainable=False) 189 190 if self.scale is None: 191 self.scale = _get_default_scale(self.kernel_initializer, input_dim) 192 scale = self.add_weight( 193 name='random_features_scale', 194 shape=(1,), 195 dtype=dtypes.float32, 196 initializer=init_ops.constant_initializer(self.scale), 197 trainable=True, 198 constraint='NonNeg') 199 self.kernel = (1.0 / scale) * unscaled_kernel 200 super(RandomFourierFeatures, self).build(input_shape) 201 202 def call(self, inputs): 203 inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) 204 inputs = gen_math_ops.cast(inputs, dtypes.float32) 205 outputs = gen_math_ops.mat_mul(inputs, self.kernel) 206 outputs = nn.bias_add(outputs, self.bias) 207 return gen_math_ops.cos(outputs) 208 209 def compute_output_shape(self, input_shape): 210 input_shape = tensor_shape.TensorShape(input_shape) 211 input_shape = input_shape.with_rank(2) 212 if input_shape.dims[-1].value is None: 213 raise ValueError( 214 'The innermost dimension of input shape must be defined. Given: %s' % 215 input_shape) 216 return input_shape[:-1].concatenate(self.output_dim) 217 218 def get_config(self): 219 kernel_initializer = self.kernel_initializer 220 if isinstance(self.kernel_initializer, init_ops.Initializer): 221 kernel_initializer = initializers.serialize(self.kernel_initializer) 222 config = { 223 'output_dim': self.output_dim, 224 'kernel_initializer': kernel_initializer, 225 'scale': self.scale, 226 } 227 base_config = super(RandomFourierFeatures, self).get_config() 228 return dict(list(base_config.items()) + list(config.items())) 229 230 231def _get_random_features_initializer(initializer, shape): 232 """Returns Initializer object for random features.""" 233 234 def _get_cauchy_samples(loc, scale, shape): 235 probs = np.random.uniform(low=0., high=1., size=shape) 236 return loc + scale * np.tan(np.pi * (probs - 0.5)) 237 238 random_features_initializer = initializer 239 if isinstance(initializer, six.string_types): 240 if initializer.lower() == 'gaussian': 241 random_features_initializer = init_ops.random_normal_initializer( 242 stddev=1.0) 243 elif initializer.lower() == 'laplacian': 244 random_features_initializer = init_ops.constant_initializer( 245 _get_cauchy_samples(loc=0.0, scale=1.0, shape=shape)) 246 247 else: 248 raise ValueError( 249 'Unsupported kernel type: \'{}\'. Supported kernel types: {}.'.format( 250 random_features_initializer, _SUPPORTED_RBF_KERNEL_TYPES)) 251 return random_features_initializer 252 253 254def _get_default_scale(initializer, input_dim): 255 if (isinstance(initializer, six.string_types) and 256 initializer.lower() == 'gaussian'): 257 return np.sqrt(input_dim / 2.0) 258 return 1.0 259