1# Lint as: python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16# pylint: disable=redefined-outer-name
17# pylint: disable=g-bad-import-order
18
19"""Build and train neural networks."""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import argparse
26import datetime
27import os
28from data_load import DataLoader
29
30import numpy as np
31import tensorflow as tf
32
33logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
34tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
35
36
37def reshape_function(data, label):
38  reshaped_data = tf.reshape(data, [-1, 3, 1])
39  return reshaped_data, label
40
41
42def calculate_model_size(model):
43  print(model.summary())
44  var_sizes = [
45      np.product(list(map(int, v.shape))) * v.dtype.size
46      for v in model.trainable_variables
47  ]
48  print("Model size:", sum(var_sizes) / 1024, "KB")
49
50
51def build_cnn(seq_length):
52  """Builds a convolutional neural network in Keras."""
53  model = tf.keras.Sequential([
54      tf.keras.layers.Conv2D(
55          8, (4, 3),
56          padding="same",
57          activation="relu",
58          input_shape=(seq_length, 3, 1)),  # output_shape=(batch, 128, 3, 8)
59      tf.keras.layers.MaxPool2D((3, 3)),  # (batch, 42, 1, 8)
60      tf.keras.layers.Dropout(0.1),  # (batch, 42, 1, 8)
61      tf.keras.layers.Conv2D(16, (4, 1), padding="same",
62                             activation="relu"),  # (batch, 42, 1, 16)
63      tf.keras.layers.MaxPool2D((3, 1), padding="same"),  # (batch, 14, 1, 16)
64      tf.keras.layers.Dropout(0.1),  # (batch, 14, 1, 16)
65      tf.keras.layers.Flatten(),  # (batch, 224)
66      tf.keras.layers.Dense(16, activation="relu"),  # (batch, 16)
67      tf.keras.layers.Dropout(0.1),  # (batch, 16)
68      tf.keras.layers.Dense(4, activation="softmax")  # (batch, 4)
69  ])
70  model_path = os.path.join("./netmodels", "CNN")
71  print("Built CNN.")
72  if not os.path.exists(model_path):
73    os.makedirs(model_path)
74  model.load_weights("./netmodels/CNN/weights.h5")
75  return model, model_path
76
77
78def build_lstm(seq_length):
79  """Builds an LSTM in Keras."""
80  model = tf.keras.Sequential([
81      tf.keras.layers.Bidirectional(
82          tf.keras.layers.LSTM(22),
83          input_shape=(seq_length, 3)),  # output_shape=(batch, 44)
84      tf.keras.layers.Dense(4, activation="sigmoid")  # (batch, 4)
85  ])
86  model_path = os.path.join("./netmodels", "LSTM")
87  print("Built LSTM.")
88  if not os.path.exists(model_path):
89    os.makedirs(model_path)
90  return model, model_path
91
92
93def load_data(train_data_path, valid_data_path, test_data_path, seq_length):
94  data_loader = DataLoader(
95      train_data_path, valid_data_path, test_data_path, seq_length=seq_length)
96  data_loader.format()
97  return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \
98      data_loader.valid_data, data_loader.test_len, data_loader.test_data
99
100
101def build_net(args, seq_length):
102  if args.model == "CNN":
103    model, model_path = build_cnn(seq_length)
104  elif args.model == "LSTM":
105    model, model_path = build_lstm(seq_length)
106  else:
107    print("Please input correct model name.(CNN  LSTM)")
108  return model, model_path
109
110
111def train_net(
112    model,
113    model_path,  # pylint: disable=unused-argument
114    train_len,  # pylint: disable=unused-argument
115    train_data,
116    valid_len,
117    valid_data,  # pylint: disable=unused-argument
118    test_len,
119    test_data,
120    kind):
121  """Trains the model."""
122  calculate_model_size(model)
123  epochs = 50
124  batch_size = 64
125  model.compile(
126      optimizer="adam",
127      loss="sparse_categorical_crossentropy",
128      metrics=["accuracy"])
129  if kind == "CNN":
130    train_data = train_data.map(reshape_function)
131    test_data = test_data.map(reshape_function)
132    valid_data = valid_data.map(reshape_function)
133  test_labels = np.zeros(test_len)
134  idx = 0
135  for data, label in test_data:  # pylint: disable=unused-variable
136    test_labels[idx] = label.numpy()
137    idx += 1
138  train_data = train_data.batch(batch_size).repeat()
139  valid_data = valid_data.batch(batch_size)
140  test_data = test_data.batch(batch_size)
141  model.fit(
142      train_data,
143      epochs=epochs,
144      validation_data=valid_data,
145      steps_per_epoch=1000,
146      validation_steps=int((valid_len - 1) / batch_size + 1),
147      callbacks=[tensorboard_callback])
148  loss, acc = model.evaluate(test_data)
149  pred = np.argmax(model.predict(test_data), axis=1)
150  confusion = tf.math.confusion_matrix(
151      labels=tf.constant(test_labels),
152      predictions=tf.constant(pred),
153      num_classes=4)
154  print(confusion)
155  print("Loss {}, Accuracy {}".format(loss, acc))
156  # Convert the model to the TensorFlow Lite format without quantization
157  converter = tf.lite.TFLiteConverter.from_keras_model(model)
158  tflite_model = converter.convert()
159
160  # Save the model to disk
161  open("model.tflite", "wb").write(tflite_model)
162
163  # Convert the model to the TensorFlow Lite format with quantization
164  converter = tf.lite.TFLiteConverter.from_keras_model(model)
165  converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
166  tflite_model = converter.convert()
167
168  # Save the model to disk
169  open("model_quantized.tflite", "wb").write(tflite_model)
170
171  basic_model_size = os.path.getsize("model.tflite")
172  print("Basic model is %d bytes" % basic_model_size)
173  quantized_model_size = os.path.getsize("model_quantized.tflite")
174  print("Quantized model is %d bytes" % quantized_model_size)
175  difference = basic_model_size - quantized_model_size
176  print("Difference is %d bytes" % difference)
177
178
179if __name__ == "__main__":
180  parser = argparse.ArgumentParser()
181  parser.add_argument("--model", "-m")
182  parser.add_argument("--person", "-p")
183  args = parser.parse_args()
184
185  seq_length = 128
186
187  print("Start to load data...")
188  if args.person == "true":
189    train_len, train_data, valid_len, valid_data, test_len, test_data = \
190        load_data("./person_split/train", "./person_split/valid",
191                  "./person_split/test", seq_length)
192  else:
193    train_len, train_data, valid_len, valid_data, test_len, test_data = \
194        load_data("./data/train", "./data/valid", "./data/test", seq_length)
195
196  print("Start to build net...")
197  model, model_path = build_net(args, seq_length)
198
199  print("Start training...")
200  train_net(model, model_path, train_len, train_data, valid_len, valid_data,
201            test_len, test_data, args.model)
202
203  print("Training finished!")
204