1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Linear regression with categorical features.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import tensorflow as tf 22 23import imports85 # pylint: disable=g-bad-import-order 24 25STEPS = 1000 26PRICE_NORM_FACTOR = 1000 27 28 29def main(argv): 30 """Builds, trains, and evaluates the model.""" 31 assert len(argv) == 1 32 (train, test) = imports85.dataset() 33 34 # Switch the labels to units of thousands for better convergence. 35 def normalize_price(features, labels): 36 return features, labels / PRICE_NORM_FACTOR 37 38 train = train.map(normalize_price) 39 test = test.map(normalize_price) 40 41 # Build the training input_fn. 42 def input_train(): 43 return ( 44 # Shuffling with a buffer larger than the data set ensures 45 # that the examples are well mixed. 46 train.shuffle(1000).batch(128) 47 # Repeat forever 48 .repeat()) 49 50 # Build the validation input_fn. 51 def input_test(): 52 return test.shuffle(1000).batch(128) 53 54 # The following code demonstrates two of the ways that `feature_columns` can 55 # be used to build a model with categorical inputs. 56 57 # The first way assigns a unique weight to each category. To do this, you must 58 # specify the category's vocabulary (values outside this specification will 59 # receive a weight of zero). 60 # Alternatively, you can define the vocabulary in a file (by calling 61 # `categorical_column_with_vocabulary_file`) or as a range of positive 62 # integers (by calling `categorical_column_with_identity`) 63 body_style_vocab = ["hardtop", "wagon", "sedan", "hatchback", "convertible"] 64 body_style_column = tf.feature_column.categorical_column_with_vocabulary_list( 65 key="body-style", vocabulary_list=body_style_vocab) 66 67 # The second way, appropriate for an unspecified vocabulary, is to create a 68 # hashed column. It will create a fixed length list of weights, and 69 # automatically assign each input category to a weight. Due to the 70 # pseudo-randomness of the process, some weights may be shared between 71 # categories, while others will remain unused. 72 make_column = tf.feature_column.categorical_column_with_hash_bucket( 73 key="make", hash_bucket_size=50) 74 75 feature_columns = [ 76 # This model uses the same two numeric features as `linear_regressor.py` 77 tf.feature_column.numeric_column(key="curb-weight"), 78 tf.feature_column.numeric_column(key="highway-mpg"), 79 # This model adds two categorical colums that will adjust the price based 80 # on "make" and "body-style". 81 body_style_column, 82 make_column, 83 ] 84 85 # Build the Estimator. 86 model = tf.estimator.LinearRegressor(feature_columns=feature_columns) 87 88 # Train the model. 89 # By default, the Estimators log output every 100 steps. 90 model.train(input_fn=input_train, steps=STEPS) 91 92 # Evaluate how the model performs on data it has not yet seen. 93 eval_result = model.evaluate(input_fn=input_test) 94 95 # The evaluation returns a Python dictionary. The "average_loss" key holds the 96 # Mean Squared Error (MSE). 97 average_loss = eval_result["average_loss"] 98 99 # Convert MSE to Root Mean Square Error (RMSE). 100 print("\n" + 80 * "*") 101 print("\nRMS error for the test set: ${:.0f}" 102 .format(PRICE_NORM_FACTOR * average_loss**0.5)) 103 104 print() 105 106 107if __name__ == "__main__": 108 # The Estimator periodically generates "INFO" logs; make these logs visible. 109 tf.logging.set_verbosity(tf.logging.INFO) 110 tf.app.run(main=main) 111