1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
18 
19 #include <map>
20 #include <string>
21 
22 #include "tensorflow/examples/android/jni/object_tracking/geom.h"
23 #include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
24 #include "tensorflow/examples/android/jni/object_tracking/logging.h"
25 #include "tensorflow/examples/android/jni/object_tracking/time_log.h"
26 #include "tensorflow/examples/android/jni/object_tracking/utils.h"
27 
28 #include "tensorflow/examples/android/jni/object_tracking/config.h"
29 #include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
30 #include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
31 #include "tensorflow/examples/android/jni/object_tracking/object_model.h"
32 #include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
33 #include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
34 
35 namespace tf_tracking {
36 
37 typedef std::map<const std::string, TrackedObject*> TrackedObjectMap;
38 
39 inline std::ostream& operator<<(std::ostream& stream,
40                                 const TrackedObjectMap& map) {
41   for (TrackedObjectMap::const_iterator iter = map.begin();
42       iter != map.end(); ++iter) {
43     const TrackedObject& tracked_object = *iter->second;
44     const std::string& key = iter->first;
45     stream << key << ": " << tracked_object;
46   }
47   return stream;
48 }
49 
50 
51 // ObjectTracker is the highest-level class in the tracking/detection framework.
52 // It handles basic image processing, keypoint detection, keypoint tracking,
53 // object tracking, and object detection/relocalization.
54 class ObjectTracker {
55  public:
56   ObjectTracker(const TrackerConfig* const config,
57                 ObjectDetectorBase* const detector);
58   virtual ~ObjectTracker();
59 
NextFrame(const uint8_t * const new_frame,const int64_t timestamp,const float * const alignment_matrix_2x3)60   virtual void NextFrame(const uint8_t* const new_frame,
61                          const int64_t timestamp,
62                          const float* const alignment_matrix_2x3) {
63     NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3);
64   }
65 
66   // Called upon the arrival of a new frame of raw data.
67   // Does all image processing, keypoint detection, and object
68   // tracking/detection for registered objects.
69   // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that
70   // represents the main transformation that has happened between the last
71   // and the current frame.
72   // Argument align_level is the pyramid level (where 0 == finest) that
73   // the matrix is valid for.
74   virtual void NextFrame(const uint8_t* const new_frame,
75                          const uint8_t* const uv_frame, const int64_t timestamp,
76                          const float* const alignment_matrix_2x3);
77 
78   virtual void RegisterNewObjectWithAppearance(const std::string& id,
79                                                const uint8_t* const new_frame,
80                                                const BoundingBox& bounding_box);
81 
82   // Updates the position of a tracked object, given that it was known to be at
83   // a certain position at some point in the past.
84   virtual void SetPreviousPositionOfObject(const std::string& id,
85                                            const BoundingBox& bounding_box,
86                                            const int64_t timestamp);
87 
88   // Sets the current position of the object in the most recent frame provided.
89   virtual void SetCurrentPositionOfObject(const std::string& id,
90                                           const BoundingBox& bounding_box);
91 
92   // Tells the ObjectTracker to stop tracking a target.
93   void ForgetTarget(const std::string& id);
94 
95   // Fills the given out_data buffer with the latest detected keypoint
96   // correspondences, first scaled by scale_factor (to adjust for downsampling
97   // that may have occurred elsewhere), then packed in a fixed-point format.
98   int GetKeypointsPacked(uint16_t* const out_data,
99                          const float scale_factor) const;
100 
101   // Copy the keypoint arrays after computeFlow is called.
102   // out_data should be at least kMaxKeypoints * kKeypointStep long.
103   // Currently, its format is [x1 y1 found x2 y2 score] repeated N times,
104   // where N is the number of keypoints tracked.  N is returned as the result.
105   int GetKeypoints(const bool only_found, float* const out_data) const;
106 
107   // Returns the current position of a box, given that it was at a certain
108   // position at the given time.
109   BoundingBox TrackBox(const BoundingBox& region,
110                        const int64_t timestamp) const;
111 
112   // Returns the number of frames that have been passed to NextFrame().
GetNumFrames()113   inline int GetNumFrames() const {
114     return num_frames_;
115   }
116 
HaveObject(const std::string & id)117   inline bool HaveObject(const std::string& id) const {
118     return objects_.find(id) != objects_.end();
119   }
120 
121   // Returns the TrackedObject associated with the given id.
GetObject(const std::string & id)122   inline const TrackedObject* GetObject(const std::string& id) const {
123     TrackedObjectMap::const_iterator iter = objects_.find(id);
124     CHECK_ALWAYS(iter != objects_.end(),
125                  "Unknown object key! \"%s\"", id.c_str());
126     TrackedObject* const object = iter->second;
127     return object;
128   }
129 
130   // Returns the TrackedObject associated with the given id.
GetObject(const std::string & id)131   inline TrackedObject* GetObject(const std::string& id) {
132     TrackedObjectMap::iterator iter = objects_.find(id);
133     CHECK_ALWAYS(iter != objects_.end(),
134                  "Unknown object key! \"%s\"", id.c_str());
135     TrackedObject* const object = iter->second;
136     return object;
137   }
138 
IsObjectVisible(const std::string & id)139   bool IsObjectVisible(const std::string& id) const {
140     SCHECK(HaveObject(id), "Don't have this object.");
141 
142     const TrackedObject* object = GetObject(id);
143     return object->IsVisible();
144   }
145 
146   virtual void Draw(const int canvas_width, const int canvas_height,
147                     const float* const frame_to_canvas) const;
148 
149  protected:
150   // Creates a new tracked object at the given position.
151   // If an object model is provided, then that model will be associated with the
152   // object. If not, a new model may be created from the appearance at the
153   // initial position and registered with the object detector.
154   virtual TrackedObject* MaybeAddObject(const std::string& id,
155                                         const Image<uint8_t>& image,
156                                         const BoundingBox& bounding_box,
157                                         const ObjectModelBase* object_model);
158 
159   // Find the keypoints in the frame before the current frame.
160   // If only one frame exists, keypoints will be found in that frame.
161   void ComputeKeypoints(const bool cached_ok = false);
162 
163   // Finds the correspondences for all the points in the current pair of frames.
164   // Stores the results in the given FramePair.
165   void FindCorrespondences(FramePair* const curr_change) const;
166 
GetNthIndexFromEnd(const int offset)167   inline int GetNthIndexFromEnd(const int offset) const {
168     return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset);
169   }
170 
171   BoundingBox TrackBox(const BoundingBox& region,
172                        const FramePair& frame_pair) const;
173 
IncrementFrameIndex()174   inline void IncrementFrameIndex() {
175     // Move the current framechange index up.
176     ++num_frames_;
177     ++curr_num_frame_pairs_;
178 
179     // If we've got too many, push up the start of the queue.
180     if (curr_num_frame_pairs_ > kNumFrames) {
181       first_frame_index_ = GetNthIndexFromStart(1);
182       --curr_num_frame_pairs_;
183     }
184   }
185 
GetNthIndexFromStart(const int offset)186   inline int GetNthIndexFromStart(const int offset) const {
187     SCHECK(offset >= 0 && offset < curr_num_frame_pairs_,
188           "Offset out of range!  %d out of %d.", offset, curr_num_frame_pairs_);
189     return (first_frame_index_ + offset) % kNumFrames;
190   }
191 
192   void TrackObjects();
193 
194   const std::unique_ptr<const TrackerConfig> config_;
195 
196   const int frame_width_;
197   const int frame_height_;
198 
199   int64_t curr_time_;
200 
201   int num_frames_;
202 
203   TrackedObjectMap objects_;
204 
205   FlowCache flow_cache_;
206 
207   KeypointDetector keypoint_detector_;
208 
209   int curr_num_frame_pairs_;
210   int first_frame_index_;
211 
212   std::unique_ptr<ImageData> frame1_;
213   std::unique_ptr<ImageData> frame2_;
214 
215   FramePair frame_pairs_[kNumFrames];
216 
217   std::unique_ptr<ObjectDetectorBase> detector_;
218 
219   int num_detected_;
220 
221  private:
222   void TrackTarget(TrackedObject* const object);
223 
224   bool GetBestObjectForDetection(
225       const Detection& detection, TrackedObject** match) const;
226 
227   void ProcessDetections(std::vector<Detection>* const detections);
228 
229   void DetectTargets();
230 
231   // Temp object used in ObjectTracker::CreateNewExample.
232   mutable std::vector<BoundingSquare> squares;
233 
234   friend std::ostream& operator<<(std::ostream& stream,
235                                   const ObjectTracker& tracker);
236 
237   TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker);
238 };
239 
240 inline std::ostream& operator<<(std::ostream& stream,
241                                 const ObjectTracker& tracker) {
242   stream << "Frame size: " << tracker.frame_width_ << "x"
243          << tracker.frame_height_ << std::endl;
244 
245   stream << "Num frames: " << tracker.num_frames_ << std::endl;
246 
247   stream << "Curr time: " << tracker.curr_time_ << std::endl;
248 
249   const int first_frame_index = tracker.GetNthIndexFromStart(0);
250   const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index];
251 
252   const int last_frame_index = tracker.GetNthIndexFromEnd(0);
253   const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index];
254 
255   stream << "first frame: " << first_frame_index << ","
256          << first_frame_pair.end_time_ << "    "
257          << "last frame: " << last_frame_index << ","
258          << last_frame_pair.end_time_ << "   diff: "
259          << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms"
260          << std::endl;
261 
262   stream << "Tracked targets:";
263   stream << tracker.objects_;
264 
265   return stream;
266 }
267 
268 }  // namespace tf_tracking
269 
270 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
271