1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""label_image for tflite."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import argparse
22import time
23
24import numpy as np
25from PIL import Image
26import tensorflow as tf # TF2
27
28
29def load_labels(filename):
30  with open(filename, 'r') as f:
31    return [line.strip() for line in f.readlines()]
32
33
34if __name__ == '__main__':
35  parser = argparse.ArgumentParser()
36  parser.add_argument(
37      '-i',
38      '--image',
39      default='/tmp/grace_hopper.bmp',
40      help='image to be classified')
41  parser.add_argument(
42      '-m',
43      '--model_file',
44      default='/tmp/mobilenet_v1_1.0_224_quant.tflite',
45      help='.tflite model to be executed')
46  parser.add_argument(
47      '-l',
48      '--label_file',
49      default='/tmp/labels.txt',
50      help='name of file containing labels')
51  parser.add_argument(
52      '--input_mean',
53      default=127.5, type=float,
54      help='input_mean')
55  parser.add_argument(
56      '--input_std',
57      default=127.5, type=float,
58      help='input standard deviation')
59  parser.add_argument(
60      '--num_threads', default=None, type=int, help='number of threads')
61  args = parser.parse_args()
62
63  interpreter = tf.lite.Interpreter(
64      model_path=args.model_file, num_threads=args.num_threads)
65  interpreter.allocate_tensors()
66
67  input_details = interpreter.get_input_details()
68  output_details = interpreter.get_output_details()
69
70  # check the type of the input tensor
71  floating_model = input_details[0]['dtype'] == np.float32
72
73  # NxHxWxC, H:1, W:2
74  height = input_details[0]['shape'][1]
75  width = input_details[0]['shape'][2]
76  img = Image.open(args.image).resize((width, height))
77
78  # add N dim
79  input_data = np.expand_dims(img, axis=0)
80
81  if floating_model:
82    input_data = (np.float32(input_data) - args.input_mean) / args.input_std
83
84  interpreter.set_tensor(input_details[0]['index'], input_data)
85
86  start_time = time.time()
87  interpreter.invoke()
88  stop_time = time.time()
89
90  output_data = interpreter.get_tensor(output_details[0]['index'])
91  results = np.squeeze(output_data)
92
93  top_k = results.argsort()[-5:][::-1]
94  labels = load_labels(args.label_file)
95  for i in top_k:
96    if floating_model:
97      print('{:08.6f}: {}'.format(float(results[i]), labels[i]))
98    else:
99      print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i]))
100
101  print('time: {:.3f}ms'.format((stop_time - start_time) * 1000))
102