1# Lint as: python3 2# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16# pylint: disable=redefined-outer-name 17# pylint: disable=g-bad-import-order 18 19"""Build and train neural networks.""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import argparse 26import datetime 27import os 28from data_load import DataLoader 29 30import numpy as np 31import tensorflow as tf 32 33logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 34tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) 35 36 37def reshape_function(data, label): 38 reshaped_data = tf.reshape(data, [-1, 3, 1]) 39 return reshaped_data, label 40 41 42def calculate_model_size(model): 43 print(model.summary()) 44 var_sizes = [ 45 np.product(list(map(int, v.shape))) * v.dtype.size 46 for v in model.trainable_variables 47 ] 48 print("Model size:", sum(var_sizes) / 1024, "KB") 49 50 51def build_cnn(seq_length): 52 """Builds a convolutional neural network in Keras.""" 53 model = tf.keras.Sequential([ 54 tf.keras.layers.Conv2D( 55 8, (4, 3), 56 padding="same", 57 activation="relu", 58 input_shape=(seq_length, 3, 1)), # output_shape=(batch, 128, 3, 8) 59 tf.keras.layers.MaxPool2D((3, 3)), # (batch, 42, 1, 8) 60 tf.keras.layers.Dropout(0.1), # (batch, 42, 1, 8) 61 tf.keras.layers.Conv2D(16, (4, 1), padding="same", 62 activation="relu"), # (batch, 42, 1, 16) 63 tf.keras.layers.MaxPool2D((3, 1), padding="same"), # (batch, 14, 1, 16) 64 tf.keras.layers.Dropout(0.1), # (batch, 14, 1, 16) 65 tf.keras.layers.Flatten(), # (batch, 224) 66 tf.keras.layers.Dense(16, activation="relu"), # (batch, 16) 67 tf.keras.layers.Dropout(0.1), # (batch, 16) 68 tf.keras.layers.Dense(4, activation="softmax") # (batch, 4) 69 ]) 70 model_path = os.path.join("./netmodels", "CNN") 71 print("Built CNN.") 72 if not os.path.exists(model_path): 73 os.makedirs(model_path) 74 model.load_weights("./netmodels/CNN/weights.h5") 75 return model, model_path 76 77 78def build_lstm(seq_length): 79 """Builds an LSTM in Keras.""" 80 model = tf.keras.Sequential([ 81 tf.keras.layers.Bidirectional( 82 tf.keras.layers.LSTM(22), 83 input_shape=(seq_length, 3)), # output_shape=(batch, 44) 84 tf.keras.layers.Dense(4, activation="sigmoid") # (batch, 4) 85 ]) 86 model_path = os.path.join("./netmodels", "LSTM") 87 print("Built LSTM.") 88 if not os.path.exists(model_path): 89 os.makedirs(model_path) 90 return model, model_path 91 92 93def load_data(train_data_path, valid_data_path, test_data_path, seq_length): 94 data_loader = DataLoader( 95 train_data_path, valid_data_path, test_data_path, seq_length=seq_length) 96 data_loader.format() 97 return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \ 98 data_loader.valid_data, data_loader.test_len, data_loader.test_data 99 100 101def build_net(args, seq_length): 102 if args.model == "CNN": 103 model, model_path = build_cnn(seq_length) 104 elif args.model == "LSTM": 105 model, model_path = build_lstm(seq_length) 106 else: 107 print("Please input correct model name.(CNN LSTM)") 108 return model, model_path 109 110 111def train_net( 112 model, 113 model_path, # pylint: disable=unused-argument 114 train_len, # pylint: disable=unused-argument 115 train_data, 116 valid_len, 117 valid_data, # pylint: disable=unused-argument 118 test_len, 119 test_data, 120 kind): 121 """Trains the model.""" 122 calculate_model_size(model) 123 epochs = 50 124 batch_size = 64 125 model.compile( 126 optimizer="adam", 127 loss="sparse_categorical_crossentropy", 128 metrics=["accuracy"]) 129 if kind == "CNN": 130 train_data = train_data.map(reshape_function) 131 test_data = test_data.map(reshape_function) 132 valid_data = valid_data.map(reshape_function) 133 test_labels = np.zeros(test_len) 134 idx = 0 135 for data, label in test_data: # pylint: disable=unused-variable 136 test_labels[idx] = label.numpy() 137 idx += 1 138 train_data = train_data.batch(batch_size).repeat() 139 valid_data = valid_data.batch(batch_size) 140 test_data = test_data.batch(batch_size) 141 model.fit( 142 train_data, 143 epochs=epochs, 144 validation_data=valid_data, 145 steps_per_epoch=1000, 146 validation_steps=int((valid_len - 1) / batch_size + 1), 147 callbacks=[tensorboard_callback]) 148 loss, acc = model.evaluate(test_data) 149 pred = np.argmax(model.predict(test_data), axis=1) 150 confusion = tf.math.confusion_matrix( 151 labels=tf.constant(test_labels), 152 predictions=tf.constant(pred), 153 num_classes=4) 154 print(confusion) 155 print("Loss {}, Accuracy {}".format(loss, acc)) 156 # Convert the model to the TensorFlow Lite format without quantization 157 converter = tf.lite.TFLiteConverter.from_keras_model(model) 158 tflite_model = converter.convert() 159 160 # Save the model to disk 161 open("model.tflite", "wb").write(tflite_model) 162 163 # Convert the model to the TensorFlow Lite format with quantization 164 converter = tf.lite.TFLiteConverter.from_keras_model(model) 165 converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] 166 tflite_model = converter.convert() 167 168 # Save the model to disk 169 open("model_quantized.tflite", "wb").write(tflite_model) 170 171 basic_model_size = os.path.getsize("model.tflite") 172 print("Basic model is %d bytes" % basic_model_size) 173 quantized_model_size = os.path.getsize("model_quantized.tflite") 174 print("Quantized model is %d bytes" % quantized_model_size) 175 difference = basic_model_size - quantized_model_size 176 print("Difference is %d bytes" % difference) 177 178 179if __name__ == "__main__": 180 parser = argparse.ArgumentParser() 181 parser.add_argument("--model", "-m") 182 parser.add_argument("--person", "-p") 183 args = parser.parse_args() 184 185 seq_length = 128 186 187 print("Start to load data...") 188 if args.person == "true": 189 train_len, train_data, valid_len, valid_data, test_len, test_data = \ 190 load_data("./person_split/train", "./person_split/valid", 191 "./person_split/test", seq_length) 192 else: 193 train_len, train_data, valid_len, valid_data, test_len, test_data = \ 194 load_data("./data/train", "./data/valid", "./data/test", seq_length) 195 196 print("Start to build net...") 197 model, model_path = build_net(args, seq_length) 198 199 print("Start training...") 200 train_net(model, model_path, train_len, train_data, valid_len, valid_data, 201 test_len, test_data, args.model) 202 203 print("Training finished!") 204