1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A simple smoke test that runs these examples for 1 training iteration.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import sys 22 23import pandas as pd 24 25from six.moves import StringIO 26 27import tensorflow.examples.get_started.regression.imports85 as imports85 28 29sys.modules["imports85"] = imports85 30 31# pylint: disable=g-bad-import-order,g-import-not-at-top 32import tensorflow.data as data 33 34import tensorflow.examples.get_started.regression.dnn_regression as dnn_regression 35import tensorflow.examples.get_started.regression.linear_regression_categorical as linear_regression_categorical 36import tensorflow.examples.get_started.regression.custom_regression as custom_regression 37 38from tensorflow.python.platform import googletest 39from tensorflow.python.platform import test 40# pylint: disable=g-bad-import-order,g-import-not-at-top 41 42 43# pylint: disable=line-too-long 44FOUR_LINES = "\n".join([ 45 "1,?,alfa-romero,gas,std,two,hatchback,rwd,front,94.50,171.20,65.50,52.40,2823,ohcv,six,152,mpfi,2.68,3.47,9.00,154,5000,19,26,16500", 46 "2,164,audi,gas,std,four,sedan,fwd,front,99.80,176.60,66.20,54.30,2337,ohc,four,109,mpfi,3.19,3.40,10.00,102,5500,24,30,13950", 47 "2,164,audi,gas,std,four,sedan,4wd,front,99.40,176.60,66.40,54.30,2824,ohc,five,136,mpfi,3.19,3.40,8.00,115,5500,18,22,17450", 48 "2,?,audi,gas,std,two,sedan,fwd,front,99.80,177.30,66.30,53.10,2507,ohc,five,136,mpfi,3.19,3.40,8.50,110,5500,19,25,15250", 49]) 50 51# pylint: enable=line-too-long 52 53 54def four_lines_dataframe(): 55 text = StringIO(FOUR_LINES) 56 57 return pd.read_csv( 58 text, names=imports85.types.keys(), dtype=imports85.types, na_values="?") 59 60 61def four_lines_dataset(*args, **kwargs): 62 del args, kwargs 63 return data.Dataset.from_tensor_slices(FOUR_LINES.split("\n")) 64 65 66class RegressionTest(googletest.TestCase): 67 """Test the regression examples in this directory.""" 68 69 @test.mock.patch.dict(data.__dict__, {"TextLineDataset": four_lines_dataset}) 70 @test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)}) 71 @test.mock.patch.dict(linear_regression_categorical.__dict__, {"STEPS": 1}) 72 def test_linear_regression_categorical(self): 73 linear_regression_categorical.main([""]) 74 75 @test.mock.patch.dict(data.__dict__, {"TextLineDataset": four_lines_dataset}) 76 @test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)}) 77 @test.mock.patch.dict(dnn_regression.__dict__, {"STEPS": 1}) 78 def test_dnn_regression(self): 79 dnn_regression.main([""]) 80 81 @test.mock.patch.dict(data.__dict__, {"TextLineDataset": four_lines_dataset}) 82 @test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)}) 83 @test.mock.patch.dict(custom_regression.__dict__, {"STEPS": 1}) 84 def test_custom_regression(self): 85 custom_regression.main([""]) 86 87 88if __name__ == "__main__": 89 googletest.main() 90