1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Scikit-learn API wrapper.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python import keras 24from tensorflow.python.keras import testing_utils 25from tensorflow.python.platform import test 26 27INPUT_DIM = 5 28HIDDEN_DIM = 5 29TRAIN_SAMPLES = 10 30TEST_SAMPLES = 5 31NUM_CLASSES = 2 32BATCH_SIZE = 5 33EPOCHS = 1 34 35 36def build_fn_clf(hidden_dim): 37 model = keras.models.Sequential() 38 model.add(keras.layers.Dense(INPUT_DIM, input_shape=(INPUT_DIM,))) 39 model.add(keras.layers.Activation('relu')) 40 model.add(keras.layers.Dense(hidden_dim)) 41 model.add(keras.layers.Activation('relu')) 42 model.add(keras.layers.Dense(NUM_CLASSES)) 43 model.add(keras.layers.Activation('softmax')) 44 model.compile( 45 optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy']) 46 return model 47 48 49def assert_classification_works(clf): 50 np.random.seed(42) 51 (x_train, y_train), (x_test, _) = testing_utils.get_test_data( 52 train_samples=TRAIN_SAMPLES, 53 test_samples=TEST_SAMPLES, 54 input_shape=(INPUT_DIM,), 55 num_classes=NUM_CLASSES) 56 57 clf.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS) 58 59 score = clf.score(x_train, y_train, batch_size=BATCH_SIZE) 60 assert np.isscalar(score) and np.isfinite(score) 61 62 preds = clf.predict(x_test, batch_size=BATCH_SIZE) 63 assert preds.shape == (TEST_SAMPLES,) 64 for prediction in np.unique(preds): 65 assert prediction in range(NUM_CLASSES) 66 67 proba = clf.predict_proba(x_test, batch_size=BATCH_SIZE) 68 assert proba.shape == (TEST_SAMPLES, NUM_CLASSES) 69 assert np.allclose(np.sum(proba, axis=1), np.ones(TEST_SAMPLES)) 70 71 72def build_fn_reg(hidden_dim): 73 model = keras.models.Sequential() 74 model.add(keras.layers.Dense(INPUT_DIM, input_shape=(INPUT_DIM,))) 75 model.add(keras.layers.Activation('relu')) 76 model.add(keras.layers.Dense(hidden_dim)) 77 model.add(keras.layers.Activation('relu')) 78 model.add(keras.layers.Dense(1)) 79 model.add(keras.layers.Activation('linear')) 80 model.compile( 81 optimizer='sgd', loss='mean_absolute_error', metrics=['accuracy']) 82 return model 83 84 85def assert_regression_works(reg): 86 np.random.seed(42) 87 (x_train, y_train), (x_test, _) = testing_utils.get_test_data( 88 train_samples=TRAIN_SAMPLES, 89 test_samples=TEST_SAMPLES, 90 input_shape=(INPUT_DIM,), 91 num_classes=NUM_CLASSES) 92 93 reg.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS) 94 95 score = reg.score(x_train, y_train, batch_size=BATCH_SIZE) 96 assert np.isscalar(score) and np.isfinite(score) 97 98 preds = reg.predict(x_test, batch_size=BATCH_SIZE) 99 assert preds.shape == (TEST_SAMPLES,) 100 101 102class ScikitLearnAPIWrapperTest(test.TestCase): 103 104 def test_classify_build_fn(self): 105 with self.cached_session(): 106 clf = keras.wrappers.scikit_learn.KerasClassifier( 107 build_fn=build_fn_clf, 108 hidden_dim=HIDDEN_DIM, 109 batch_size=BATCH_SIZE, 110 epochs=EPOCHS) 111 112 assert_classification_works(clf) 113 114 def test_classify_class_build_fn(self): 115 116 class ClassBuildFnClf(object): 117 118 def __call__(self, hidden_dim): 119 return build_fn_clf(hidden_dim) 120 121 with self.cached_session(): 122 clf = keras.wrappers.scikit_learn.KerasClassifier( 123 build_fn=ClassBuildFnClf(), 124 hidden_dim=HIDDEN_DIM, 125 batch_size=BATCH_SIZE, 126 epochs=EPOCHS) 127 128 assert_classification_works(clf) 129 130 def test_classify_inherit_class_build_fn(self): 131 132 class InheritClassBuildFnClf(keras.wrappers.scikit_learn.KerasClassifier): 133 134 def __call__(self, hidden_dim): 135 return build_fn_clf(hidden_dim) 136 137 with self.cached_session(): 138 clf = InheritClassBuildFnClf( 139 build_fn=None, 140 hidden_dim=HIDDEN_DIM, 141 batch_size=BATCH_SIZE, 142 epochs=EPOCHS) 143 144 assert_classification_works(clf) 145 146 def test_regression_build_fn(self): 147 with self.cached_session(): 148 reg = keras.wrappers.scikit_learn.KerasRegressor( 149 build_fn=build_fn_reg, 150 hidden_dim=HIDDEN_DIM, 151 batch_size=BATCH_SIZE, 152 epochs=EPOCHS) 153 154 assert_regression_works(reg) 155 156 def test_regression_class_build_fn(self): 157 158 class ClassBuildFnReg(object): 159 160 def __call__(self, hidden_dim): 161 return build_fn_reg(hidden_dim) 162 163 with self.cached_session(): 164 reg = keras.wrappers.scikit_learn.KerasRegressor( 165 build_fn=ClassBuildFnReg(), 166 hidden_dim=HIDDEN_DIM, 167 batch_size=BATCH_SIZE, 168 epochs=EPOCHS) 169 170 assert_regression_works(reg) 171 172 def test_regression_inherit_class_build_fn(self): 173 174 class InheritClassBuildFnReg(keras.wrappers.scikit_learn.KerasRegressor): 175 176 def __call__(self, hidden_dim): 177 return build_fn_reg(hidden_dim) 178 179 with self.cached_session(): 180 reg = InheritClassBuildFnReg( 181 build_fn=None, 182 hidden_dim=HIDDEN_DIM, 183 batch_size=BATCH_SIZE, 184 epochs=EPOCHS) 185 186 assert_regression_works(reg) 187 188 189if __name__ == '__main__': 190 test.main() 191