1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A simple smoke test that runs these examples for 1 training iteration."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import sys
22
23import pandas as pd
24
25from six.moves import StringIO
26
27import tensorflow.examples.get_started.regression.imports85 as imports85
28
29sys.modules["imports85"] = imports85
30
31# pylint: disable=g-bad-import-order,g-import-not-at-top
32import tensorflow.data as data
33
34import tensorflow.examples.get_started.regression.dnn_regression as dnn_regression
35import tensorflow.examples.get_started.regression.linear_regression_categorical as linear_regression_categorical
36import tensorflow.examples.get_started.regression.custom_regression as custom_regression
37
38from tensorflow.python.platform import googletest
39from tensorflow.python.platform import test
40# pylint: disable=g-bad-import-order,g-import-not-at-top
41
42
43# pylint: disable=line-too-long
44FOUR_LINES = "\n".join([
45    "1,?,alfa-romero,gas,std,two,hatchback,rwd,front,94.50,171.20,65.50,52.40,2823,ohcv,six,152,mpfi,2.68,3.47,9.00,154,5000,19,26,16500",
46    "2,164,audi,gas,std,four,sedan,fwd,front,99.80,176.60,66.20,54.30,2337,ohc,four,109,mpfi,3.19,3.40,10.00,102,5500,24,30,13950",
47    "2,164,audi,gas,std,four,sedan,4wd,front,99.40,176.60,66.40,54.30,2824,ohc,five,136,mpfi,3.19,3.40,8.00,115,5500,18,22,17450",
48    "2,?,audi,gas,std,two,sedan,fwd,front,99.80,177.30,66.30,53.10,2507,ohc,five,136,mpfi,3.19,3.40,8.50,110,5500,19,25,15250",
49])
50
51# pylint: enable=line-too-long
52
53
54def four_lines_dataframe():
55  text = StringIO(FOUR_LINES)
56
57  return pd.read_csv(
58      text, names=imports85.types.keys(), dtype=imports85.types, na_values="?")
59
60
61def four_lines_dataset(*args, **kwargs):
62  del args, kwargs
63  return data.Dataset.from_tensor_slices(FOUR_LINES.split("\n"))
64
65
66class RegressionTest(googletest.TestCase):
67  """Test the regression examples in this directory."""
68
69  @test.mock.patch.dict(data.__dict__, {"TextLineDataset": four_lines_dataset})
70  @test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)})
71  @test.mock.patch.dict(linear_regression_categorical.__dict__, {"STEPS": 1})
72  def test_linear_regression_categorical(self):
73    linear_regression_categorical.main([""])
74
75  @test.mock.patch.dict(data.__dict__, {"TextLineDataset": four_lines_dataset})
76  @test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)})
77  @test.mock.patch.dict(dnn_regression.__dict__, {"STEPS": 1})
78  def test_dnn_regression(self):
79    dnn_regression.main([""])
80
81  @test.mock.patch.dict(data.__dict__, {"TextLineDataset": four_lines_dataset})
82  @test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)})
83  @test.mock.patch.dict(custom_regression.__dict__, {"STEPS": 1})
84  def test_custom_regression(self):
85    custom_regression.main([""])
86
87
88if __name__ == "__main__":
89  googletest.main()
90