1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Wrapper for using the Scikit-Learn API with Keras models. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import types 23 24import numpy as np 25 26from tensorflow.python.keras import losses 27from tensorflow.python.keras.models import Sequential 28from tensorflow.python.keras.utils.generic_utils import has_arg 29from tensorflow.python.keras.utils.np_utils import to_categorical 30from tensorflow.python.util.tf_export import keras_export 31 32 33class BaseWrapper(object): 34 """Base class for the Keras scikit-learn wrapper. 35 36 Warning: This class should not be used directly. 37 Use descendant classes instead. 38 39 Args: 40 build_fn: callable function or class instance 41 **sk_params: model parameters & fitting parameters 42 43 The `build_fn` should construct, compile and return a Keras model, which 44 will then be used to fit/predict. One of the following 45 three values could be passed to `build_fn`: 46 1. A function 47 2. An instance of a class that implements the `__call__` method 48 3. None. This means you implement a class that inherits from either 49 `KerasClassifier` or `KerasRegressor`. The `__call__` method of the 50 present class will then be treated as the default `build_fn`. 51 52 `sk_params` takes both model parameters and fitting parameters. Legal model 53 parameters are the arguments of `build_fn`. Note that like all other 54 estimators in scikit-learn, `build_fn` should provide default values for 55 its arguments, so that you could create the estimator without passing any 56 values to `sk_params`. 57 58 `sk_params` could also accept parameters for calling `fit`, `predict`, 59 `predict_proba`, and `score` methods (e.g., `epochs`, `batch_size`). 60 fitting (predicting) parameters are selected in the following order: 61 62 1. Values passed to the dictionary arguments of 63 `fit`, `predict`, `predict_proba`, and `score` methods 64 2. Values passed to `sk_params` 65 3. The default values of the `keras.models.Sequential` 66 `fit`, `predict`, `predict_proba` and `score` methods 67 68 When using scikit-learn's `grid_search` API, legal tunable parameters are 69 those you could pass to `sk_params`, including fitting parameters. 70 In other words, you could use `grid_search` to search for the best 71 `batch_size` or `epochs` as well as the model parameters. 72 """ 73 74 def __init__(self, build_fn=None, **sk_params): 75 self.build_fn = build_fn 76 self.sk_params = sk_params 77 self.check_params(sk_params) 78 79 def check_params(self, params): 80 """Checks for user typos in `params`. 81 82 Args: 83 params: dictionary; the parameters to be checked 84 85 Raises: 86 ValueError: if any member of `params` is not a valid argument. 87 """ 88 legal_params_fns = [ 89 Sequential.fit, Sequential.predict, Sequential.predict_classes, 90 Sequential.evaluate 91 ] 92 if self.build_fn is None: 93 legal_params_fns.append(self.__call__) 94 elif (not isinstance(self.build_fn, types.FunctionType) and 95 not isinstance(self.build_fn, types.MethodType)): 96 legal_params_fns.append(self.build_fn.__call__) 97 else: 98 legal_params_fns.append(self.build_fn) 99 100 for params_name in params: 101 for fn in legal_params_fns: 102 if has_arg(fn, params_name): 103 break 104 else: 105 if params_name != 'nb_epoch': 106 raise ValueError('{} is not a legal parameter'.format(params_name)) 107 108 def get_params(self, **params): # pylint: disable=unused-argument 109 """Gets parameters for this estimator. 110 111 Args: 112 **params: ignored (exists for API compatibility). 113 114 Returns: 115 Dictionary of parameter names mapped to their values. 116 """ 117 res = self.sk_params.copy() 118 res.update({'build_fn': self.build_fn}) 119 return res 120 121 def set_params(self, **params): 122 """Sets the parameters of this estimator. 123 124 Args: 125 **params: Dictionary of parameter names mapped to their values. 126 127 Returns: 128 self 129 """ 130 self.check_params(params) 131 self.sk_params.update(params) 132 return self 133 134 def fit(self, x, y, **kwargs): 135 """Constructs a new model with `build_fn` & fit the model to `(x, y)`. 136 137 Args: 138 x : array-like, shape `(n_samples, n_features)` 139 Training samples where `n_samples` is the number of samples 140 and `n_features` is the number of features. 141 y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` 142 True labels for `x`. 143 **kwargs: dictionary arguments 144 Legal arguments are the arguments of `Sequential.fit` 145 146 Returns: 147 history : object 148 details about the training history at each epoch. 149 """ 150 if self.build_fn is None: 151 self.model = self.__call__(**self.filter_sk_params(self.__call__)) 152 elif (not isinstance(self.build_fn, types.FunctionType) and 153 not isinstance(self.build_fn, types.MethodType)): 154 self.model = self.build_fn( 155 **self.filter_sk_params(self.build_fn.__call__)) 156 else: 157 self.model = self.build_fn(**self.filter_sk_params(self.build_fn)) 158 159 if (losses.is_categorical_crossentropy(self.model.loss) and 160 len(y.shape) != 2): 161 y = to_categorical(y) 162 163 fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit)) 164 fit_args.update(kwargs) 165 166 history = self.model.fit(x, y, **fit_args) 167 168 return history 169 170 def filter_sk_params(self, fn, override=None): 171 """Filters `sk_params` and returns those in `fn`'s arguments. 172 173 Args: 174 fn : arbitrary function 175 override: dictionary, values to override `sk_params` 176 177 Returns: 178 res : dictionary containing variables 179 in both `sk_params` and `fn`'s arguments. 180 """ 181 override = override or {} 182 res = {} 183 for name, value in self.sk_params.items(): 184 if has_arg(fn, name): 185 res.update({name: value}) 186 res.update(override) 187 return res 188 189 190@keras_export('keras.wrappers.scikit_learn.KerasClassifier') 191class KerasClassifier(BaseWrapper): 192 """Implementation of the scikit-learn classifier API for Keras. 193 """ 194 195 def fit(self, x, y, **kwargs): 196 """Constructs a new model with `build_fn` & fit the model to `(x, y)`. 197 198 Args: 199 x : array-like, shape `(n_samples, n_features)` 200 Training samples where `n_samples` is the number of samples 201 and `n_features` is the number of features. 202 y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` 203 True labels for `x`. 204 **kwargs: dictionary arguments 205 Legal arguments are the arguments of `Sequential.fit` 206 207 Returns: 208 history : object 209 details about the training history at each epoch. 210 211 Raises: 212 ValueError: In case of invalid shape for `y` argument. 213 """ 214 y = np.array(y) 215 if len(y.shape) == 2 and y.shape[1] > 1: 216 self.classes_ = np.arange(y.shape[1]) 217 elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1: 218 self.classes_ = np.unique(y) 219 y = np.searchsorted(self.classes_, y) 220 else: 221 raise ValueError('Invalid shape for y: ' + str(y.shape)) 222 self.n_classes_ = len(self.classes_) 223 return super(KerasClassifier, self).fit(x, y, **kwargs) 224 225 def predict(self, x, **kwargs): 226 """Returns the class predictions for the given test data. 227 228 Args: 229 x: array-like, shape `(n_samples, n_features)` 230 Test samples where `n_samples` is the number of samples 231 and `n_features` is the number of features. 232 **kwargs: dictionary arguments 233 Legal arguments are the arguments 234 of `Sequential.predict_classes`. 235 236 Returns: 237 preds: array-like, shape `(n_samples,)` 238 Class predictions. 239 """ 240 kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs) 241 classes = self.model.predict_classes(x, **kwargs) 242 return self.classes_[classes] 243 244 def predict_proba(self, x, **kwargs): 245 """Returns class probability estimates for the given test data. 246 247 Args: 248 x: array-like, shape `(n_samples, n_features)` 249 Test samples where `n_samples` is the number of samples 250 and `n_features` is the number of features. 251 **kwargs: dictionary arguments 252 Legal arguments are the arguments 253 of `Sequential.predict_classes`. 254 255 Returns: 256 proba: array-like, shape `(n_samples, n_outputs)` 257 Class probability estimates. 258 In the case of binary classification, 259 to match the scikit-learn API, 260 will return an array of shape `(n_samples, 2)` 261 (instead of `(n_sample, 1)` as in Keras). 262 """ 263 kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs) 264 probs = self.model.predict(x, **kwargs) 265 266 # check if binary classification 267 if probs.shape[1] == 1: 268 # first column is probability of class 0 and second is of class 1 269 probs = np.hstack([1 - probs, probs]) 270 return probs 271 272 def score(self, x, y, **kwargs): 273 """Returns the mean accuracy on the given test data and labels. 274 275 Args: 276 x: array-like, shape `(n_samples, n_features)` 277 Test samples where `n_samples` is the number of samples 278 and `n_features` is the number of features. 279 y: array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` 280 True labels for `x`. 281 **kwargs: dictionary arguments 282 Legal arguments are the arguments of `Sequential.evaluate`. 283 284 Returns: 285 score: float 286 Mean accuracy of predictions on `x` wrt. `y`. 287 288 Raises: 289 ValueError: If the underlying model isn't configured to 290 compute accuracy. You should pass `metrics=["accuracy"]` to 291 the `.compile()` method of the model. 292 """ 293 y = np.searchsorted(self.classes_, y) 294 kwargs = self.filter_sk_params(Sequential.evaluate, kwargs) 295 296 loss_name = self.model.loss 297 if hasattr(loss_name, '__name__'): 298 loss_name = loss_name.__name__ 299 if loss_name == 'categorical_crossentropy' and len(y.shape) != 2: 300 y = to_categorical(y) 301 302 outputs = self.model.evaluate(x, y, **kwargs) 303 if not isinstance(outputs, list): 304 outputs = [outputs] 305 for name, output in zip(self.model.metrics_names, outputs): 306 if name in ['accuracy', 'acc']: 307 return output 308 raise ValueError('The model is not configured to compute accuracy. ' 309 'You should pass `metrics=["accuracy"]` to ' 310 'the `model.compile()` method.') 311 312 313@keras_export('keras.wrappers.scikit_learn.KerasRegressor') 314class KerasRegressor(BaseWrapper): 315 """Implementation of the scikit-learn regressor API for Keras. 316 """ 317 318 def predict(self, x, **kwargs): 319 """Returns predictions for the given test data. 320 321 Args: 322 x: array-like, shape `(n_samples, n_features)` 323 Test samples where `n_samples` is the number of samples 324 and `n_features` is the number of features. 325 **kwargs: dictionary arguments 326 Legal arguments are the arguments of `Sequential.predict`. 327 328 Returns: 329 preds: array-like, shape `(n_samples,)` 330 Predictions. 331 """ 332 kwargs = self.filter_sk_params(Sequential.predict, kwargs) 333 return np.squeeze(self.model.predict(x, **kwargs)) 334 335 def score(self, x, y, **kwargs): 336 """Returns the mean loss on the given test data and labels. 337 338 Args: 339 x: array-like, shape `(n_samples, n_features)` 340 Test samples where `n_samples` is the number of samples 341 and `n_features` is the number of features. 342 y: array-like, shape `(n_samples,)` 343 True labels for `x`. 344 **kwargs: dictionary arguments 345 Legal arguments are the arguments of `Sequential.evaluate`. 346 347 Returns: 348 score: float 349 Mean accuracy of predictions on `x` wrt. `y`. 350 """ 351 kwargs = self.filter_sk_params(Sequential.evaluate, kwargs) 352 loss = self.model.evaluate(x, y, **kwargs) 353 if isinstance(loss, list): 354 return -loss[0] 355 return -loss 356