1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
18 
19 #include <stdint.h>
20 #include <memory>
21 
22 #include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
23 #include "tensorflow/examples/android/jni/object_tracking/image.h"
24 #include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
25 #include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
26 #include "tensorflow/examples/android/jni/object_tracking/time_log.h"
27 #include "tensorflow/examples/android/jni/object_tracking/utils.h"
28 
29 #include "tensorflow/examples/android/jni/object_tracking/config.h"
30 
31 namespace tf_tracking {
32 
33 // Class that encapsulates all bulky processed data for a frame.
34 class ImageData {
35  public:
ImageData(const int width,const int height)36   explicit ImageData(const int width, const int height)
37       : uv_frame_width_(width << 1),
38         uv_frame_height_(height << 1),
39         timestamp_(0),
40         image_(width, height) {
41     InitPyramid(width, height);
42     ResetComputationCache();
43   }
44 
45  private:
ResetComputationCache()46   void ResetComputationCache() {
47     uv_data_computed_ = false;
48     integral_image_computed_ = false;
49     for (int i = 0; i < kNumPyramidLevels; ++i) {
50       spatial_x_computed_[i] = false;
51       spatial_y_computed_[i] = false;
52       pyramid_sqrt2_computed_[i * 2] = false;
53       pyramid_sqrt2_computed_[i * 2 + 1] = false;
54     }
55   }
56 
InitPyramid(const int width,const int height)57   void InitPyramid(const int width, const int height) {
58     int level_width = width;
59     int level_height = height;
60 
61     for (int i = 0; i < kNumPyramidLevels; ++i) {
62       pyramid_sqrt2_[i * 2] = NULL;
63       pyramid_sqrt2_[i * 2 + 1] = NULL;
64       spatial_x_[i] = NULL;
65       spatial_y_[i] = NULL;
66 
67       level_width /= 2;
68       level_height /= 2;
69     }
70 
71     // Alias the first pyramid level to image_.
72     pyramid_sqrt2_[0] = &image_;
73   }
74 
75  public:
~ImageData()76   ~ImageData() {
77     // The first pyramid level is actually an alias to image_,
78     // so make sure it doesn't get deleted here.
79     pyramid_sqrt2_[0] = NULL;
80 
81     for (int i = 0; i < kNumPyramidLevels; ++i) {
82       SAFE_DELETE(pyramid_sqrt2_[i * 2]);
83       SAFE_DELETE(pyramid_sqrt2_[i * 2 + 1]);
84       SAFE_DELETE(spatial_x_[i]);
85       SAFE_DELETE(spatial_y_[i]);
86     }
87   }
88 
SetData(const uint8_t * const new_frame,const int stride,const int64_t timestamp,const int downsample_factor)89   void SetData(const uint8_t* const new_frame, const int stride,
90                const int64_t timestamp, const int downsample_factor) {
91     SetData(new_frame, NULL, stride, timestamp, downsample_factor);
92   }
93 
SetData(const uint8_t * const new_frame,const uint8_t * const uv_frame,const int stride,const int64_t timestamp,const int downsample_factor)94   void SetData(const uint8_t* const new_frame, const uint8_t* const uv_frame,
95                const int stride, const int64_t timestamp,
96                const int downsample_factor) {
97     ResetComputationCache();
98 
99     timestamp_ = timestamp;
100 
101     TimeLog("SetData!");
102 
103     pyramid_sqrt2_[0]->FromArray(new_frame, stride, downsample_factor);
104     pyramid_sqrt2_computed_[0] = true;
105     TimeLog("Downsampled image");
106 
107     if (uv_frame != NULL) {
108       if (u_data_.get() == NULL) {
109         u_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_));
110         v_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_));
111       }
112 
113       GetUV(uv_frame, u_data_.get(), v_data_.get());
114       uv_data_computed_ = true;
115       TimeLog("Copied UV data");
116     } else {
117       LOGV("No uv data!");
118     }
119 
120 #ifdef LOG_TIME
121     // If profiling is enabled, precompute here to make it easier to distinguish
122     // total costs.
123     Precompute();
124 #endif
125   }
126 
GetTimestamp()127   inline const uint64_t GetTimestamp() const { return timestamp_; }
128 
GetImage()129   inline const Image<uint8_t>* GetImage() const {
130     SCHECK(pyramid_sqrt2_computed_[0], "image not set!");
131     return pyramid_sqrt2_[0];
132   }
133 
GetPyramidSqrt2Level(const int level)134   const Image<uint8_t>* GetPyramidSqrt2Level(const int level) const {
135     if (!pyramid_sqrt2_computed_[level]) {
136       SCHECK(level != 0, "Level equals 0!");
137       if (level == 1) {
138         const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(0);
139         if (pyramid_sqrt2_[level] == NULL) {
140           const int new_width =
141               (static_cast<int>(upper_level.GetWidth() / sqrtf(2)) + 1) / 2 * 2;
142           const int new_height =
143               (static_cast<int>(upper_level.GetHeight() / sqrtf(2)) + 1) / 2 *
144               2;
145 
146           pyramid_sqrt2_[level] = new Image<uint8_t>(new_width, new_height);
147         }
148         pyramid_sqrt2_[level]->DownsampleInterpolateLinear(upper_level);
149       } else {
150         const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(level - 2);
151         if (pyramid_sqrt2_[level] == NULL) {
152           pyramid_sqrt2_[level] = new Image<uint8_t>(
153               upper_level.GetWidth() / 2, upper_level.GetHeight() / 2);
154         }
155         pyramid_sqrt2_[level]->DownsampleAveraged(
156             upper_level.data(), upper_level.stride(), 2);
157       }
158       pyramid_sqrt2_computed_[level] = true;
159     }
160     return pyramid_sqrt2_[level];
161   }
162 
GetSpatialX(const int level)163   inline const Image<int32_t>* GetSpatialX(const int level) const {
164     if (!spatial_x_computed_[level]) {
165       const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2);
166       if (spatial_x_[level] == NULL) {
167         spatial_x_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight());
168       }
169       spatial_x_[level]->DerivativeX(src);
170       spatial_x_computed_[level] = true;
171     }
172     return spatial_x_[level];
173   }
174 
GetSpatialY(const int level)175   inline const Image<int32_t>* GetSpatialY(const int level) const {
176     if (!spatial_y_computed_[level]) {
177       const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2);
178       if (spatial_y_[level] == NULL) {
179         spatial_y_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight());
180       }
181       spatial_y_[level]->DerivativeY(src);
182       spatial_y_computed_[level] = true;
183     }
184     return spatial_y_[level];
185   }
186 
187   // The integral image is currently only used for object detection, so lazily
188   // initialize it on request.
GetIntegralImage()189   inline const IntegralImage* GetIntegralImage() const {
190     if (integral_image_.get() == NULL) {
191       integral_image_.reset(new IntegralImage(image_));
192     } else if (!integral_image_computed_) {
193       integral_image_->Recompute(image_);
194     }
195     integral_image_computed_ = true;
196     return integral_image_.get();
197   }
198 
GetU()199   inline const Image<uint8_t>* GetU() const {
200     SCHECK(uv_data_computed_, "UV data not provided!");
201     return u_data_.get();
202   }
203 
GetV()204   inline const Image<uint8_t>* GetV() const {
205     SCHECK(uv_data_computed_, "UV data not provided!");
206     return v_data_.get();
207   }
208 
209  private:
Precompute()210   void Precompute() {
211     // Create the smoothed pyramids.
212     for (int i = 0; i < kNumPyramidLevels * 2; i += 2) {
213       (void) GetPyramidSqrt2Level(i);
214     }
215     TimeLog("Created smoothed pyramids");
216 
217     // Create the smoothed pyramids.
218     for (int i = 1; i < kNumPyramidLevels * 2; i += 2) {
219       (void) GetPyramidSqrt2Level(i);
220     }
221     TimeLog("Created smoothed sqrt pyramids");
222 
223     // Create the spatial derivatives for frame 1.
224     for (int i = 0; i < kNumPyramidLevels; ++i) {
225       (void) GetSpatialX(i);
226       (void) GetSpatialY(i);
227     }
228     TimeLog("Created spatial derivatives");
229 
230     (void) GetIntegralImage();
231     TimeLog("Got integral image!");
232   }
233 
234   const int uv_frame_width_;
235   const int uv_frame_height_;
236 
237   int64_t timestamp_;
238 
239   Image<uint8_t> image_;
240 
241   bool uv_data_computed_;
242   std::unique_ptr<Image<uint8_t> > u_data_;
243   std::unique_ptr<Image<uint8_t> > v_data_;
244 
245   mutable bool spatial_x_computed_[kNumPyramidLevels];
246   mutable Image<int32_t>* spatial_x_[kNumPyramidLevels];
247 
248   mutable bool spatial_y_computed_[kNumPyramidLevels];
249   mutable Image<int32_t>* spatial_y_[kNumPyramidLevels];
250 
251   // Mutable so the lazy initialization can work when this class is const.
252   // Whether or not the integral image has been computed for the current image.
253   mutable bool integral_image_computed_;
254   mutable std::unique_ptr<IntegralImage> integral_image_;
255 
256   mutable bool pyramid_sqrt2_computed_[kNumPyramidLevels * 2];
257   mutable Image<uint8_t>* pyramid_sqrt2_[kNumPyramidLevels * 2];
258 
259   TF_DISALLOW_COPY_AND_ASSIGN(ImageData);
260 };
261 
262 }  // namespace tf_tracking
263 
264 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
265