1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""label_image for tflite.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import argparse 22import time 23 24import numpy as np 25from PIL import Image 26import tensorflow as tf # TF2 27 28 29def load_labels(filename): 30 with open(filename, 'r') as f: 31 return [line.strip() for line in f.readlines()] 32 33 34if __name__ == '__main__': 35 parser = argparse.ArgumentParser() 36 parser.add_argument( 37 '-i', 38 '--image', 39 default='/tmp/grace_hopper.bmp', 40 help='image to be classified') 41 parser.add_argument( 42 '-m', 43 '--model_file', 44 default='/tmp/mobilenet_v1_1.0_224_quant.tflite', 45 help='.tflite model to be executed') 46 parser.add_argument( 47 '-l', 48 '--label_file', 49 default='/tmp/labels.txt', 50 help='name of file containing labels') 51 parser.add_argument( 52 '--input_mean', 53 default=127.5, type=float, 54 help='input_mean') 55 parser.add_argument( 56 '--input_std', 57 default=127.5, type=float, 58 help='input standard deviation') 59 parser.add_argument( 60 '--num_threads', default=None, type=int, help='number of threads') 61 args = parser.parse_args() 62 63 interpreter = tf.lite.Interpreter( 64 model_path=args.model_file, num_threads=args.num_threads) 65 interpreter.allocate_tensors() 66 67 input_details = interpreter.get_input_details() 68 output_details = interpreter.get_output_details() 69 70 # check the type of the input tensor 71 floating_model = input_details[0]['dtype'] == np.float32 72 73 # NxHxWxC, H:1, W:2 74 height = input_details[0]['shape'][1] 75 width = input_details[0]['shape'][2] 76 img = Image.open(args.image).resize((width, height)) 77 78 # add N dim 79 input_data = np.expand_dims(img, axis=0) 80 81 if floating_model: 82 input_data = (np.float32(input_data) - args.input_mean) / args.input_std 83 84 interpreter.set_tensor(input_details[0]['index'], input_data) 85 86 start_time = time.time() 87 interpreter.invoke() 88 stop_time = time.time() 89 90 output_data = interpreter.get_tensor(output_details[0]['index']) 91 results = np.squeeze(output_data) 92 93 top_k = results.argsort()[-5:][::-1] 94 labels = load_labels(args.label_file) 95 for i in top_k: 96 if floating_model: 97 print('{:08.6f}: {}'.format(float(results[i]), labels[i])) 98 else: 99 print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i])) 100 101 print('time: {:.3f}ms'.format((stop_time - start_time) * 1000)) 102