1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.demo.tracking; 17 18 import android.content.Context; 19 import android.graphics.Canvas; 20 import android.graphics.Color; 21 import android.graphics.Matrix; 22 import android.graphics.Paint; 23 import android.graphics.Paint.Cap; 24 import android.graphics.Paint.Join; 25 import android.graphics.Paint.Style; 26 import android.graphics.RectF; 27 import android.text.TextUtils; 28 import android.util.Pair; 29 import android.util.TypedValue; 30 import android.widget.Toast; 31 import java.util.LinkedList; 32 import java.util.List; 33 import java.util.Queue; 34 import org.tensorflow.demo.Classifier.Recognition; 35 import org.tensorflow.demo.env.BorderedText; 36 import org.tensorflow.demo.env.ImageUtils; 37 import org.tensorflow.demo.env.Logger; 38 39 /** 40 * A tracker wrapping ObjectTracker that also handles non-max suppression and matching existing 41 * objects to new detections. 42 */ 43 public class MultiBoxTracker { 44 private final Logger logger = new Logger(); 45 46 private static final float TEXT_SIZE_DIP = 18; 47 48 // Maximum percentage of a box that can be overlapped by another box at detection time. Otherwise 49 // the lower scored box (new or old) will be removed. 50 private static final float MAX_OVERLAP = 0.2f; 51 52 private static final float MIN_SIZE = 16.0f; 53 54 // Allow replacement of the tracked box with new results if 55 // correlation has dropped below this level. 56 private static final float MARGINAL_CORRELATION = 0.75f; 57 58 // Consider object to be lost if correlation falls below this threshold. 59 private static final float MIN_CORRELATION = 0.3f; 60 61 private static final int[] COLORS = { 62 Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA, Color.WHITE, 63 Color.parseColor("#55FF55"), Color.parseColor("#FFA500"), Color.parseColor("#FF8888"), 64 Color.parseColor("#AAAAFF"), Color.parseColor("#FFFFAA"), Color.parseColor("#55AAAA"), 65 Color.parseColor("#AA33AA"), Color.parseColor("#0D0068") 66 }; 67 68 private final Queue<Integer> availableColors = new LinkedList<Integer>(); 69 70 public ObjectTracker objectTracker; 71 72 final List<Pair<Float, RectF>> screenRects = new LinkedList<Pair<Float, RectF>>(); 73 74 private static class TrackedRecognition { 75 ObjectTracker.TrackedObject trackedObject; 76 RectF location; 77 float detectionConfidence; 78 int color; 79 String title; 80 } 81 82 private final List<TrackedRecognition> trackedObjects = new LinkedList<TrackedRecognition>(); 83 84 private final Paint boxPaint = new Paint(); 85 86 private final float textSizePx; 87 private final BorderedText borderedText; 88 89 private Matrix frameToCanvasMatrix; 90 91 private int frameWidth; 92 private int frameHeight; 93 94 private int sensorOrientation; 95 private Context context; 96 MultiBoxTracker(final Context context)97 public MultiBoxTracker(final Context context) { 98 this.context = context; 99 for (final int color : COLORS) { 100 availableColors.add(color); 101 } 102 103 boxPaint.setColor(Color.RED); 104 boxPaint.setStyle(Style.STROKE); 105 boxPaint.setStrokeWidth(12.0f); 106 boxPaint.setStrokeCap(Cap.ROUND); 107 boxPaint.setStrokeJoin(Join.ROUND); 108 boxPaint.setStrokeMiter(100); 109 110 textSizePx = 111 TypedValue.applyDimension( 112 TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, context.getResources().getDisplayMetrics()); 113 borderedText = new BorderedText(textSizePx); 114 } 115 getFrameToCanvasMatrix()116 private Matrix getFrameToCanvasMatrix() { 117 return frameToCanvasMatrix; 118 } 119 drawDebug(final Canvas canvas)120 public synchronized void drawDebug(final Canvas canvas) { 121 final Paint textPaint = new Paint(); 122 textPaint.setColor(Color.WHITE); 123 textPaint.setTextSize(60.0f); 124 125 final Paint boxPaint = new Paint(); 126 boxPaint.setColor(Color.RED); 127 boxPaint.setAlpha(200); 128 boxPaint.setStyle(Style.STROKE); 129 130 for (final Pair<Float, RectF> detection : screenRects) { 131 final RectF rect = detection.second; 132 canvas.drawRect(rect, boxPaint); 133 canvas.drawText("" + detection.first, rect.left, rect.top, textPaint); 134 borderedText.drawText(canvas, rect.centerX(), rect.centerY(), "" + detection.first); 135 } 136 137 if (objectTracker == null) { 138 return; 139 } 140 141 // Draw correlations. 142 for (final TrackedRecognition recognition : trackedObjects) { 143 final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; 144 145 final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame(); 146 147 if (getFrameToCanvasMatrix().mapRect(trackedPos)) { 148 final String labelString = String.format("%.2f", trackedObject.getCurrentCorrelation()); 149 borderedText.drawText(canvas, trackedPos.right, trackedPos.bottom, labelString); 150 } 151 } 152 153 final Matrix matrix = getFrameToCanvasMatrix(); 154 objectTracker.drawDebug(canvas, matrix); 155 } 156 trackResults( final List<Recognition> results, final byte[] frame, final long timestamp)157 public synchronized void trackResults( 158 final List<Recognition> results, final byte[] frame, final long timestamp) { 159 logger.i("Processing %d results from %d", results.size(), timestamp); 160 processResults(timestamp, results, frame); 161 } 162 draw(final Canvas canvas)163 public synchronized void draw(final Canvas canvas) { 164 final boolean rotated = sensorOrientation % 180 == 90; 165 final float multiplier = 166 Math.min(canvas.getHeight() / (float) (rotated ? frameWidth : frameHeight), 167 canvas.getWidth() / (float) (rotated ? frameHeight : frameWidth)); 168 frameToCanvasMatrix = 169 ImageUtils.getTransformationMatrix( 170 frameWidth, 171 frameHeight, 172 (int) (multiplier * (rotated ? frameHeight : frameWidth)), 173 (int) (multiplier * (rotated ? frameWidth : frameHeight)), 174 sensorOrientation, 175 false); 176 for (final TrackedRecognition recognition : trackedObjects) { 177 final RectF trackedPos = 178 (objectTracker != null) 179 ? recognition.trackedObject.getTrackedPositionInPreviewFrame() 180 : new RectF(recognition.location); 181 182 getFrameToCanvasMatrix().mapRect(trackedPos); 183 boxPaint.setColor(recognition.color); 184 185 final float cornerSize = Math.min(trackedPos.width(), trackedPos.height()) / 8.0f; 186 canvas.drawRoundRect(trackedPos, cornerSize, cornerSize, boxPaint); 187 188 final String labelString = 189 !TextUtils.isEmpty(recognition.title) 190 ? String.format("%s %.2f", recognition.title, recognition.detectionConfidence) 191 : String.format("%.2f", recognition.detectionConfidence); 192 borderedText.drawText(canvas, trackedPos.left + cornerSize, trackedPos.bottom, labelString); 193 } 194 } 195 196 private boolean initialized = false; 197 onFrame( final int w, final int h, final int rowStride, final int sensorOrientation, final byte[] frame, final long timestamp)198 public synchronized void onFrame( 199 final int w, 200 final int h, 201 final int rowStride, 202 final int sensorOrientation, 203 final byte[] frame, 204 final long timestamp) { 205 if (objectTracker == null && !initialized) { 206 ObjectTracker.clearInstance(); 207 208 logger.i("Initializing ObjectTracker: %dx%d", w, h); 209 objectTracker = ObjectTracker.getInstance(w, h, rowStride, true); 210 frameWidth = w; 211 frameHeight = h; 212 this.sensorOrientation = sensorOrientation; 213 initialized = true; 214 215 if (objectTracker == null) { 216 String message = 217 "Object tracking support not found. " 218 + "See tensorflow/tools/android/test/README.md for details."; 219 Toast.makeText(context, message, Toast.LENGTH_LONG).show(); 220 logger.e(message); 221 } 222 } 223 224 if (objectTracker == null) { 225 return; 226 } 227 228 objectTracker.nextFrame(frame, null, timestamp, null, true); 229 230 // Clean up any objects not worth tracking any more. 231 final LinkedList<TrackedRecognition> copyList = 232 new LinkedList<TrackedRecognition>(trackedObjects); 233 for (final TrackedRecognition recognition : copyList) { 234 final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; 235 final float correlation = trackedObject.getCurrentCorrelation(); 236 if (correlation < MIN_CORRELATION) { 237 logger.v("Removing tracked object %s because NCC is %.2f", trackedObject, correlation); 238 trackedObject.stopTracking(); 239 trackedObjects.remove(recognition); 240 241 availableColors.add(recognition.color); 242 } 243 } 244 } 245 processResults( final long timestamp, final List<Recognition> results, final byte[] originalFrame)246 private void processResults( 247 final long timestamp, final List<Recognition> results, final byte[] originalFrame) { 248 final List<Pair<Float, Recognition>> rectsToTrack = new LinkedList<Pair<Float, Recognition>>(); 249 250 screenRects.clear(); 251 final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix()); 252 253 for (final Recognition result : results) { 254 if (result.getLocation() == null) { 255 continue; 256 } 257 final RectF detectionFrameRect = new RectF(result.getLocation()); 258 259 final RectF detectionScreenRect = new RectF(); 260 rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect); 261 262 logger.v( 263 "Result! Frame: " + result.getLocation() + " mapped to screen:" + detectionScreenRect); 264 265 screenRects.add(new Pair<Float, RectF>(result.getConfidence(), detectionScreenRect)); 266 267 if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) { 268 logger.w("Degenerate rectangle! " + detectionFrameRect); 269 continue; 270 } 271 272 rectsToTrack.add(new Pair<Float, Recognition>(result.getConfidence(), result)); 273 } 274 275 if (rectsToTrack.isEmpty()) { 276 logger.v("Nothing to track, aborting."); 277 return; 278 } 279 280 if (objectTracker == null) { 281 trackedObjects.clear(); 282 for (final Pair<Float, Recognition> potential : rectsToTrack) { 283 final TrackedRecognition trackedRecognition = new TrackedRecognition(); 284 trackedRecognition.detectionConfidence = potential.first; 285 trackedRecognition.location = new RectF(potential.second.getLocation()); 286 trackedRecognition.trackedObject = null; 287 trackedRecognition.title = potential.second.getTitle(); 288 trackedRecognition.color = COLORS[trackedObjects.size()]; 289 trackedObjects.add(trackedRecognition); 290 291 if (trackedObjects.size() >= COLORS.length) { 292 break; 293 } 294 } 295 return; 296 } 297 298 logger.i("%d rects to track", rectsToTrack.size()); 299 for (final Pair<Float, Recognition> potential : rectsToTrack) { 300 handleDetection(originalFrame, timestamp, potential); 301 } 302 } 303 handleDetection( final byte[] frameCopy, final long timestamp, final Pair<Float, Recognition> potential)304 private void handleDetection( 305 final byte[] frameCopy, final long timestamp, final Pair<Float, Recognition> potential) { 306 final ObjectTracker.TrackedObject potentialObject = 307 objectTracker.trackObject(potential.second.getLocation(), timestamp, frameCopy); 308 309 final float potentialCorrelation = potentialObject.getCurrentCorrelation(); 310 logger.v( 311 "Tracked object went from %s to %s with correlation %.2f", 312 potential.second, potentialObject.getTrackedPositionInPreviewFrame(), potentialCorrelation); 313 314 if (potentialCorrelation < MARGINAL_CORRELATION) { 315 logger.v("Correlation too low to begin tracking %s.", potentialObject); 316 potentialObject.stopTracking(); 317 return; 318 } 319 320 final List<TrackedRecognition> removeList = new LinkedList<TrackedRecognition>(); 321 322 float maxIntersect = 0.0f; 323 324 // This is the current tracked object whose color we will take. If left null we'll take the 325 // first one from the color queue. 326 TrackedRecognition recogToReplace = null; 327 328 // Look for intersections that will be overridden by this object or an intersection that would 329 // prevent this one from being placed. 330 for (final TrackedRecognition trackedRecognition : trackedObjects) { 331 final RectF a = trackedRecognition.trackedObject.getTrackedPositionInPreviewFrame(); 332 final RectF b = potentialObject.getTrackedPositionInPreviewFrame(); 333 final RectF intersection = new RectF(); 334 final boolean intersects = intersection.setIntersect(a, b); 335 336 final float intersectArea = intersection.width() * intersection.height(); 337 final float totalArea = a.width() * a.height() + b.width() * b.height() - intersectArea; 338 final float intersectOverUnion = intersectArea / totalArea; 339 340 // If there is an intersection with this currently tracked box above the maximum overlap 341 // percentage allowed, either the new recognition needs to be dismissed or the old 342 // recognition needs to be removed and possibly replaced with the new one. 343 if (intersects && intersectOverUnion > MAX_OVERLAP) { 344 if (potential.first < trackedRecognition.detectionConfidence 345 && trackedRecognition.trackedObject.getCurrentCorrelation() > MARGINAL_CORRELATION) { 346 // If track for the existing object is still going strong and the detection score was 347 // good, reject this new object. 348 potentialObject.stopTracking(); 349 return; 350 } else { 351 removeList.add(trackedRecognition); 352 353 // Let the previously tracked object with max intersection amount donate its color to 354 // the new object. 355 if (intersectOverUnion > maxIntersect) { 356 maxIntersect = intersectOverUnion; 357 recogToReplace = trackedRecognition; 358 } 359 } 360 } 361 } 362 363 // If we're already tracking the max object and no intersections were found to bump off, 364 // pick the worst current tracked object to remove, if it's also worse than this candidate 365 // object. 366 if (availableColors.isEmpty() && removeList.isEmpty()) { 367 for (final TrackedRecognition candidate : trackedObjects) { 368 if (candidate.detectionConfidence < potential.first) { 369 if (recogToReplace == null 370 || candidate.detectionConfidence < recogToReplace.detectionConfidence) { 371 // Save it so that we use this color for the new object. 372 recogToReplace = candidate; 373 } 374 } 375 } 376 if (recogToReplace != null) { 377 logger.v("Found non-intersecting object to remove."); 378 removeList.add(recogToReplace); 379 } else { 380 logger.v("No non-intersecting object found to remove"); 381 } 382 } 383 384 // Remove everything that got intersected. 385 for (final TrackedRecognition trackedRecognition : removeList) { 386 logger.v( 387 "Removing tracked object %s with detection confidence %.2f, correlation %.2f", 388 trackedRecognition.trackedObject, 389 trackedRecognition.detectionConfidence, 390 trackedRecognition.trackedObject.getCurrentCorrelation()); 391 trackedRecognition.trackedObject.stopTracking(); 392 trackedObjects.remove(trackedRecognition); 393 if (trackedRecognition != recogToReplace) { 394 availableColors.add(trackedRecognition.color); 395 } 396 } 397 398 if (recogToReplace == null && availableColors.isEmpty()) { 399 logger.e("No room to track this object, aborting."); 400 potentialObject.stopTracking(); 401 return; 402 } 403 404 // Finally safe to say we can track this object. 405 logger.v( 406 "Tracking object %s (%s) with detection confidence %.2f at position %s", 407 potentialObject, 408 potential.second.getTitle(), 409 potential.first, 410 potential.second.getLocation()); 411 final TrackedRecognition trackedRecognition = new TrackedRecognition(); 412 trackedRecognition.detectionConfidence = potential.first; 413 trackedRecognition.trackedObject = potentialObject; 414 trackedRecognition.title = potential.second.getTitle(); 415 416 // Use the color from a replaced object before taking one from the color queue. 417 trackedRecognition.color = 418 recogToReplace != null ? recogToReplace.color : availableColors.poll(); 419 trackedObjects.add(trackedRecognition); 420 } 421 } 422