1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A dataset loader for imports85.data.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23import numpy as np 24import tensorflow as tf 25 26try: 27 import pandas as pd # pylint: disable=g-import-not-at-top 28except ImportError: 29 pass 30 31 32URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.data" 33 34# Order is important for the csv-readers, so we use an OrderedDict here. 35defaults = collections.OrderedDict([ 36 ("symboling", [0]), 37 ("normalized-losses", [0.0]), 38 ("make", [""]), 39 ("fuel-type", [""]), 40 ("aspiration", [""]), 41 ("num-of-doors", [""]), 42 ("body-style", [""]), 43 ("drive-wheels", [""]), 44 ("engine-location", [""]), 45 ("wheel-base", [0.0]), 46 ("length", [0.0]), 47 ("width", [0.0]), 48 ("height", [0.0]), 49 ("curb-weight", [0.0]), 50 ("engine-type", [""]), 51 ("num-of-cylinders", [""]), 52 ("engine-size", [0.0]), 53 ("fuel-system", [""]), 54 ("bore", [0.0]), 55 ("stroke", [0.0]), 56 ("compression-ratio", [0.0]), 57 ("horsepower", [0.0]), 58 ("peak-rpm", [0.0]), 59 ("city-mpg", [0.0]), 60 ("highway-mpg", [0.0]), 61 ("price", [0.0]) 62]) # pyformat: disable 63 64 65types = collections.OrderedDict((key, type(value[0])) 66 for key, value in defaults.items()) 67 68 69def _get_imports85(): 70 path = tf.contrib.keras.utils.get_file(URL.split("/")[-1], URL) 71 return path 72 73 74def dataset(y_name="price", train_fraction=0.7): 75 """Load the imports85 data as a (train,test) pair of `Dataset`. 76 77 Each dataset generates (features_dict, label) pairs. 78 79 Args: 80 y_name: The name of the column to use as the label. 81 train_fraction: A float, the fraction of data to use for training. The 82 remainder will be used for evaluation. 83 Returns: 84 A (train,test) pair of `Datasets` 85 """ 86 # Download and cache the data 87 path = _get_imports85() 88 89 # Define how the lines of the file should be parsed 90 def decode_line(line): 91 """Convert a csv line into a (features_dict,label) pair.""" 92 # Decode the line to a tuple of items based on the types of 93 # csv_header.values(). 94 items = tf.decode_csv(line, list(defaults.values())) 95 96 # Convert the keys and items to a dict. 97 pairs = zip(defaults.keys(), items) 98 features_dict = dict(pairs) 99 100 # Remove the label from the features_dict 101 label = features_dict.pop(y_name) 102 103 return features_dict, label 104 105 def has_no_question_marks(line): 106 """Returns True if the line of text has no question marks.""" 107 # split the line into an array of characters 108 chars = tf.string_split(line[tf.newaxis], "").values 109 # for each character check if it is a question mark 110 is_question = tf.equal(chars, "?") 111 any_question = tf.reduce_any(is_question) 112 no_question = ~any_question 113 114 return no_question 115 116 def in_training_set(line): 117 """Returns a boolean tensor, true if the line is in the training set.""" 118 # If you randomly split the dataset you won't get the same split in both 119 # sessions if you stop and restart training later. Also a simple 120 # random split won't work with a dataset that's too big to `.cache()` as 121 # we are doing here. 122 num_buckets = 1000000 123 bucket_id = tf.string_to_hash_bucket_fast(line, num_buckets) 124 # Use the hash bucket id as a random number that's deterministic per example 125 return bucket_id < int(train_fraction * num_buckets) 126 127 def in_test_set(line): 128 """Returns a boolean tensor, true if the line is in the training set.""" 129 # Items not in the training set are in the test set. 130 # This line must use `~` instead of `not` because `not` only works on python 131 # booleans but we are dealing with symbolic tensors. 132 return ~in_training_set(line) 133 134 base_dataset = ( 135 tf.data 136 # Get the lines from the file. 137 .TextLineDataset(path) 138 # drop lines with question marks. 139 .filter(has_no_question_marks)) 140 141 train = (base_dataset 142 # Take only the training-set lines. 143 .filter(in_training_set) 144 # Decode each line into a (features_dict, label) pair. 145 .map(decode_line) 146 # Cache data so you only decode the file once. 147 .cache()) 148 149 # Do the same for the test-set. 150 test = (base_dataset.filter(in_test_set).cache().map(decode_line)) 151 152 return train, test 153 154 155def raw_dataframe(): 156 """Load the imports85 data as a pd.DataFrame.""" 157 # Download and cache the data 158 path = _get_imports85() 159 160 # Load it into a pandas dataframe 161 df = pd.read_csv(path, names=types.keys(), dtype=types, na_values="?") 162 163 return df 164 165 166def load_data(y_name="price", train_fraction=0.7, seed=None): 167 """Get the imports85 data set. 168 169 A description of the data is available at: 170 https://archive.ics.uci.edu/ml/datasets/automobile 171 172 The data itself can be found at: 173 https://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.data 174 175 Args: 176 y_name: the column to return as the label. 177 train_fraction: the fraction of the dataset to use for training. 178 seed: The random seed to use when shuffling the data. `None` generates a 179 unique shuffle every run. 180 Returns: 181 a pair of pairs where the first pair is the training data, and the second 182 is the test data: 183 `(x_train, y_train), (x_test, y_test) = get_imports85_dataset(...)` 184 `x` contains a pandas DataFrame of features, while `y` contains the label 185 array. 186 """ 187 # Load the raw data columns. 188 data = raw_dataframe() 189 190 # Delete rows with unknowns 191 data = data.dropna() 192 193 # Shuffle the data 194 np.random.seed(seed) 195 196 # Split the data into train/test subsets. 197 x_train = data.sample(frac=train_fraction, random_state=seed) 198 x_test = data.drop(x_train.index) 199 200 # Extract the label from the features dataframe. 201 y_train = x_train.pop(y_name) 202 y_test = x_test.pop(y_name) 203 204 return (x_train, y_train), (x_test, y_test) 205