1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""label_image for tflite"""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import argparse
22import numpy as np
23
24from PIL import Image
25
26from tensorflow.lite.python import interpreter as interpreter_wrapper
27
28def load_labels(filename):
29  my_labels = []
30  input_file = open(filename, 'r')
31  for l in input_file:
32    my_labels.append(l.strip())
33  return my_labels
34
35if __name__ == "__main__":
36  floating_model = False
37
38  parser = argparse.ArgumentParser()
39  parser.add_argument("-i", "--image", default="/tmp/grace_hopper.bmp", \
40    help="image to be classified")
41  parser.add_argument("-m", "--model_file", \
42    default="/tmp/mobilenet_v1_1.0_224_quant.tflite", \
43    help=".tflite model to be executed")
44  parser.add_argument("-l", "--label_file", default="/tmp/labels.txt", \
45    help="name of file containing labels")
46  parser.add_argument("--input_mean", default=127.5, help="input_mean")
47  parser.add_argument("--input_std", default=127.5, \
48    help="input standard deviation")
49  args = parser.parse_args()
50
51  interpreter = interpreter_wrapper.Interpreter(model_path=args.model_file)
52  interpreter.allocate_tensors()
53
54  input_details = interpreter.get_input_details()
55  output_details = interpreter.get_output_details()
56
57  # check the type of the input tensor
58  if input_details[0]['dtype'] == np.float32:
59    floating_model = True
60
61  # NxHxWxC, H:1, W:2
62  height = input_details[0]['shape'][1]
63  width = input_details[0]['shape'][2]
64  img = Image.open(args.image)
65  img = img.resize((width, height))
66
67  # add N dim
68  input_data = np.expand_dims(img, axis=0)
69
70  if floating_model:
71    input_data = (np.float32(input_data) - args.input_mean) / args.input_std
72
73  interpreter.set_tensor(input_details[0]['index'], input_data)
74
75  interpreter.invoke()
76
77  output_data = interpreter.get_tensor(output_details[0]['index'])
78  results = np.squeeze(output_data)
79
80  top_k = results.argsort()[-5:][::-1]
81  labels = load_labels(args.label_file)
82  for i in top_k:
83    if floating_model:
84      print('{0:08.6f}'.format(float(results[i]))+":", labels[i])
85    else:
86      print('{0:08.6f}'.format(float(results[i]/255.0))+":", labels[i])
87