• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.demo;
17 
18 import android.content.res.AssetManager;
19 import android.graphics.Bitmap;
20 import android.os.Trace;
21 import android.util.Log;
22 import java.io.BufferedReader;
23 import java.io.IOException;
24 import java.io.InputStreamReader;
25 import java.util.ArrayList;
26 import java.util.Comparator;
27 import java.util.List;
28 import java.util.PriorityQueue;
29 import java.util.Vector;
30 import org.tensorflow.Operation;
31 import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
32 
33 /** A classifier specialized to label images using TensorFlow. */
34 public class TensorFlowImageClassifier implements Classifier {
35   private static final String TAG = "TensorFlowImageClassifier";
36 
37   // Only return this many results with at least this confidence.
38   private static final int MAX_RESULTS = 3;
39   private static final float THRESHOLD = 0.1f;
40 
41   // Config values.
42   private String inputName;
43   private String outputName;
44   private int inputSize;
45   private int imageMean;
46   private float imageStd;
47 
48   // Pre-allocated buffers.
49   private Vector<String> labels = new Vector<String>();
50   private int[] intValues;
51   private float[] floatValues;
52   private float[] outputs;
53   private String[] outputNames;
54 
55   private boolean logStats = false;
56 
57   private TensorFlowInferenceInterface inferenceInterface;
58 
TensorFlowImageClassifier()59   private TensorFlowImageClassifier() {}
60 
61   /**
62    * Initializes a native TensorFlow session for classifying images.
63    *
64    * @param assetManager The asset manager to be used to load assets.
65    * @param modelFilename The filepath of the model GraphDef protocol buffer.
66    * @param labelFilename The filepath of label file for classes.
67    * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
68    * @param imageMean The assumed mean of the image values.
69    * @param imageStd The assumed std of the image values.
70    * @param inputName The label of the image input node.
71    * @param outputName The label of the output node.
72    * @throws IOException
73    */
create( AssetManager assetManager, String modelFilename, String labelFilename, int inputSize, int imageMean, float imageStd, String inputName, String outputName)74   public static Classifier create(
75       AssetManager assetManager,
76       String modelFilename,
77       String labelFilename,
78       int inputSize,
79       int imageMean,
80       float imageStd,
81       String inputName,
82       String outputName) {
83     TensorFlowImageClassifier c = new TensorFlowImageClassifier();
84     c.inputName = inputName;
85     c.outputName = outputName;
86 
87     // Read the label names into memory.
88     // TODO(andrewharp): make this handle non-assets.
89     String actualFilename = labelFilename.split("file:///android_asset/")[1];
90     Log.i(TAG, "Reading labels from: " + actualFilename);
91     BufferedReader br = null;
92     try {
93       br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
94       String line;
95       while ((line = br.readLine()) != null) {
96         c.labels.add(line);
97       }
98       br.close();
99     } catch (IOException e) {
100       throw new RuntimeException("Problem reading label file!" , e);
101     }
102 
103     c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
104 
105     // The shape of the output is [N, NUM_CLASSES], where N is the batch size.
106     final Operation operation = c.inferenceInterface.graphOperation(outputName);
107     final int numClasses = (int) operation.output(0).shape().size(1);
108     Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
109 
110     // Ideally, inputSize could have been retrieved from the shape of the input operation.  Alas,
111     // the placeholder node for input in the graphdef typically used does not specify a shape, so it
112     // must be passed in as a parameter.
113     c.inputSize = inputSize;
114     c.imageMean = imageMean;
115     c.imageStd = imageStd;
116 
117     // Pre-allocate buffers.
118     c.outputNames = new String[] {outputName};
119     c.intValues = new int[inputSize * inputSize];
120     c.floatValues = new float[inputSize * inputSize * 3];
121     c.outputs = new float[numClasses];
122 
123     return c;
124   }
125 
126   @Override
recognizeImage(final Bitmap bitmap)127   public List<Recognition> recognizeImage(final Bitmap bitmap) {
128     // Log this method so that it can be analyzed with systrace.
129     Trace.beginSection("recognizeImage");
130 
131     Trace.beginSection("preprocessBitmap");
132     // Preprocess the image data from 0-255 int to normalized float based
133     // on the provided parameters.
134     bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
135     for (int i = 0; i < intValues.length; ++i) {
136       final int val = intValues[i];
137       floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
138       floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
139       floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd;
140     }
141     Trace.endSection();
142 
143     // Copy the input data into TensorFlow.
144     Trace.beginSection("feed");
145     inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
146     Trace.endSection();
147 
148     // Run the inference call.
149     Trace.beginSection("run");
150     inferenceInterface.run(outputNames, logStats);
151     Trace.endSection();
152 
153     // Copy the output Tensor back into the output array.
154     Trace.beginSection("fetch");
155     inferenceInterface.fetch(outputName, outputs);
156     Trace.endSection();
157 
158     // Find the best classifications.
159     PriorityQueue<Recognition> pq =
160         new PriorityQueue<Recognition>(
161             3,
162             new Comparator<Recognition>() {
163               @Override
164               public int compare(Recognition lhs, Recognition rhs) {
165                 // Intentionally reversed to put high confidence at the head of the queue.
166                 return Float.compare(rhs.getConfidence(), lhs.getConfidence());
167               }
168             });
169     for (int i = 0; i < outputs.length; ++i) {
170       if (outputs[i] > THRESHOLD) {
171         pq.add(
172             new Recognition(
173                 "" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
174       }
175     }
176     final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
177     int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
178     for (int i = 0; i < recognitionsSize; ++i) {
179       recognitions.add(pq.poll());
180     }
181     Trace.endSection(); // "recognizeImage"
182     return recognitions;
183   }
184 
185   @Override
enableStatLogging(boolean logStats)186   public void enableStatLogging(boolean logStats) {
187     this.logStats = logStats;
188   }
189 
190   @Override
getStatString()191   public String getStatString() {
192     return inferenceInterface.getStatString();
193   }
194 
195   @Override
close()196   public void close() {
197     inferenceInterface.close();
198   }
199 }
200