1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
18 
19 #include <stdint.h>
20 #include <vector>
21 
22 #include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
23 #include "tensorflow/examples/android/jni/object_tracking/image.h"
24 #include "tensorflow/examples/android/jni/object_tracking/image_data.h"
25 #include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
26 
27 namespace tf_tracking {
28 
29 struct Keypoint;
30 
31 class KeypointDetector {
32  public:
KeypointDetector(const KeypointDetectorConfig * const config)33   explicit KeypointDetector(const KeypointDetectorConfig* const config)
34       : config_(config),
35         keypoint_scratch_(new Image<uint8_t>(config_->image_size)),
36         interest_map_(new Image<bool>(config_->image_size)),
37         fast_quadrant_(0) {
38     interest_map_->Clear(false);
39   }
40 
~KeypointDetector()41   ~KeypointDetector() {}
42 
43   // Finds a new set of keypoints for the current frame, picked from the current
44   // set of keypoints and also from a set discovered via a keypoint detector.
45   // Special attention is applied to make sure that keypoints are distributed
46   // within the supplied ROIs.
47   void FindKeypoints(const ImageData& image_data,
48                      const std::vector<BoundingBox>& rois,
49                      const FramePair& prev_change,
50                      FramePair* const curr_change);
51 
52  private:
53   // Compute the corneriness of a point in the image.
54   float HarrisFilter(const Image<int32_t>& I_x, const Image<int32_t>& I_y,
55                      const float x, const float y) const;
56 
57   // Adds a grid of candidate keypoints to the given box, up to
58   // max_num_keypoints or kNumToAddAsCandidates^2, whichever is lower.
59   int AddExtraCandidatesForBoxes(
60       const std::vector<BoundingBox>& boxes,
61       const int max_num_keypoints,
62       Keypoint* const keypoints) const;
63 
64   // Scan the frame for potential keypoints using the FAST keypoint detector.
65   // Quadrant is an argument 0-3 which refers to the quadrant of the image in
66   // which to detect keypoints.
67   int FindFastKeypoints(const Image<uint8_t>& frame, const int quadrant,
68                         const int downsample_factor,
69                         const int max_num_keypoints, Keypoint* const keypoints);
70 
71   int FindFastKeypoints(const ImageData& image_data,
72                         const int max_num_keypoints,
73                         Keypoint* const keypoints);
74 
75   // Score a bunch of candidate keypoints.  Assigns the scores to the input
76   // candidate_keypoints array entries.
77   void ScoreKeypoints(const ImageData& image_data,
78                       const int num_candidates,
79                       Keypoint* const candidate_keypoints);
80 
81   void SortKeypoints(const int num_candidates,
82                     Keypoint* const candidate_keypoints) const;
83 
84   // Selects a set of keypoints falling within the supplied box such that the
85   // most highly rated keypoints are picked first, and so that none of them are
86   // too close together.
87   int SelectKeypointsInBox(
88       const BoundingBox& box,
89       const Keypoint* const candidate_keypoints,
90       const int num_candidates,
91       const int max_keypoints,
92       const int num_existing_keypoints,
93       const Keypoint* const existing_keypoints,
94       Keypoint* const final_keypoints) const;
95 
96   // Selects from the supplied sorted keypoint pool a set of keypoints that will
97   // best cover the given set of boxes, such that each box is covered at a
98   // resolution proportional to its size.
99   void SelectKeypoints(
100       const std::vector<BoundingBox>& boxes,
101       const Keypoint* const candidate_keypoints,
102       const int num_candidates,
103       FramePair* const frame_change) const;
104 
105   // Copies and compacts the found keypoints in the second frame of prev_change
106   // into the array at new_keypoints.
107   static int CopyKeypoints(const FramePair& prev_change,
108                           Keypoint* const new_keypoints);
109 
110   const KeypointDetectorConfig* const config_;
111 
112   // Scratch memory for keypoint candidacy detection and non-max suppression.
113   std::unique_ptr<Image<uint8_t> > keypoint_scratch_;
114 
115   // Regions of the image to pay special attention to.
116   std::unique_ptr<Image<bool> > interest_map_;
117 
118   // The current quadrant of the image to detect FAST keypoints in.
119   // Keypoint detection is staggered for performance reasons. Every four frames
120   // a full scan of the frame will have been performed.
121   int fast_quadrant_;
122 
123   Keypoint tmp_keypoints_[kMaxTempKeypoints];
124 };
125 
126 }  // namespace tf_tracking
127 
128 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
129