1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20import numpy as np 21import tensorflow as tf # pylint: disable=g-bad-import-order 22from tensorflow.lite.tutorials import dataset 23flags = tf.app.flags 24 25flags.DEFINE_string('data_dir', '/tmp/data_dir', 26 'Directory where data is stored.') 27flags.DEFINE_string('model_file', '', 28 'The path to the TFLite flatbuffer model file.') 29 30 31flags = flags.FLAGS 32 33 34def test_image_generator(): 35 # Generates an iterator over images 36 with tf.Session() as sess: 37 input_data = tf.compat.v1.data.make_one_shot_iterator(dataset.test( 38 flags.data_dir)).get_next() 39 try: 40 while True: 41 yield sess.run(input_data) 42 except tf.errors.OutOfRangeError: 43 pass 44 45 46def run_eval(interpreter, input_image): 47 """Performs evaluation for input image over specified model. 48 49 Args: 50 interpreter: TFLite interpreter initialized with model to execute. 51 input_image: Image input to the model. 52 53 Returns: 54 output: output tensor of model being executed. 55 """ 56 57 # Get input and output tensors. 58 input_details = interpreter.get_input_details() 59 output_details = interpreter.get_output_details() 60 61 # Test model on the input images. 62 input_image = np.reshape(input_image, input_details[0]['shape']) 63 interpreter.set_tensor(input_details[0]['index'], input_image) 64 65 interpreter.invoke() 66 output_data = interpreter.get_tensor(output_details[0]['index']) 67 output = np.squeeze(output_data) 68 return output 69 70 71def main(_): 72 interpreter = tf.lite.Interpreter(model_path=flags.model_file) 73 interpreter.allocate_tensors() 74 num_correct, total = 0, 0 75 for input_data in test_image_generator(): 76 output = run_eval(interpreter, input_data[0]) 77 total += 1 78 if output == input_data[1]: 79 num_correct += 1 80 if total % 500 == 0: 81 print('Accuracy after %i images: %f' % 82 (total, float(num_correct) / float(total))) 83 84 85if __name__ == '__main__': 86 tf.logging.set_verbosity(tf.logging.INFO) 87 tf.app.run(main) 88