1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
18 
19 #include "tensorflow/examples/android/jni/object_tracking/geom.h"
20 #include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
21 #include "tensorflow/examples/android/jni/object_tracking/image.h"
22 #include "tensorflow/examples/android/jni/object_tracking/utils.h"
23 
24 namespace tf_tracking {
25 
26 typedef uint8_t Code;
27 
28 class IntegralImage : public Image<uint32_t> {
29  public:
IntegralImage(const Image<uint8_t> & image_base)30   explicit IntegralImage(const Image<uint8_t>& image_base)
31       : Image<uint32_t>(image_base.GetWidth(), image_base.GetHeight()) {
32     Recompute(image_base);
33   }
34 
IntegralImage(const int width,const int height)35   IntegralImage(const int width, const int height)
36       : Image<uint32_t>(width, height) {}
37 
Recompute(const Image<uint8_t> & image_base)38   void Recompute(const Image<uint8_t>& image_base) {
39     SCHECK(image_base.GetWidth() == GetWidth() &&
40           image_base.GetHeight() == GetHeight(), "Dimensions don't match!");
41 
42     // Sum along first row.
43     {
44       int x_sum = 0;
45       for (int x = 0; x < image_base.GetWidth(); ++x) {
46         x_sum += image_base[0][x];
47         (*this)[0][x] = x_sum;
48       }
49     }
50 
51     // Sum everything else.
52     for (int y = 1; y < image_base.GetHeight(); ++y) {
53       uint32_t* curr_sum = (*this)[y];
54 
55       // Previously summed pointers.
56       const uint32_t* up_one = (*this)[y - 1];
57 
58       // Current value pointer.
59       const uint8_t* curr_delta = image_base[y];
60 
61       uint32_t row_till_now = 0;
62 
63       for (int x = 0; x < GetWidth(); ++x) {
64         // Add the one above and the one to the left.
65         row_till_now += *curr_delta;
66         *curr_sum = *up_one + row_till_now;
67 
68         // Scoot everything along.
69         ++curr_sum;
70         ++up_one;
71         ++curr_delta;
72       }
73     }
74 
75     SCHECK(VerifyData(image_base), "Images did not match!");
76   }
77 
VerifyData(const Image<uint8_t> & image_base)78   bool VerifyData(const Image<uint8_t>& image_base) {
79     for (int y = 0; y < GetHeight(); ++y) {
80       for (int x = 0; x < GetWidth(); ++x) {
81         uint32_t curr_val = (*this)[y][x];
82 
83         if (x > 0) {
84           curr_val -= (*this)[y][x - 1];
85         }
86 
87         if (y > 0) {
88           curr_val -= (*this)[y - 1][x];
89         }
90 
91         if (x > 0 && y > 0) {
92           curr_val += (*this)[y - 1][x - 1];
93         }
94 
95         if (curr_val != image_base[y][x]) {
96           LOGE("Mismatch! %d vs %d", curr_val, image_base[y][x]);
97           return false;
98         }
99 
100         if (GetRegionSum(x, y, x, y) != curr_val) {
101           LOGE("Mismatch!");
102         }
103       }
104     }
105 
106     return true;
107   }
108 
109   // Returns the sum of all pixels in the specified region.
GetRegionSum(const int x1,const int y1,const int x2,const int y2)110   inline uint32_t GetRegionSum(const int x1, const int y1, const int x2,
111                                const int y2) const {
112     SCHECK(x1 >= 0 && y1 >= 0 &&
113           x2 >= x1 && y2 >= y1 && x2 < GetWidth() && y2 < GetHeight(),
114           "indices out of bounds! %d-%d / %d, %d-%d / %d, ",
115           x1, x2, GetWidth(), y1, y2, GetHeight());
116 
117     const uint32_t everything = (*this)[y2][x2];
118 
119     uint32_t sum = everything;
120     if (x1 > 0 && y1 > 0) {
121       // Most common case.
122       const uint32_t left = (*this)[y2][x1 - 1];
123       const uint32_t top = (*this)[y1 - 1][x2];
124       const uint32_t top_left = (*this)[y1 - 1][x1 - 1];
125 
126       sum = everything - left - top + top_left;
127       SCHECK(sum >= 0, "Both: %d - %d - %d + %d => %d! indices: %d %d %d %d",
128             everything, left, top, top_left, sum, x1, y1, x2, y2);
129     } else if (x1 > 0) {
130       // Flush against top of image.
131       // Subtract out the region to the left only.
132       const uint32_t top = (*this)[y2][x1 - 1];
133       sum = everything - top;
134       SCHECK(sum >= 0, "Top: %d - %d => %d!", everything, top, sum);
135     } else if (y1 > 0) {
136       // Flush against left side of image.
137       // Subtract out the region above only.
138       const uint32_t left = (*this)[y1 - 1][x2];
139       sum = everything - left;
140       SCHECK(sum >= 0, "Left: %d - %d => %d!", everything, left, sum);
141     }
142 
143     SCHECK(sum >= 0, "Negative sum!");
144 
145     return sum;
146   }
147 
148   // Returns the 2bit code associated with this region, which represents
149   // the overall gradient.
GetCode(const BoundingBox & bounding_box)150   inline Code GetCode(const BoundingBox& bounding_box) const {
151     return GetCode(bounding_box.left_, bounding_box.top_,
152                    bounding_box.right_, bounding_box.bottom_);
153   }
154 
GetCode(const int x1,const int y1,const int x2,const int y2)155   inline Code GetCode(const int x1, const int y1,
156                       const int x2, const int y2) const {
157     SCHECK(x1 < x2 && y1 < y2, "Bounds out of order!! TL:%d,%d BR:%d,%d",
158            x1, y1, x2, y2);
159 
160     // Gradient computed vertically.
161     const int box_height = (y2 - y1) / 2;
162     const int top_sum = GetRegionSum(x1, y1, x2, y1 + box_height);
163     const int bottom_sum = GetRegionSum(x1, y2 - box_height, x2, y2);
164     const bool vertical_code = top_sum > bottom_sum;
165 
166     // Gradient computed horizontally.
167     const int box_width = (x2 - x1) / 2;
168     const int left_sum = GetRegionSum(x1, y1, x1 + box_width, y2);
169     const int right_sum = GetRegionSum(x2 - box_width, y1, x2, y2);
170     const bool horizontal_code = left_sum > right_sum;
171 
172     const Code final_code = (vertical_code << 1) | horizontal_code;
173 
174     SCHECK(InRange(final_code, static_cast<Code>(0), static_cast<Code>(3)),
175           "Invalid code! %d", final_code);
176 
177     // Returns a value 0-3.
178     return final_code;
179   }
180 
181  private:
182   TF_DISALLOW_COPY_AND_ASSIGN(IntegralImage);
183 };
184 
185 }  // namespace tf_tracking
186 
187 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
188