1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.demo.tracking;
17 
18 import android.content.Context;
19 import android.graphics.Canvas;
20 import android.graphics.Color;
21 import android.graphics.Matrix;
22 import android.graphics.Paint;
23 import android.graphics.Paint.Cap;
24 import android.graphics.Paint.Join;
25 import android.graphics.Paint.Style;
26 import android.graphics.RectF;
27 import android.text.TextUtils;
28 import android.util.Pair;
29 import android.util.TypedValue;
30 import android.widget.Toast;
31 import java.util.LinkedList;
32 import java.util.List;
33 import java.util.Queue;
34 import org.tensorflow.demo.Classifier.Recognition;
35 import org.tensorflow.demo.env.BorderedText;
36 import org.tensorflow.demo.env.ImageUtils;
37 import org.tensorflow.demo.env.Logger;
38 
39 /**
40  * A tracker wrapping ObjectTracker that also handles non-max suppression and matching existing
41  * objects to new detections.
42  */
43 public class MultiBoxTracker {
44   private final Logger logger = new Logger();
45 
46   private static final float TEXT_SIZE_DIP = 18;
47 
48   // Maximum percentage of a box that can be overlapped by another box at detection time. Otherwise
49   // the lower scored box (new or old) will be removed.
50   private static final float MAX_OVERLAP = 0.2f;
51 
52   private static final float MIN_SIZE = 16.0f;
53 
54   // Allow replacement of the tracked box with new results if
55   // correlation has dropped below this level.
56   private static final float MARGINAL_CORRELATION = 0.75f;
57 
58   // Consider object to be lost if correlation falls below this threshold.
59   private static final float MIN_CORRELATION = 0.3f;
60 
61   private static final int[] COLORS = {
62     Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA, Color.WHITE,
63     Color.parseColor("#55FF55"), Color.parseColor("#FFA500"), Color.parseColor("#FF8888"),
64     Color.parseColor("#AAAAFF"), Color.parseColor("#FFFFAA"), Color.parseColor("#55AAAA"),
65     Color.parseColor("#AA33AA"), Color.parseColor("#0D0068")
66   };
67 
68   private final Queue<Integer> availableColors = new LinkedList<Integer>();
69 
70   public ObjectTracker objectTracker;
71 
72   final List<Pair<Float, RectF>> screenRects = new LinkedList<Pair<Float, RectF>>();
73 
74   private static class TrackedRecognition {
75     ObjectTracker.TrackedObject trackedObject;
76     RectF location;
77     float detectionConfidence;
78     int color;
79     String title;
80   }
81 
82   private final List<TrackedRecognition> trackedObjects = new LinkedList<TrackedRecognition>();
83 
84   private final Paint boxPaint = new Paint();
85 
86   private final float textSizePx;
87   private final BorderedText borderedText;
88 
89   private Matrix frameToCanvasMatrix;
90 
91   private int frameWidth;
92   private int frameHeight;
93 
94   private int sensorOrientation;
95   private Context context;
96 
MultiBoxTracker(final Context context)97   public MultiBoxTracker(final Context context) {
98     this.context = context;
99     for (final int color : COLORS) {
100       availableColors.add(color);
101     }
102 
103     boxPaint.setColor(Color.RED);
104     boxPaint.setStyle(Style.STROKE);
105     boxPaint.setStrokeWidth(12.0f);
106     boxPaint.setStrokeCap(Cap.ROUND);
107     boxPaint.setStrokeJoin(Join.ROUND);
108     boxPaint.setStrokeMiter(100);
109 
110     textSizePx =
111         TypedValue.applyDimension(
112             TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, context.getResources().getDisplayMetrics());
113     borderedText = new BorderedText(textSizePx);
114   }
115 
getFrameToCanvasMatrix()116   private Matrix getFrameToCanvasMatrix() {
117     return frameToCanvasMatrix;
118   }
119 
drawDebug(final Canvas canvas)120   public synchronized void drawDebug(final Canvas canvas) {
121     final Paint textPaint = new Paint();
122     textPaint.setColor(Color.WHITE);
123     textPaint.setTextSize(60.0f);
124 
125     final Paint boxPaint = new Paint();
126     boxPaint.setColor(Color.RED);
127     boxPaint.setAlpha(200);
128     boxPaint.setStyle(Style.STROKE);
129 
130     for (final Pair<Float, RectF> detection : screenRects) {
131       final RectF rect = detection.second;
132       canvas.drawRect(rect, boxPaint);
133       canvas.drawText("" + detection.first, rect.left, rect.top, textPaint);
134       borderedText.drawText(canvas, rect.centerX(), rect.centerY(), "" + detection.first);
135     }
136 
137     if (objectTracker == null) {
138       return;
139     }
140 
141     // Draw correlations.
142     for (final TrackedRecognition recognition : trackedObjects) {
143       final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
144 
145       final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame();
146 
147       if (getFrameToCanvasMatrix().mapRect(trackedPos)) {
148         final String labelString = String.format("%.2f", trackedObject.getCurrentCorrelation());
149         borderedText.drawText(canvas, trackedPos.right, trackedPos.bottom, labelString);
150       }
151     }
152 
153     final Matrix matrix = getFrameToCanvasMatrix();
154     objectTracker.drawDebug(canvas, matrix);
155   }
156 
trackResults( final List<Recognition> results, final byte[] frame, final long timestamp)157   public synchronized void trackResults(
158       final List<Recognition> results, final byte[] frame, final long timestamp) {
159     logger.i("Processing %d results from %d", results.size(), timestamp);
160     processResults(timestamp, results, frame);
161   }
162 
draw(final Canvas canvas)163   public synchronized void draw(final Canvas canvas) {
164     final boolean rotated = sensorOrientation % 180 == 90;
165     final float multiplier =
166         Math.min(canvas.getHeight() / (float) (rotated ? frameWidth : frameHeight),
167                  canvas.getWidth() / (float) (rotated ? frameHeight : frameWidth));
168     frameToCanvasMatrix =
169         ImageUtils.getTransformationMatrix(
170             frameWidth,
171             frameHeight,
172             (int) (multiplier * (rotated ? frameHeight : frameWidth)),
173             (int) (multiplier * (rotated ? frameWidth : frameHeight)),
174             sensorOrientation,
175             false);
176     for (final TrackedRecognition recognition : trackedObjects) {
177       final RectF trackedPos =
178           (objectTracker != null)
179               ? recognition.trackedObject.getTrackedPositionInPreviewFrame()
180               : new RectF(recognition.location);
181 
182       getFrameToCanvasMatrix().mapRect(trackedPos);
183       boxPaint.setColor(recognition.color);
184 
185       final float cornerSize = Math.min(trackedPos.width(), trackedPos.height()) / 8.0f;
186       canvas.drawRoundRect(trackedPos, cornerSize, cornerSize, boxPaint);
187 
188       final String labelString =
189           !TextUtils.isEmpty(recognition.title)
190               ? String.format("%s %.2f", recognition.title, recognition.detectionConfidence)
191               : String.format("%.2f", recognition.detectionConfidence);
192       borderedText.drawText(canvas, trackedPos.left + cornerSize, trackedPos.bottom, labelString);
193     }
194   }
195 
196   private boolean initialized = false;
197 
onFrame( final int w, final int h, final int rowStride, final int sensorOrientation, final byte[] frame, final long timestamp)198   public synchronized void onFrame(
199       final int w,
200       final int h,
201       final int rowStride,
202       final int sensorOrientation,
203       final byte[] frame,
204       final long timestamp) {
205     if (objectTracker == null && !initialized) {
206       ObjectTracker.clearInstance();
207 
208       logger.i("Initializing ObjectTracker: %dx%d", w, h);
209       objectTracker = ObjectTracker.getInstance(w, h, rowStride, true);
210       frameWidth = w;
211       frameHeight = h;
212       this.sensorOrientation = sensorOrientation;
213       initialized = true;
214 
215       if (objectTracker == null) {
216         String message =
217             "Object tracking support not found. "
218                 + "See tensorflow/tools/android/test/README.md for details.";
219         Toast.makeText(context, message, Toast.LENGTH_LONG).show();
220         logger.e(message);
221       }
222     }
223 
224     if (objectTracker == null) {
225       return;
226     }
227 
228     objectTracker.nextFrame(frame, null, timestamp, null, true);
229 
230     // Clean up any objects not worth tracking any more.
231     final LinkedList<TrackedRecognition> copyList =
232         new LinkedList<TrackedRecognition>(trackedObjects);
233     for (final TrackedRecognition recognition : copyList) {
234       final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
235       final float correlation = trackedObject.getCurrentCorrelation();
236       if (correlation < MIN_CORRELATION) {
237         logger.v("Removing tracked object %s because NCC is %.2f", trackedObject, correlation);
238         trackedObject.stopTracking();
239         trackedObjects.remove(recognition);
240 
241         availableColors.add(recognition.color);
242       }
243     }
244   }
245 
processResults( final long timestamp, final List<Recognition> results, final byte[] originalFrame)246   private void processResults(
247       final long timestamp, final List<Recognition> results, final byte[] originalFrame) {
248     final List<Pair<Float, Recognition>> rectsToTrack = new LinkedList<Pair<Float, Recognition>>();
249 
250     screenRects.clear();
251     final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix());
252 
253     for (final Recognition result : results) {
254       if (result.getLocation() == null) {
255         continue;
256       }
257       final RectF detectionFrameRect = new RectF(result.getLocation());
258 
259       final RectF detectionScreenRect = new RectF();
260       rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect);
261 
262       logger.v(
263           "Result! Frame: " + result.getLocation() + " mapped to screen:" + detectionScreenRect);
264 
265       screenRects.add(new Pair<Float, RectF>(result.getConfidence(), detectionScreenRect));
266 
267       if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) {
268         logger.w("Degenerate rectangle! " + detectionFrameRect);
269         continue;
270       }
271 
272       rectsToTrack.add(new Pair<Float, Recognition>(result.getConfidence(), result));
273     }
274 
275     if (rectsToTrack.isEmpty()) {
276       logger.v("Nothing to track, aborting.");
277       return;
278     }
279 
280     if (objectTracker == null) {
281       trackedObjects.clear();
282       for (final Pair<Float, Recognition> potential : rectsToTrack) {
283         final TrackedRecognition trackedRecognition = new TrackedRecognition();
284         trackedRecognition.detectionConfidence = potential.first;
285         trackedRecognition.location = new RectF(potential.second.getLocation());
286         trackedRecognition.trackedObject = null;
287         trackedRecognition.title = potential.second.getTitle();
288         trackedRecognition.color = COLORS[trackedObjects.size()];
289         trackedObjects.add(trackedRecognition);
290 
291         if (trackedObjects.size() >= COLORS.length) {
292           break;
293         }
294       }
295       return;
296     }
297 
298     logger.i("%d rects to track", rectsToTrack.size());
299     for (final Pair<Float, Recognition> potential : rectsToTrack) {
300       handleDetection(originalFrame, timestamp, potential);
301     }
302   }
303 
handleDetection( final byte[] frameCopy, final long timestamp, final Pair<Float, Recognition> potential)304   private void handleDetection(
305       final byte[] frameCopy, final long timestamp, final Pair<Float, Recognition> potential) {
306     final ObjectTracker.TrackedObject potentialObject =
307         objectTracker.trackObject(potential.second.getLocation(), timestamp, frameCopy);
308 
309     final float potentialCorrelation = potentialObject.getCurrentCorrelation();
310     logger.v(
311         "Tracked object went from %s to %s with correlation %.2f",
312         potential.second, potentialObject.getTrackedPositionInPreviewFrame(), potentialCorrelation);
313 
314     if (potentialCorrelation < MARGINAL_CORRELATION) {
315       logger.v("Correlation too low to begin tracking %s.", potentialObject);
316       potentialObject.stopTracking();
317       return;
318     }
319 
320     final List<TrackedRecognition> removeList = new LinkedList<TrackedRecognition>();
321 
322     float maxIntersect = 0.0f;
323 
324     // This is the current tracked object whose color we will take. If left null we'll take the
325     // first one from the color queue.
326     TrackedRecognition recogToReplace = null;
327 
328     // Look for intersections that will be overridden by this object or an intersection that would
329     // prevent this one from being placed.
330     for (final TrackedRecognition trackedRecognition : trackedObjects) {
331       final RectF a = trackedRecognition.trackedObject.getTrackedPositionInPreviewFrame();
332       final RectF b = potentialObject.getTrackedPositionInPreviewFrame();
333       final RectF intersection = new RectF();
334       final boolean intersects = intersection.setIntersect(a, b);
335 
336       final float intersectArea = intersection.width() * intersection.height();
337       final float totalArea = a.width() * a.height() + b.width() * b.height() - intersectArea;
338       final float intersectOverUnion = intersectArea / totalArea;
339 
340       // If there is an intersection with this currently tracked box above the maximum overlap
341       // percentage allowed, either the new recognition needs to be dismissed or the old
342       // recognition needs to be removed and possibly replaced with the new one.
343       if (intersects && intersectOverUnion > MAX_OVERLAP) {
344         if (potential.first < trackedRecognition.detectionConfidence
345             && trackedRecognition.trackedObject.getCurrentCorrelation() > MARGINAL_CORRELATION) {
346           // If track for the existing object is still going strong and the detection score was
347           // good, reject this new object.
348           potentialObject.stopTracking();
349           return;
350         } else {
351           removeList.add(trackedRecognition);
352 
353           // Let the previously tracked object with max intersection amount donate its color to
354           // the new object.
355           if (intersectOverUnion > maxIntersect) {
356             maxIntersect = intersectOverUnion;
357             recogToReplace = trackedRecognition;
358           }
359         }
360       }
361     }
362 
363     // If we're already tracking the max object and no intersections were found to bump off,
364     // pick the worst current tracked object to remove, if it's also worse than this candidate
365     // object.
366     if (availableColors.isEmpty() && removeList.isEmpty()) {
367       for (final TrackedRecognition candidate : trackedObjects) {
368         if (candidate.detectionConfidence < potential.first) {
369           if (recogToReplace == null
370               || candidate.detectionConfidence < recogToReplace.detectionConfidence) {
371             // Save it so that we use this color for the new object.
372             recogToReplace = candidate;
373           }
374         }
375       }
376       if (recogToReplace != null) {
377         logger.v("Found non-intersecting object to remove.");
378         removeList.add(recogToReplace);
379       } else {
380         logger.v("No non-intersecting object found to remove");
381       }
382     }
383 
384     // Remove everything that got intersected.
385     for (final TrackedRecognition trackedRecognition : removeList) {
386       logger.v(
387           "Removing tracked object %s with detection confidence %.2f, correlation %.2f",
388           trackedRecognition.trackedObject,
389           trackedRecognition.detectionConfidence,
390           trackedRecognition.trackedObject.getCurrentCorrelation());
391       trackedRecognition.trackedObject.stopTracking();
392       trackedObjects.remove(trackedRecognition);
393       if (trackedRecognition != recogToReplace) {
394         availableColors.add(trackedRecognition.color);
395       }
396     }
397 
398     if (recogToReplace == null && availableColors.isEmpty()) {
399       logger.e("No room to track this object, aborting.");
400       potentialObject.stopTracking();
401       return;
402     }
403 
404     // Finally safe to say we can track this object.
405     logger.v(
406         "Tracking object %s (%s) with detection confidence %.2f at position %s",
407         potentialObject,
408         potential.second.getTitle(),
409         potential.first,
410         potential.second.getLocation());
411     final TrackedRecognition trackedRecognition = new TrackedRecognition();
412     trackedRecognition.detectionConfidence = potential.first;
413     trackedRecognition.trackedObject = potentialObject;
414     trackedRecognition.title = potential.second.getTitle();
415 
416     // Use the color from a replaced object before taking one from the color queue.
417     trackedRecognition.color =
418         recogToReplace != null ? recogToReplace.color : availableColors.poll();
419     trackedObjects.add(trackedRecognition);
420   }
421 }
422