1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Scikit-learn API wrapper."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python import keras
24from tensorflow.python.keras import testing_utils
25from tensorflow.python.platform import test
26
27INPUT_DIM = 5
28HIDDEN_DIM = 5
29TRAIN_SAMPLES = 10
30TEST_SAMPLES = 5
31NUM_CLASSES = 2
32BATCH_SIZE = 5
33EPOCHS = 1
34
35
36def build_fn_clf(hidden_dim):
37  model = keras.models.Sequential()
38  model.add(keras.layers.Dense(INPUT_DIM, input_shape=(INPUT_DIM,)))
39  model.add(keras.layers.Activation('relu'))
40  model.add(keras.layers.Dense(hidden_dim))
41  model.add(keras.layers.Activation('relu'))
42  model.add(keras.layers.Dense(NUM_CLASSES))
43  model.add(keras.layers.Activation('softmax'))
44  model.compile(
45      optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
46  return model
47
48
49def assert_classification_works(clf):
50  np.random.seed(42)
51  (x_train, y_train), (x_test, _) = testing_utils.get_test_data(
52      train_samples=TRAIN_SAMPLES,
53      test_samples=TEST_SAMPLES,
54      input_shape=(INPUT_DIM,),
55      num_classes=NUM_CLASSES)
56
57  clf.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS)
58
59  score = clf.score(x_train, y_train, batch_size=BATCH_SIZE)
60  assert np.isscalar(score) and np.isfinite(score)
61
62  preds = clf.predict(x_test, batch_size=BATCH_SIZE)
63  assert preds.shape == (TEST_SAMPLES,)
64  for prediction in np.unique(preds):
65    assert prediction in range(NUM_CLASSES)
66
67  proba = clf.predict_proba(x_test, batch_size=BATCH_SIZE)
68  assert proba.shape == (TEST_SAMPLES, NUM_CLASSES)
69  assert np.allclose(np.sum(proba, axis=1), np.ones(TEST_SAMPLES))
70
71
72def build_fn_reg(hidden_dim):
73  model = keras.models.Sequential()
74  model.add(keras.layers.Dense(INPUT_DIM, input_shape=(INPUT_DIM,)))
75  model.add(keras.layers.Activation('relu'))
76  model.add(keras.layers.Dense(hidden_dim))
77  model.add(keras.layers.Activation('relu'))
78  model.add(keras.layers.Dense(1))
79  model.add(keras.layers.Activation('linear'))
80  model.compile(
81      optimizer='sgd', loss='mean_absolute_error', metrics=['accuracy'])
82  return model
83
84
85def assert_regression_works(reg):
86  np.random.seed(42)
87  (x_train, y_train), (x_test, _) = testing_utils.get_test_data(
88      train_samples=TRAIN_SAMPLES,
89      test_samples=TEST_SAMPLES,
90      input_shape=(INPUT_DIM,),
91      num_classes=NUM_CLASSES)
92
93  reg.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS)
94
95  score = reg.score(x_train, y_train, batch_size=BATCH_SIZE)
96  assert np.isscalar(score) and np.isfinite(score)
97
98  preds = reg.predict(x_test, batch_size=BATCH_SIZE)
99  assert preds.shape == (TEST_SAMPLES,)
100
101
102class ScikitLearnAPIWrapperTest(test.TestCase):
103
104  def test_classify_build_fn(self):
105    with self.cached_session():
106      clf = keras.wrappers.scikit_learn.KerasClassifier(
107          build_fn=build_fn_clf,
108          hidden_dim=HIDDEN_DIM,
109          batch_size=BATCH_SIZE,
110          epochs=EPOCHS)
111
112      assert_classification_works(clf)
113
114  def test_classify_class_build_fn(self):
115
116    class ClassBuildFnClf(object):
117
118      def __call__(self, hidden_dim):
119        return build_fn_clf(hidden_dim)
120
121    with self.cached_session():
122      clf = keras.wrappers.scikit_learn.KerasClassifier(
123          build_fn=ClassBuildFnClf(),
124          hidden_dim=HIDDEN_DIM,
125          batch_size=BATCH_SIZE,
126          epochs=EPOCHS)
127
128      assert_classification_works(clf)
129
130  def test_classify_inherit_class_build_fn(self):
131
132    class InheritClassBuildFnClf(keras.wrappers.scikit_learn.KerasClassifier):
133
134      def __call__(self, hidden_dim):
135        return build_fn_clf(hidden_dim)
136
137    with self.cached_session():
138      clf = InheritClassBuildFnClf(
139          build_fn=None,
140          hidden_dim=HIDDEN_DIM,
141          batch_size=BATCH_SIZE,
142          epochs=EPOCHS)
143
144      assert_classification_works(clf)
145
146  def test_regression_build_fn(self):
147    with self.cached_session():
148      reg = keras.wrappers.scikit_learn.KerasRegressor(
149          build_fn=build_fn_reg,
150          hidden_dim=HIDDEN_DIM,
151          batch_size=BATCH_SIZE,
152          epochs=EPOCHS)
153
154      assert_regression_works(reg)
155
156  def test_regression_class_build_fn(self):
157
158    class ClassBuildFnReg(object):
159
160      def __call__(self, hidden_dim):
161        return build_fn_reg(hidden_dim)
162
163    with self.cached_session():
164      reg = keras.wrappers.scikit_learn.KerasRegressor(
165          build_fn=ClassBuildFnReg(),
166          hidden_dim=HIDDEN_DIM,
167          batch_size=BATCH_SIZE,
168          epochs=EPOCHS)
169
170      assert_regression_works(reg)
171
172  def test_regression_inherit_class_build_fn(self):
173
174    class InheritClassBuildFnReg(keras.wrappers.scikit_learn.KerasRegressor):
175
176      def __call__(self, hidden_dim):
177        return build_fn_reg(hidden_dim)
178
179    with self.cached_session():
180      reg = InheritClassBuildFnReg(
181          build_fn=None,
182          hidden_dim=HIDDEN_DIM,
183          batch_size=BATCH_SIZE,
184          epochs=EPOCHS)
185
186      assert_regression_works(reg)
187
188
189if __name__ == '__main__':
190  test.main()
191