1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A dataset loader for imports85.data."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23import numpy as np
24import tensorflow as tf
25
26try:
27  import pandas as pd  # pylint: disable=g-import-not-at-top
28except ImportError:
29  pass
30
31
32URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.data"
33
34# Order is important for the csv-readers, so we use an OrderedDict here.
35defaults = collections.OrderedDict([
36    ("symboling", [0]),
37    ("normalized-losses", [0.0]),
38    ("make", [""]),
39    ("fuel-type", [""]),
40    ("aspiration", [""]),
41    ("num-of-doors", [""]),
42    ("body-style", [""]),
43    ("drive-wheels", [""]),
44    ("engine-location", [""]),
45    ("wheel-base", [0.0]),
46    ("length", [0.0]),
47    ("width", [0.0]),
48    ("height", [0.0]),
49    ("curb-weight", [0.0]),
50    ("engine-type", [""]),
51    ("num-of-cylinders", [""]),
52    ("engine-size", [0.0]),
53    ("fuel-system", [""]),
54    ("bore", [0.0]),
55    ("stroke", [0.0]),
56    ("compression-ratio", [0.0]),
57    ("horsepower", [0.0]),
58    ("peak-rpm", [0.0]),
59    ("city-mpg", [0.0]),
60    ("highway-mpg", [0.0]),
61    ("price", [0.0])
62])  # pyformat: disable
63
64
65types = collections.OrderedDict((key, type(value[0]))
66                                for key, value in defaults.items())
67
68
69def _get_imports85():
70  path = tf.contrib.keras.utils.get_file(URL.split("/")[-1], URL)
71  return path
72
73
74def dataset(y_name="price", train_fraction=0.7):
75  """Load the imports85 data as a (train,test) pair of `Dataset`.
76
77  Each dataset generates (features_dict, label) pairs.
78
79  Args:
80    y_name: The name of the column to use as the label.
81    train_fraction: A float, the fraction of data to use for training. The
82        remainder will be used for evaluation.
83  Returns:
84    A (train,test) pair of `Datasets`
85  """
86  # Download and cache the data
87  path = _get_imports85()
88
89  # Define how the lines of the file should be parsed
90  def decode_line(line):
91    """Convert a csv line into a (features_dict,label) pair."""
92    # Decode the line to a tuple of items based on the types of
93    # csv_header.values().
94    items = tf.decode_csv(line, list(defaults.values()))
95
96    # Convert the keys and items to a dict.
97    pairs = zip(defaults.keys(), items)
98    features_dict = dict(pairs)
99
100    # Remove the label from the features_dict
101    label = features_dict.pop(y_name)
102
103    return features_dict, label
104
105  def has_no_question_marks(line):
106    """Returns True if the line of text has no question marks."""
107    # split the line into an array of characters
108    chars = tf.string_split(line[tf.newaxis], "").values
109    # for each character check if it is a question mark
110    is_question = tf.equal(chars, "?")
111    any_question = tf.reduce_any(is_question)
112    no_question = ~any_question
113
114    return no_question
115
116  def in_training_set(line):
117    """Returns a boolean tensor, true if the line is in the training set."""
118    # If you randomly split the dataset you won't get the same split in both
119    # sessions if you stop and restart training later. Also a simple
120    # random split won't work with a dataset that's too big to `.cache()` as
121    # we are doing here.
122    num_buckets = 1000000
123    bucket_id = tf.string_to_hash_bucket_fast(line, num_buckets)
124    # Use the hash bucket id as a random number that's deterministic per example
125    return bucket_id < int(train_fraction * num_buckets)
126
127  def in_test_set(line):
128    """Returns a boolean tensor, true if the line is in the training set."""
129    # Items not in the training set are in the test set.
130    # This line must use `~` instead of `not` because `not` only works on python
131    # booleans but we are dealing with symbolic tensors.
132    return ~in_training_set(line)
133
134  base_dataset = (
135      tf.data
136      # Get the lines from the file.
137      .TextLineDataset(path)
138      # drop lines with question marks.
139      .filter(has_no_question_marks))
140
141  train = (base_dataset
142           # Take only the training-set lines.
143           .filter(in_training_set)
144           # Decode each line into a (features_dict, label) pair.
145           .map(decode_line)
146           # Cache data so you only decode the file once.
147           .cache())
148
149  # Do the same for the test-set.
150  test = (base_dataset.filter(in_test_set).cache().map(decode_line))
151
152  return train, test
153
154
155def raw_dataframe():
156  """Load the imports85 data as a pd.DataFrame."""
157  # Download and cache the data
158  path = _get_imports85()
159
160  # Load it into a pandas dataframe
161  df = pd.read_csv(path, names=types.keys(), dtype=types, na_values="?")
162
163  return df
164
165
166def load_data(y_name="price", train_fraction=0.7, seed=None):
167  """Get the imports85 data set.
168
169  A description of the data is available at:
170    https://archive.ics.uci.edu/ml/datasets/automobile
171
172  The data itself can be found at:
173    https://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.data
174
175  Args:
176    y_name: the column to return as the label.
177    train_fraction: the fraction of the dataset to use for training.
178    seed: The random seed to use when shuffling the data. `None` generates a
179      unique shuffle every run.
180  Returns:
181    a pair of pairs where the first pair is the training data, and the second
182    is the test data:
183    `(x_train, y_train), (x_test, y_test) = get_imports85_dataset(...)`
184    `x` contains a pandas DataFrame of features, while `y` contains the label
185    array.
186  """
187  # Load the raw data columns.
188  data = raw_dataframe()
189
190  # Delete rows with unknowns
191  data = data.dropna()
192
193  # Shuffle the data
194  np.random.seed(seed)
195
196  # Split the data into train/test subsets.
197  x_train = data.sample(frac=train_fraction, random_state=seed)
198  x_test = data.drop(x_train.index)
199
200  # Extract the label from the features dataframe.
201  y_train = x_train.pop(y_name)
202  y_test = x_test.pop(y_name)
203
204  return (x_train, y_train), (x_test, y_test)
205