1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Example of DNNClassifier for Iris plant dataset. 15 16This example uses APIs in Tensorflow 1.4 or above. 17""" 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import os 24import urllib 25 26import tensorflow as tf 27 28# Data sets 29IRIS_TRAINING = 'iris_training.csv' 30IRIS_TRAINING_URL = 'http://download.tensorflow.org/data/iris_training.csv' 31 32IRIS_TEST = 'iris_test.csv' 33IRIS_TEST_URL = 'http://download.tensorflow.org/data/iris_test.csv' 34 35FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'] 36 37 38def maybe_download_iris_data(file_name, download_url): 39 """Downloads the file and returns the number of data.""" 40 if not os.path.exists(file_name): 41 raw = urllib.urlopen(download_url).read() 42 with open(file_name, 'w') as f: 43 f.write(raw) 44 45 # The first line is a comma-separated string. The first one is the number of 46 # total data in the file. 47 with open(file_name, 'r') as f: 48 first_line = f.readline() 49 num_elements = first_line.split(',')[0] 50 return int(num_elements) 51 52 53def input_fn(file_name, num_data, batch_size, is_training): 54 """Creates an input_fn required by Estimator train/evaluate.""" 55 # If the data sets aren't stored locally, download them. 56 57 def _parse_csv(rows_string_tensor): 58 """Takes the string input tensor and returns tuple of (features, labels).""" 59 # Last dim is the label. 60 num_features = len(FEATURE_KEYS) 61 num_columns = num_features + 1 62 columns = tf.decode_csv(rows_string_tensor, 63 record_defaults=[[]] * num_columns) 64 features = dict(zip(FEATURE_KEYS, columns[:num_features])) 65 labels = tf.cast(columns[num_features], tf.int32) 66 return features, labels 67 68 def _input_fn(): 69 """The input_fn.""" 70 dataset = tf.data.TextLineDataset([file_name]) 71 # Skip the first line (which does not have data). 72 dataset = dataset.skip(1) 73 dataset = dataset.map(_parse_csv) 74 75 if is_training: 76 # For this small dataset, which can fit into memory, to achieve true 77 # randomness, the shuffle buffer size is set as the total number of 78 # elements in the dataset. 79 dataset = dataset.shuffle(num_data) 80 dataset = dataset.repeat() 81 82 dataset = dataset.batch(batch_size) 83 iterator = dataset.make_one_shot_iterator() 84 features, labels = iterator.get_next() 85 return features, labels 86 87 return _input_fn 88 89 90def main(unused_argv): 91 tf.logging.set_verbosity(tf.logging.INFO) 92 93 num_training_data = maybe_download_iris_data( 94 IRIS_TRAINING, IRIS_TRAINING_URL) 95 num_test_data = maybe_download_iris_data(IRIS_TEST, IRIS_TEST_URL) 96 97 # Build 3 layer DNN with 10, 20, 10 units respectively. 98 feature_columns = [ 99 tf.feature_column.numeric_column(key, shape=1) for key in FEATURE_KEYS] 100 classifier = tf.estimator.DNNClassifier( 101 feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3) 102 103 # Train. 104 train_input_fn = input_fn(IRIS_TRAINING, num_training_data, batch_size=32, 105 is_training=True) 106 classifier.train(input_fn=train_input_fn, steps=400) 107 108 # Eval. 109 test_input_fn = input_fn(IRIS_TEST, num_test_data, batch_size=32, 110 is_training=False) 111 scores = classifier.evaluate(input_fn=test_input_fn) 112 print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy'])) 113 114 115if __name__ == '__main__': 116 tf.app.run() 117