1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""sklearn cross-support (deprecated).""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import os 24 25import numpy as np 26import six 27 28 29def _pprint(d): 30 return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()]) 31 32 33class _BaseEstimator(object): 34 """This is a cross-import when sklearn is not available. 35 36 Adopted from sklearn.BaseEstimator implementation. 37 https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/base.py 38 """ 39 40 def get_params(self, deep=True): 41 """Get parameters for this estimator. 42 43 Args: 44 deep: boolean, optional 45 46 If `True`, will return the parameters for this estimator and 47 contained subobjects that are estimators. 48 49 Returns: 50 params : mapping of string to any 51 Parameter names mapped to their values. 52 """ 53 out = dict() 54 param_names = [name for name in self.__dict__ if not name.startswith('_')] 55 for key in param_names: 56 value = getattr(self, key, None) 57 58 if isinstance(value, collections.Callable): 59 continue 60 61 # XXX: should we rather test if instance of estimator? 62 if deep and hasattr(value, 'get_params'): 63 deep_items = value.get_params().items() 64 out.update((key + '__' + k, val) for k, val in deep_items) 65 out[key] = value 66 return out 67 68 def set_params(self, **params): 69 """Set the parameters of this estimator. 70 71 The method works on simple estimators as well as on nested objects 72 (such as pipelines). The former have parameters of the form 73 ``<component>__<parameter>`` so that it's possible to update each 74 component of a nested object. 75 76 Args: 77 **params: Parameters. 78 79 Returns: 80 self 81 82 Raises: 83 ValueError: If params contain invalid names. 84 """ 85 if not params: 86 # Simple optimisation to gain speed (inspect is slow) 87 return self 88 valid_params = self.get_params(deep=True) 89 for key, value in six.iteritems(params): 90 split = key.split('__', 1) 91 if len(split) > 1: 92 # nested objects case 93 name, sub_name = split 94 if name not in valid_params: 95 raise ValueError('Invalid parameter %s for estimator %s. ' 96 'Check the list of available parameters ' 97 'with `estimator.get_params().keys()`.' % 98 (name, self)) 99 sub_object = valid_params[name] 100 sub_object.set_params(**{sub_name: value}) 101 else: 102 # simple objects case 103 if key not in valid_params: 104 raise ValueError('Invalid parameter %s for estimator %s. ' 105 'Check the list of available parameters ' 106 'with `estimator.get_params().keys()`.' % 107 (key, self.__class__.__name__)) 108 setattr(self, key, value) 109 return self 110 111 def __repr__(self): 112 class_name = self.__class__.__name__ 113 return '%s(%s)' % (class_name, 114 _pprint(self.get_params(deep=False)),) 115 116 117# pylint: disable=old-style-class 118class _ClassifierMixin(): 119 """Mixin class for all classifiers.""" 120 pass 121 122 123class _RegressorMixin(): 124 """Mixin class for all regression estimators.""" 125 pass 126 127 128class _TransformerMixin(): 129 """Mixin class for all transformer estimators.""" 130 131 132class NotFittedError(ValueError, AttributeError): 133 """Exception class to raise if estimator is used before fitting. 134 135 USE OF THIS EXCEPTION IS DEPRECATED. 136 137 This class inherits from both ValueError and AttributeError to help with 138 exception handling and backward compatibility. 139 140 Examples: 141 >>> from sklearn.svm import LinearSVC 142 >>> from sklearn.exceptions import NotFittedError 143 >>> try: 144 ... LinearSVC().predict([[1, 2], [2, 3], [3, 4]]) 145 ... except NotFittedError as e: 146 ... print(repr(e)) 147 ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS 148 NotFittedError('This LinearSVC instance is not fitted yet',) 149 150 Copied from 151 https://github.com/scikit-learn/scikit-learn/master/sklearn/exceptions.py 152 """ 153 154# pylint: enable=old-style-class 155 156 157def _accuracy_score(y_true, y_pred): 158 score = y_true == y_pred 159 return np.average(score) 160 161 162def _mean_squared_error(y_true, y_pred): 163 if len(y_true.shape) > 1: 164 y_true = np.squeeze(y_true) 165 if len(y_pred.shape) > 1: 166 y_pred = np.squeeze(y_pred) 167 return np.average((y_true - y_pred)**2) 168 169 170def _train_test_split(*args, **options): 171 # pylint: disable=missing-docstring 172 test_size = options.pop('test_size', None) 173 train_size = options.pop('train_size', None) 174 random_state = options.pop('random_state', None) 175 176 if test_size is None and train_size is None: 177 train_size = 0.75 178 elif train_size is None: 179 train_size = 1 - test_size 180 train_size = int(train_size * args[0].shape[0]) 181 182 np.random.seed(random_state) 183 indices = np.random.permutation(args[0].shape[0]) 184 train_idx, test_idx = indices[:train_size], indices[train_size:] 185 result = [] 186 for x in args: 187 result += [x.take(train_idx, axis=0), x.take(test_idx, axis=0)] 188 return tuple(result) 189 190 191# If "TENSORFLOW_SKLEARN" flag is defined then try to import from sklearn. 192TRY_IMPORT_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False) 193if TRY_IMPORT_SKLEARN: 194 # pylint: disable=g-import-not-at-top,g-multiple-import,unused-import 195 from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin 196 from sklearn.metrics import accuracy_score, log_loss, mean_squared_error 197 from sklearn.model_selection import train_test_split 198 try: 199 from sklearn.exceptions import NotFittedError 200 except ImportError: 201 try: 202 from sklearn.utils.validation import NotFittedError 203 except ImportError: 204 pass 205else: 206 # Naive implementations of sklearn classes and functions. 207 BaseEstimator = _BaseEstimator 208 ClassifierMixin = _ClassifierMixin 209 RegressorMixin = _RegressorMixin 210 TransformerMixin = _TransformerMixin 211 accuracy_score = _accuracy_score 212 log_loss = None 213 mean_squared_error = _mean_squared_error 214 train_test_split = _train_test_split 215