1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Wrapper for using the Scikit-Learn API with Keras models.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import types
23
24import numpy as np
25
26from tensorflow.python.keras import losses
27from tensorflow.python.keras.models import Sequential
28from tensorflow.python.keras.utils.generic_utils import has_arg
29from tensorflow.python.keras.utils.np_utils import to_categorical
30from tensorflow.python.util.tf_export import keras_export
31
32
33class BaseWrapper(object):
34  """Base class for the Keras scikit-learn wrapper.
35
36  Warning: This class should not be used directly.
37  Use descendant classes instead.
38
39  Args:
40      build_fn: callable function or class instance
41      **sk_params: model parameters & fitting parameters
42
43  The `build_fn` should construct, compile and return a Keras model, which
44  will then be used to fit/predict. One of the following
45  three values could be passed to `build_fn`:
46  1. A function
47  2. An instance of a class that implements the `__call__` method
48  3. None. This means you implement a class that inherits from either
49  `KerasClassifier` or `KerasRegressor`. The `__call__` method of the
50  present class will then be treated as the default `build_fn`.
51
52  `sk_params` takes both model parameters and fitting parameters. Legal model
53  parameters are the arguments of `build_fn`. Note that like all other
54  estimators in scikit-learn, `build_fn` should provide default values for
55  its arguments, so that you could create the estimator without passing any
56  values to `sk_params`.
57
58  `sk_params` could also accept parameters for calling `fit`, `predict`,
59  `predict_proba`, and `score` methods (e.g., `epochs`, `batch_size`).
60  fitting (predicting) parameters are selected in the following order:
61
62  1. Values passed to the dictionary arguments of
63  `fit`, `predict`, `predict_proba`, and `score` methods
64  2. Values passed to `sk_params`
65  3. The default values of the `keras.models.Sequential`
66  `fit`, `predict`, `predict_proba` and `score` methods
67
68  When using scikit-learn's `grid_search` API, legal tunable parameters are
69  those you could pass to `sk_params`, including fitting parameters.
70  In other words, you could use `grid_search` to search for the best
71  `batch_size` or `epochs` as well as the model parameters.
72  """
73
74  def __init__(self, build_fn=None, **sk_params):
75    self.build_fn = build_fn
76    self.sk_params = sk_params
77    self.check_params(sk_params)
78
79  def check_params(self, params):
80    """Checks for user typos in `params`.
81
82    Args:
83        params: dictionary; the parameters to be checked
84
85    Raises:
86        ValueError: if any member of `params` is not a valid argument.
87    """
88    legal_params_fns = [
89        Sequential.fit, Sequential.predict, Sequential.predict_classes,
90        Sequential.evaluate
91    ]
92    if self.build_fn is None:
93      legal_params_fns.append(self.__call__)
94    elif (not isinstance(self.build_fn, types.FunctionType) and
95          not isinstance(self.build_fn, types.MethodType)):
96      legal_params_fns.append(self.build_fn.__call__)
97    else:
98      legal_params_fns.append(self.build_fn)
99
100    for params_name in params:
101      for fn in legal_params_fns:
102        if has_arg(fn, params_name):
103          break
104      else:
105        if params_name != 'nb_epoch':
106          raise ValueError('{} is not a legal parameter'.format(params_name))
107
108  def get_params(self, **params):  # pylint: disable=unused-argument
109    """Gets parameters for this estimator.
110
111    Args:
112        **params: ignored (exists for API compatibility).
113
114    Returns:
115        Dictionary of parameter names mapped to their values.
116    """
117    res = self.sk_params.copy()
118    res.update({'build_fn': self.build_fn})
119    return res
120
121  def set_params(self, **params):
122    """Sets the parameters of this estimator.
123
124    Args:
125        **params: Dictionary of parameter names mapped to their values.
126
127    Returns:
128        self
129    """
130    self.check_params(params)
131    self.sk_params.update(params)
132    return self
133
134  def fit(self, x, y, **kwargs):
135    """Constructs a new model with `build_fn` & fit the model to `(x, y)`.
136
137    Args:
138        x : array-like, shape `(n_samples, n_features)`
139            Training samples where `n_samples` is the number of samples
140            and `n_features` is the number of features.
141        y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
142            True labels for `x`.
143        **kwargs: dictionary arguments
144            Legal arguments are the arguments of `Sequential.fit`
145
146    Returns:
147        history : object
148            details about the training history at each epoch.
149    """
150    if self.build_fn is None:
151      self.model = self.__call__(**self.filter_sk_params(self.__call__))
152    elif (not isinstance(self.build_fn, types.FunctionType) and
153          not isinstance(self.build_fn, types.MethodType)):
154      self.model = self.build_fn(
155          **self.filter_sk_params(self.build_fn.__call__))
156    else:
157      self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
158
159    if (losses.is_categorical_crossentropy(self.model.loss) and
160        len(y.shape) != 2):
161      y = to_categorical(y)
162
163    fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit))
164    fit_args.update(kwargs)
165
166    history = self.model.fit(x, y, **fit_args)
167
168    return history
169
170  def filter_sk_params(self, fn, override=None):
171    """Filters `sk_params` and returns those in `fn`'s arguments.
172
173    Args:
174        fn : arbitrary function
175        override: dictionary, values to override `sk_params`
176
177    Returns:
178        res : dictionary containing variables
179            in both `sk_params` and `fn`'s arguments.
180    """
181    override = override or {}
182    res = {}
183    for name, value in self.sk_params.items():
184      if has_arg(fn, name):
185        res.update({name: value})
186    res.update(override)
187    return res
188
189
190@keras_export('keras.wrappers.scikit_learn.KerasClassifier')
191class KerasClassifier(BaseWrapper):
192  """Implementation of the scikit-learn classifier API for Keras.
193  """
194
195  def fit(self, x, y, **kwargs):
196    """Constructs a new model with `build_fn` & fit the model to `(x, y)`.
197
198    Args:
199        x : array-like, shape `(n_samples, n_features)`
200            Training samples where `n_samples` is the number of samples
201            and `n_features` is the number of features.
202        y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
203            True labels for `x`.
204        **kwargs: dictionary arguments
205            Legal arguments are the arguments of `Sequential.fit`
206
207    Returns:
208        history : object
209            details about the training history at each epoch.
210
211    Raises:
212        ValueError: In case of invalid shape for `y` argument.
213    """
214    y = np.array(y)
215    if len(y.shape) == 2 and y.shape[1] > 1:
216      self.classes_ = np.arange(y.shape[1])
217    elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1:
218      self.classes_ = np.unique(y)
219      y = np.searchsorted(self.classes_, y)
220    else:
221      raise ValueError('Invalid shape for y: ' + str(y.shape))
222    self.n_classes_ = len(self.classes_)
223    return super(KerasClassifier, self).fit(x, y, **kwargs)
224
225  def predict(self, x, **kwargs):
226    """Returns the class predictions for the given test data.
227
228    Args:
229        x: array-like, shape `(n_samples, n_features)`
230            Test samples where `n_samples` is the number of samples
231            and `n_features` is the number of features.
232        **kwargs: dictionary arguments
233            Legal arguments are the arguments
234            of `Sequential.predict_classes`.
235
236    Returns:
237        preds: array-like, shape `(n_samples,)`
238            Class predictions.
239    """
240    kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
241    classes = self.model.predict_classes(x, **kwargs)
242    return self.classes_[classes]
243
244  def predict_proba(self, x, **kwargs):
245    """Returns class probability estimates for the given test data.
246
247    Args:
248        x: array-like, shape `(n_samples, n_features)`
249            Test samples where `n_samples` is the number of samples
250            and `n_features` is the number of features.
251        **kwargs: dictionary arguments
252            Legal arguments are the arguments
253            of `Sequential.predict_classes`.
254
255    Returns:
256        proba: array-like, shape `(n_samples, n_outputs)`
257            Class probability estimates.
258            In the case of binary classification,
259            to match the scikit-learn API,
260            will return an array of shape `(n_samples, 2)`
261            (instead of `(n_sample, 1)` as in Keras).
262    """
263    kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs)
264    probs = self.model.predict(x, **kwargs)
265
266    # check if binary classification
267    if probs.shape[1] == 1:
268      # first column is probability of class 0 and second is of class 1
269      probs = np.hstack([1 - probs, probs])
270    return probs
271
272  def score(self, x, y, **kwargs):
273    """Returns the mean accuracy on the given test data and labels.
274
275    Args:
276        x: array-like, shape `(n_samples, n_features)`
277            Test samples where `n_samples` is the number of samples
278            and `n_features` is the number of features.
279        y: array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
280            True labels for `x`.
281        **kwargs: dictionary arguments
282            Legal arguments are the arguments of `Sequential.evaluate`.
283
284    Returns:
285        score: float
286            Mean accuracy of predictions on `x` wrt. `y`.
287
288    Raises:
289        ValueError: If the underlying model isn't configured to
290            compute accuracy. You should pass `metrics=["accuracy"]` to
291            the `.compile()` method of the model.
292    """
293    y = np.searchsorted(self.classes_, y)
294    kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
295
296    loss_name = self.model.loss
297    if hasattr(loss_name, '__name__'):
298      loss_name = loss_name.__name__
299    if loss_name == 'categorical_crossentropy' and len(y.shape) != 2:
300      y = to_categorical(y)
301
302    outputs = self.model.evaluate(x, y, **kwargs)
303    if not isinstance(outputs, list):
304      outputs = [outputs]
305    for name, output in zip(self.model.metrics_names, outputs):
306      if name in ['accuracy', 'acc']:
307        return output
308    raise ValueError('The model is not configured to compute accuracy. '
309                     'You should pass `metrics=["accuracy"]` to '
310                     'the `model.compile()` method.')
311
312
313@keras_export('keras.wrappers.scikit_learn.KerasRegressor')
314class KerasRegressor(BaseWrapper):
315  """Implementation of the scikit-learn regressor API for Keras.
316  """
317
318  def predict(self, x, **kwargs):
319    """Returns predictions for the given test data.
320
321    Args:
322        x: array-like, shape `(n_samples, n_features)`
323            Test samples where `n_samples` is the number of samples
324            and `n_features` is the number of features.
325        **kwargs: dictionary arguments
326            Legal arguments are the arguments of `Sequential.predict`.
327
328    Returns:
329        preds: array-like, shape `(n_samples,)`
330            Predictions.
331    """
332    kwargs = self.filter_sk_params(Sequential.predict, kwargs)
333    return np.squeeze(self.model.predict(x, **kwargs))
334
335  def score(self, x, y, **kwargs):
336    """Returns the mean loss on the given test data and labels.
337
338    Args:
339        x: array-like, shape `(n_samples, n_features)`
340            Test samples where `n_samples` is the number of samples
341            and `n_features` is the number of features.
342        y: array-like, shape `(n_samples,)`
343            True labels for `x`.
344        **kwargs: dictionary arguments
345            Legal arguments are the arguments of `Sequential.evaluate`.
346
347    Returns:
348        score: float
349            Mean accuracy of predictions on `x` wrt. `y`.
350    """
351    kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
352    loss = self.model.evaluate(x, y, **kwargs)
353    if isinstance(loss, list):
354      return -loss[0]
355    return -loss
356