1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20import numpy as np
21import tensorflow as tf  # pylint: disable=g-bad-import-order
22from tensorflow.lite.tutorials import dataset
23flags = tf.app.flags
24
25flags.DEFINE_string('data_dir', '/tmp/data_dir',
26                    'Directory where data is stored.')
27flags.DEFINE_string('model_file', '',
28                    'The path to the TFLite flatbuffer model file.')
29
30
31flags = flags.FLAGS
32
33
34def test_image_generator():
35  # Generates an iterator over images
36  with tf.Session() as sess:
37    input_data = tf.compat.v1.data.make_one_shot_iterator(dataset.test(
38        flags.data_dir)).get_next()
39    try:
40      while True:
41        yield sess.run(input_data)
42    except tf.errors.OutOfRangeError:
43      pass
44
45
46def run_eval(interpreter, input_image):
47  """Performs evaluation for input image over specified model.
48
49  Args:
50      interpreter: TFLite interpreter initialized with model to execute.
51      input_image: Image input to the model.
52
53  Returns:
54      output: output tensor of model being executed.
55  """
56
57  # Get input and output tensors.
58  input_details = interpreter.get_input_details()
59  output_details = interpreter.get_output_details()
60
61  # Test model on the input images.
62  input_image = np.reshape(input_image, input_details[0]['shape'])
63  interpreter.set_tensor(input_details[0]['index'], input_image)
64
65  interpreter.invoke()
66  output_data = interpreter.get_tensor(output_details[0]['index'])
67  output = np.squeeze(output_data)
68  return output
69
70
71def main(_):
72  interpreter = tf.lite.Interpreter(model_path=flags.model_file)
73  interpreter.allocate_tensors()
74  num_correct, total = 0, 0
75  for input_data in test_image_generator():
76    output = run_eval(interpreter, input_data[0])
77    total += 1
78    if output == input_data[1]:
79      num_correct += 1
80    if total % 500 == 0:
81      print('Accuracy after %i images: %f' %
82            (total, float(num_correct) / float(total)))
83
84
85if __name__ == '__main__':
86  tf.logging.set_verbosity(tf.logging.INFO)
87  tf.app.run(main)
88