1# Copyright 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import its.image
16import its.device
17import its.objects
18import its.caps
19import time
20import math
21from matplotlib import pylab
22import os.path
23import matplotlib
24import matplotlib.pyplot
25import json
26from PIL import Image
27import numpy
28import cv2
29import bisect
30import scipy.spatial
31import sys
32
33NAME = os.path.basename(__file__).split(".")[0]
34
35# Capture 210 VGA frames (which is 7s at 30fps)
36N = 210
37W,H = 640,480
38FEATURE_MARGIN = H * 0.20 / 2 # Only take feature points from the center 20%
39                              # so that the rotation measured have much less
40                              # of rolling shutter effect
41
42MIN_FEATURE_PTS = 30          # Minimum number of feature points required to
43                              # perform rotation analysis
44
45MAX_CAM_FRM_RANGE_SEC = 9.0   # Maximum allowed camera frame range. When this
46                              # number is significantly larger than 7 seconds,
47                              # usually system is in some busy/bad states.
48
49MIN_GYRO_SMP_RATE = 100.0     # Minimum gyro sample rate
50
51FEATURE_PARAMS = dict( maxCorners = 240,
52                       qualityLevel = 0.3,
53                       minDistance = 7,
54                       blockSize = 7 )
55
56LK_PARAMS = dict( winSize  = (15, 15),
57                  maxLevel = 2,
58                  criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
59                        10, 0.03))
60
61# Constants to convert between different time units (for clarity).
62SEC_TO_NSEC = 1000*1000*1000.0
63SEC_TO_MSEC = 1000.0
64MSEC_TO_NSEC = 1000*1000.0
65MSEC_TO_SEC = 1/1000.0
66NSEC_TO_SEC = 1/(1000*1000*1000.0)
67NSEC_TO_MSEC = 1/(1000*1000.0)
68
69# Pass/fail thresholds.
70THRESH_MAX_CORR_DIST = 0.005
71THRESH_MAX_SHIFT_MS = 1
72THRESH_MIN_ROT = 0.001
73
74# lens facing
75FACING_FRONT = 0
76FACING_BACK = 1
77FACING_EXTERNAL = 2
78
79def main():
80    """Test if image and motion sensor events are well synchronized.
81
82    The instructions for running this test are in the SensorFusion.pdf file in
83    the same directory as this test.
84
85    The command-line argument "replay" may be optionally provided. Without this
86    argument, the test will collect a new set of camera+gyro data from the
87    device and then analyze it (and it will also dump this data to files in the
88    current directory). If the "replay" argument is provided, then the script
89    will instead load the dumped data from a previous run and analyze that
90    instead. This can be helpful for developers who are digging for additional
91    information on their measurements.
92    """
93
94    # Collect or load the camera+gyro data. All gyro events as well as camera
95    # timestamps are in the "events" dictionary, and "frames" is a list of
96    # RGB images as numpy arrays.
97    if "replay" not in sys.argv:
98        events, frames = collect_data()
99    else:
100        events, frames = load_data()
101
102    # Sanity check camera timestamps are enclosed by sensor timestamps
103    # This will catch bugs where camera and gyro timestamps go completely out
104    # of sync
105    cam_times = get_cam_times(events["cam"])
106    min_cam_time = min(cam_times) * NSEC_TO_SEC
107    max_cam_time = max(cam_times) * NSEC_TO_SEC
108    gyro_times = [e["time"] for e in events["gyro"]]
109    min_gyro_time = min(gyro_times) * NSEC_TO_SEC
110    max_gyro_time = max(gyro_times) * NSEC_TO_SEC
111    if not (min_cam_time > min_gyro_time and max_cam_time < max_gyro_time):
112        print "Test failed: camera timestamps [%f,%f] " \
113              "are not enclosed by gyro timestamps [%f, %f]" % (
114            min_cam_time, max_cam_time, min_gyro_time, max_gyro_time)
115        assert(0)
116
117    cam_frame_range = max_cam_time - min_cam_time
118    gyro_time_range = max_gyro_time - min_gyro_time
119    gyro_smp_per_sec = len(gyro_times) / gyro_time_range
120    print "Camera frame range", max_cam_time - min_cam_time
121    print "Gyro samples per second", gyro_smp_per_sec
122    assert(cam_frame_range < MAX_CAM_FRM_RANGE_SEC)
123    assert(gyro_smp_per_sec > MIN_GYRO_SMP_RATE)
124
125    # Compute the camera rotation displacements (rad) between each pair of
126    # adjacent frames.
127    cam_rots = get_cam_rotations(frames, events["facing"])
128    if max(abs(cam_rots)) < THRESH_MIN_ROT:
129        print "Device wasn't moved enough"
130        assert(0)
131
132    # Find the best offset (time-shift) to align the gyro and camera motion
133    # traces; this function integrates the shifted gyro data between camera
134    # samples for a range of candidate shift values, and returns the shift that
135    # result in the best correlation.
136    offset = get_best_alignment_offset(cam_times, cam_rots, events["gyro"])
137
138    # Plot the camera and gyro traces after applying the best shift.
139    cam_times = cam_times + offset*SEC_TO_NSEC
140    gyro_rots = get_gyro_rotations(events["gyro"], cam_times)
141    plot_rotations(cam_rots, gyro_rots)
142
143    # Pass/fail based on the offset and also the correlation distance.
144    dist = scipy.spatial.distance.correlation(cam_rots, gyro_rots)
145    print "Best correlation of %f at shift of %.2fms"%(dist, offset*SEC_TO_MSEC)
146    assert(dist < THRESH_MAX_CORR_DIST)
147    assert(abs(offset) < THRESH_MAX_SHIFT_MS*MSEC_TO_SEC)
148
149def get_best_alignment_offset(cam_times, cam_rots, gyro_events):
150    """Find the best offset to align the camera and gyro traces.
151
152    Uses a correlation distance metric between the curves, where a smaller
153    value means that the curves are better-correlated.
154
155    Args:
156        cam_times: Array of N camera times, one for each frame.
157        cam_rots: Array of N-1 camera rotation displacements (rad).
158        gyro_events: List of gyro event objects.
159
160    Returns:
161        Offset (seconds) of the best alignment.
162    """
163    # Measure the corr. dist. over a shift of up to +/- 50ms (0.5ms step size).
164    # Get the shift corresponding to the best (lowest) score.
165    candidates = numpy.arange(-50,50.5,0.5).tolist()
166    dists = []
167    for shift in candidates:
168        times = cam_times + shift*MSEC_TO_NSEC
169        gyro_rots = get_gyro_rotations(gyro_events, times)
170        dists.append(scipy.spatial.distance.correlation(cam_rots, gyro_rots))
171    best_corr_dist = min(dists)
172    best_shift = candidates[dists.index(best_corr_dist)]
173
174    print "Best shift without fitting is ", best_shift, "ms"
175
176    # Fit a curve to the corr. dist. data to measure the minima more
177    # accurately, by looking at the correlation distances within a range of
178    # +/- 10ms from the measured best score; note that this will use fewer
179    # than the full +/- 10 range for the curve fit if the measured score
180    # (which is used as the center of the fit) is within 10ms of the edge of
181    # the +/- 50ms candidate range.
182    i = dists.index(best_corr_dist)
183    candidates = candidates[i-20:i+21]
184    dists = dists[i-20:i+21]
185    a,b,c = numpy.polyfit(candidates, dists, 2)
186    exact_best_shift = -b/(2*a)
187    if abs(best_shift - exact_best_shift) > 2.0 or a <= 0 or c <= 0:
188        print "Test failed; bad fit to time-shift curve"
189        print "best_shift %f, exact_best_shift %f, a %f, c %f" % (best_shift,
190                exact_best_shift, a, c)
191        assert(0)
192
193    xfit = numpy.arange(candidates[0], candidates[-1], 0.05).tolist()
194    yfit = [a*x*x+b*x+c for x in xfit]
195    fig = matplotlib.pyplot.figure()
196    pylab.plot(candidates, dists, 'r', label="data")
197    pylab.plot(xfit, yfit, 'b', label="fit")
198    pylab.plot([exact_best_shift+x for x in [-0.1,0,0.1]], [0,0.01,0], 'b')
199    pylab.xlabel("Relative horizontal shift between curves (ms)")
200    pylab.ylabel("Correlation distance")
201    pylab.legend()
202    matplotlib.pyplot.savefig("%s_plot_shifts.png" % (NAME))
203
204    return exact_best_shift * MSEC_TO_SEC
205
206def plot_rotations(cam_rots, gyro_rots):
207    """Save a plot of the camera vs. gyro rotational measurements.
208
209    Args:
210        cam_rots: Array of N-1 camera rotation measurements (rad).
211        gyro_rots: Array of N-1 gyro rotation measurements (rad).
212    """
213    # For the plot, scale the rotations to be in degrees.
214    scale = 360/(2*math.pi)
215    fig = matplotlib.pyplot.figure()
216    cam_rots = cam_rots * scale
217    gyro_rots = gyro_rots * scale
218    pylab.plot(range(len(cam_rots)), cam_rots, 'r', label="camera")
219    pylab.plot(range(len(gyro_rots)), gyro_rots, 'b', label="gyro")
220    pylab.legend()
221    pylab.xlabel("Camera frame number")
222    pylab.ylabel("Angular displacement between adjacent camera frames (deg)")
223    pylab.xlim([0, len(cam_rots)])
224    matplotlib.pyplot.savefig("%s_plot.png" % (NAME))
225
226def get_gyro_rotations(gyro_events, cam_times):
227    """Get the rotation values of the gyro.
228
229    Integrates the gyro data between each camera frame to compute an angular
230    displacement.
231
232    Args:
233        gyro_events: List of gyro event objects.
234        cam_times: Array of N camera times, one for each frame.
235
236    Returns:
237        Array of N-1 gyro rotation measurements (rad).
238    """
239    all_times = numpy.array([e["time"] for e in gyro_events])
240    all_rots = numpy.array([e["z"] for e in gyro_events])
241    gyro_rots = []
242    # Integrate the gyro data between each pair of camera frame times.
243    for icam in range(len(cam_times)-1):
244        # Get the window of gyro samples within the current pair of frames.
245        tcam0 = cam_times[icam]
246        tcam1 = cam_times[icam+1]
247        igyrowindow0 = bisect.bisect(all_times, tcam0)
248        igyrowindow1 = bisect.bisect(all_times, tcam1)
249        sgyro = 0
250        # Integrate samples within the window.
251        for igyro in range(igyrowindow0, igyrowindow1):
252            vgyro = all_rots[igyro+1]
253            tgyro0 = all_times[igyro]
254            tgyro1 = all_times[igyro+1]
255            deltatgyro = (tgyro1 - tgyro0) * NSEC_TO_SEC
256            sgyro += vgyro * deltatgyro
257        # Handle the fractional intervals at the sides of the window.
258        for side,igyro in enumerate([igyrowindow0-1, igyrowindow1]):
259            vgyro = all_rots[igyro+1]
260            tgyro0 = all_times[igyro]
261            tgyro1 = all_times[igyro+1]
262            deltatgyro = (tgyro1 - tgyro0) * NSEC_TO_SEC
263            if side == 0:
264                f = (tcam0 - tgyro0) / (tgyro1 - tgyro0)
265                sgyro += vgyro * deltatgyro * (1.0 - f)
266            else:
267                f = (tcam1 - tgyro0) / (tgyro1 - tgyro0)
268                sgyro += vgyro * deltatgyro * f
269        gyro_rots.append(sgyro)
270    gyro_rots = numpy.array(gyro_rots)
271    return gyro_rots
272
273def get_cam_rotations(frames, facing):
274    """Get the rotations of the camera between each pair of frames.
275
276    Takes N frames and returns N-1 angular displacements corresponding to the
277    rotations between adjacent pairs of frames, in radians.
278
279    Args:
280        frames: List of N images (as RGB numpy arrays).
281
282    Returns:
283        Array of N-1 camera rotation measurements (rad).
284    """
285    gframes = []
286    for frame in frames:
287        frame = (frame * 255.0).astype(numpy.uint8)
288        gframes.append(cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY))
289    rots = []
290    ymin = H/2 - FEATURE_MARGIN
291    ymax = H/2 + FEATURE_MARGIN
292    for i in range(1,len(gframes)):
293        gframe0 = gframes[i-1]
294        gframe1 = gframes[i]
295        p0 = cv2.goodFeaturesToTrack(gframe0, mask=None, **FEATURE_PARAMS)
296        # p0's shape is N * 1 * 2
297        mask = (p0[:,0,1] >= ymin) & (p0[:,0,1] <= ymax)
298        p0_filtered = p0[mask]
299        if len(p0_filtered) < MIN_FEATURE_PTS:
300            print "Not enough feature points in frame", i
301            print "Need at least %d features, got %d" % (
302                    MIN_FEATURE_PTS, len(p0_filtered))
303            assert(0)
304        p1,st,_ = cv2.calcOpticalFlowPyrLK(gframe0, gframe1, p0_filtered, None,
305                **LK_PARAMS)
306        tform = procrustes_rotation(p0_filtered[st==1], p1[st==1])
307        if facing == FACING_BACK:
308            rot = -math.atan2(tform[0, 1], tform[0, 0])
309        elif facing == FACING_FRONT:
310            rot = math.atan2(tform[0, 1], tform[0, 0])
311        else:
312            print "Unknown lens facing", facing
313            assert(0)
314        rots.append(rot)
315        if i == 1:
316            # Save a debug visualization of the features that are being
317            # tracked in the first frame.
318            frame = frames[i]
319            for x,y in p0_filtered[st==1]:
320                cv2.circle(frame, (x,y), 3, (100,100,255), -1)
321            its.image.write_image(frame, "%s_features.png"%(NAME))
322    return numpy.array(rots)
323
324def get_cam_times(cam_events):
325    """Get the camera frame times.
326
327    Args:
328        cam_events: List of (start_exposure, exposure_time, readout_duration)
329            tuples, one per captured frame, with times in nanoseconds.
330
331    Returns:
332        frame_times: Array of N times, one corresponding to the "middle" of
333            the exposure of each frame.
334    """
335    # Assign a time to each frame that assumes that the image is instantly
336    # captured in the middle of its exposure.
337    starts = numpy.array([start for start,exptime,readout in cam_events])
338    exptimes = numpy.array([exptime for start,exptime,readout in cam_events])
339    readouts = numpy.array([readout for start,exptime,readout in cam_events])
340    frame_times = starts + (exptimes + readouts) / 2.0
341    return frame_times
342
343def load_data():
344    """Load a set of previously captured data.
345
346    Returns:
347        events: Dictionary containing all gyro events and cam timestamps.
348        frames: List of RGB images as numpy arrays.
349    """
350    with open("%s_events.txt"%(NAME), "r") as f:
351        events = json.loads(f.read())
352    n = len(events["cam"])
353    frames = []
354    for i in range(n):
355        img = Image.open("%s_frame%03d.png"%(NAME,i))
356        w,h = img.size[0:2]
357        frames.append(numpy.array(img).reshape(h,w,3) / 255.0)
358    return events, frames
359
360def collect_data():
361    """Capture a new set of data from the device.
362
363    Captures both motion data and camera frames, while the user is moving
364    the device in a proscribed manner.
365
366    Returns:
367        events: Dictionary containing all gyro events and cam timestamps.
368        frames: List of RGB images as numpy arrays.
369    """
370    with its.device.ItsSession() as cam:
371        props = cam.get_camera_properties()
372        its.caps.skip_unless(its.caps.sensor_fusion(props) and
373                             its.caps.manual_sensor(props) and
374                             props['android.lens.facing'] != FACING_EXTERNAL)
375
376        print "Starting sensor event collection"
377        cam.start_sensor_events()
378
379        # Sleep a while for gyro events to stabilize.
380        time.sleep(0.5)
381
382        # Capture the frames. OIS is disabled for manual captures.
383        facing = props['android.lens.facing']
384        if facing != FACING_FRONT and facing != FACING_BACK:
385            print "Unknown lens facing", facing
386            assert(0)
387
388        fmt = {"format":"yuv", "width":W, "height":H}
389        s,e,_,_,_ = cam.do_3a(get_results=True, do_af=False)
390        req = its.objects.manual_capture_request(s, e)
391        fps = 30
392        req["android.control.aeTargetFpsRange"] = [fps, fps]
393        print "Capturing %dx%d with sens. %d, exp. time %.1fms" % (
394                W, H, s, e*NSEC_TO_MSEC)
395        caps = cam.do_capture([req]*N, fmt)
396
397        # Get the gyro events.
398        print "Reading out sensor events"
399        gyro = cam.get_sensor_events()["gyro"]
400        print "Number of gyro samples", len(gyro)
401
402        # Combine the events into a single structure.
403        print "Dumping event data"
404        starts = [c["metadata"]["android.sensor.timestamp"] for c in caps]
405        exptimes = [c["metadata"]["android.sensor.exposureTime"] for c in caps]
406        readouts = [c["metadata"]["android.sensor.rollingShutterSkew"]
407                    for c in caps]
408        events = {"gyro": gyro, "cam": zip(starts,exptimes,readouts),
409                  "facing": facing}
410        with open("%s_events.txt"%(NAME), "w") as f:
411            f.write(json.dumps(events))
412
413        # Convert the frames to RGB.
414        print "Dumping frames"
415        frames = []
416        for i,c in enumerate(caps):
417            img = its.image.convert_capture_to_rgb_image(c)
418            frames.append(img)
419            its.image.write_image(img, "%s_frame%03d.png"%(NAME,i))
420
421        return events, frames
422
423def procrustes_rotation(X, Y):
424    """
425    Procrustes analysis determines a linear transformation (translation,
426    reflection, orthogonal rotation and scaling) of the points in Y to best
427    conform them to the points in matrix X, using the sum of squared errors
428    as the goodness of fit criterion.
429
430    Args:
431        X, Y: Matrices of target and input coordinates.
432
433    Returns:
434        The rotation component of the transformation that maps X to Y.
435    """
436    X0 = (X-X.mean(0)) / numpy.sqrt(((X-X.mean(0))**2.0).sum())
437    Y0 = (Y-Y.mean(0)) / numpy.sqrt(((Y-Y.mean(0))**2.0).sum())
438    U,s,Vt = numpy.linalg.svd(numpy.dot(X0.T, Y0),full_matrices=False)
439    return numpy.dot(Vt.T, U.T)
440
441if __name__ == '__main__':
442    main()
443