1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Regression using the DNNRegressor Estimator."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import tensorflow as tf
22
23import imports85  # pylint: disable=g-bad-import-order
24
25STEPS = 5000
26PRICE_NORM_FACTOR = 1000
27
28
29def main(argv):
30  """Builds, trains, and evaluates the model."""
31  assert len(argv) == 1
32  (train, test) = imports85.dataset()
33
34  # Switch the labels to units of thousands for better convergence.
35  def normalize_price(features, labels):
36    return features, labels / PRICE_NORM_FACTOR
37
38  train = train.map(normalize_price)
39  test = test.map(normalize_price)
40
41  # Build the training input_fn.
42  def input_train():
43    return (
44        # Shuffling with a buffer larger than the data set ensures
45        # that the examples are well mixed.
46        train.shuffle(1000).batch(128)
47        # Repeat forever
48        .repeat())
49
50  # Build the validation input_fn.
51  def input_test():
52    return test.shuffle(1000).batch(128)
53
54  # The first way assigns a unique weight to each category. To do this you must
55  # specify the category's vocabulary (values outside this specification will
56  # receive a weight of zero). Here we specify the vocabulary using a list of
57  # options. The vocabulary can also be specified with a vocabulary file (using
58  # `categorical_column_with_vocabulary_file`). For features covering a
59  # range of positive integers use `categorical_column_with_identity`.
60  body_style_vocab = ["hardtop", "wagon", "sedan", "hatchback", "convertible"]
61  body_style = tf.feature_column.categorical_column_with_vocabulary_list(
62      key="body-style", vocabulary_list=body_style_vocab)
63  make = tf.feature_column.categorical_column_with_hash_bucket(
64      key="make", hash_bucket_size=50)
65
66  feature_columns = [
67      tf.feature_column.numeric_column(key="curb-weight"),
68      tf.feature_column.numeric_column(key="highway-mpg"),
69      # Since this is a DNN model, convert categorical columns from sparse
70      # to dense.
71      # Wrap them in an `indicator_column` to create a
72      # one-hot vector from the input.
73      tf.feature_column.indicator_column(body_style),
74      # Or use an `embedding_column` to create a trainable vector for each
75      # index.
76      tf.feature_column.embedding_column(make, dimension=3),
77  ]
78
79  # Build a DNNRegressor, with 2x20-unit hidden layers, with the feature columns
80  # defined above as input.
81  model = tf.estimator.DNNRegressor(
82      hidden_units=[20, 20], feature_columns=feature_columns)
83
84  # Train the model.
85  model.train(input_fn=input_train, steps=STEPS)
86
87  # Evaluate how the model performs on data it has not yet seen.
88  eval_result = model.evaluate(input_fn=input_test)
89
90  # The evaluation returns a Python dictionary. The "average_loss" key holds the
91  # Mean Squared Error (MSE).
92  average_loss = eval_result["average_loss"]
93
94  # Convert MSE to Root Mean Square Error (RMSE).
95  print("\n" + 80 * "*")
96  print("\nRMS error for the test set: ${:.0f}"
97        .format(PRICE_NORM_FACTOR * average_loss**0.5))
98
99  print()
100
101
102if __name__ == "__main__":
103  # The Estimator periodically generates "INFO" logs; make these logs visible.
104  tf.logging.set_verbosity(tf.logging.INFO)
105  tf.app.run(main=main)
106