1"""Experimentally determines a camera's rolling shutter skew.
2
3See the accompanying PDF for instructions on how to use this test.
4"""
5from __future__ import division
6from __future__ import print_function
7
8import argparse
9import glob
10import math
11import os
12import sys
13import tempfile
14
15import cv2
16import its.caps
17import its.device
18import its.image
19import its.objects
20import numpy as np
21
22DEBUG = False
23
24# Constants for which direction the camera is facing.
25FACING_FRONT = 0
26FACING_BACK = 1
27FACING_EXTERNAL = 2
28
29# Camera capture defaults.
30FPS = 30
31WIDTH = 640
32HEIGHT = 480
33TEST_LENGTH = 1
34
35# Each circle in a cluster must be within this many pixels of some other circle
36# in the cluster.
37CLUSTER_DISTANCE = 50.0 / HEIGHT
38# A cluster must consist of at least this percentage of the total contours for
39# it to be allowed into the computation.
40MAJORITY_THRESHOLD = 0.7
41
42# Constants to make sure the slope of the fitted line is reasonable.
43SLOPE_MIN_THRESHOLD = 0.5
44SLOPE_MAX_THRESHOLD = 1.5
45
46# To improve readability of unit conversions.
47SEC_TO_NSEC = float(10**9)
48MSEC_TO_NSEC = float(10**6)
49NSEC_TO_MSEC = 1.0 / float(10**6)
50
51
52class RollingShutterArgumentParser(object):
53    """Parses command line arguments for the rolling shutter test."""
54
55    def __init__(self):
56        self.__parser = argparse.ArgumentParser(
57                description='Run rolling shutter test')
58        self.__parser.add_argument(
59                '-d', '--debug',
60                action='store_true',
61                help='print and write data useful for debugging')
62        self.__parser.add_argument(
63                '-f', '--fps',
64                type=int,
65                help='FPS to capture with during the test (defaults to 30)')
66        self.__parser.add_argument(
67                '-i', '--img_size',
68                help=('comma-separated dimensions of captured images (defaults '
69                      'to 640x480). Example: --img_size=<width>,<height>'))
70        self.__parser.add_argument(
71                '-l', '--led_time',
72                type=float,
73                required=True,
74                help=('how many milliseconds each column of the LED array is '
75                      'lit for'))
76        self.__parser.add_argument(
77                '-p', '--panel_distance',
78                type=float,
79                help='how far the LED panel is from the camera (in meters)')
80        self.__parser.add_argument(
81                '-r', '--read_dir',
82                help=('read existing test data from specified directory.  If '
83                      'not specified, new test data is collected from the '
84                      'device\'s camera)'))
85        self.__parser.add_argument(
86                '--device_id',
87                help=('device ID for device being tested (can also use '
88                      '\'device=<DEVICE ID>\')'))
89        self.__parser.add_argument(
90                '-t', '--test_length',
91                type=int,
92                help=('how many seconds the test should run for (defaults to 1 '
93                      'second)'))
94        self.__parser.add_argument(
95                '-o', '--debug_dir',
96                help=('write debugging information in a folder in the '
97                      'specified directory.  Otherwise, the system\'s default '
98                      'location for temporary folders is used.  --debug must '
99                      'be specified along with this argument.'))
100
101    def parse_args(self):
102        """Returns object containing parsed values from the command line."""
103        # Don't show argparse the 'device' flag, since it's in a different
104        # format than the others (to maintain CameraITS conventions) and it will
105        # complain.
106        filtered_args = [arg for arg in sys.argv[1:] if 'device=' not in arg]
107        args = self.__parser.parse_args(filtered_args)
108        if args.device_id:
109            # If argparse format is used, convert it to a format its.device can
110            # use later on.
111            sys.argv.append('device=%s' % args.device_id)
112        return args
113
114
115def main():
116    global DEBUG
117    global CLUSTER_DISTANCE
118
119    parser = RollingShutterArgumentParser()
120    args = parser.parse_args()
121
122    DEBUG = args.debug
123    if not DEBUG and args.debug_dir:
124        print('argument --debug_dir requires --debug', file=sys.stderr)
125        sys.exit()
126
127    if args.read_dir is None:
128        # Collect new data.
129        raw_caps, reported_skew = collect_data(args)
130        frames = [its.image.convert_capture_to_rgb_image(c) for c in raw_caps]
131    else:
132        # Load existing data.
133        frames, reported_skew = load_data(args.read_dir)
134
135    # Make the cluster distance relative to the height of the image.
136    (frame_h, _, _) = frames[0].shape
137    CLUSTER_DISTANCE = frame_h * CLUSTER_DISTANCE
138    debug_print('Setting cluster distance to %spx.' % CLUSTER_DISTANCE)
139
140    if DEBUG:
141        debug_dir = setup_debug_dir(args.debug_dir)
142        # Write raw frames.
143        for i, img in enumerate(frames):
144            its.image.write_image(img, '%s/raw/%03d.png' % (debug_dir, i))
145    else:
146        debug_dir = None
147
148    avg_shutter_skew, num_frames_used = find_average_shutter_skew(
149            frames, args.led_time, debug_dir)
150    if debug_dir:
151        # Write the reported skew with the raw images, so the directory can also
152        # be used to read from.
153        with open(debug_dir + '/raw/reported_skew.txt', 'w') as f:
154            f.write('%sms\n' % reported_skew)
155
156    if avg_shutter_skew is None:
157        print('Could not find usable frames.')
158    else:
159        print('Device reported shutter skew of %sms.' % reported_skew)
160        print('Measured shutter skew is %sms (averaged over %s frames).' %
161              (avg_shutter_skew, num_frames_used))
162
163
164def collect_data(args):
165    """Capture a new set of frames from the device's camera.
166
167    Args:
168        args: Parsed command line arguments.
169
170    Returns:
171        A list of RGB images as numpy arrays.
172    """
173    fps = args.fps if args.fps else FPS
174    if args.img_size:
175        w, h = map(int, args.img_size.split(','))
176    else:
177        w, h = WIDTH, HEIGHT
178    test_length = args.test_length if args.test_length else TEST_LENGTH
179
180    with its.device.ItsSession() as cam:
181        props = cam.get_camera_properties()
182        its.caps.skip_unless(its.caps.manual_sensor(props))
183        facing = props['android.lens.facing']
184        if facing != FACING_FRONT and facing != FACING_BACK:
185            print('Unknown lens facing %s' % facing)
186            assert 0
187
188        fmt = {'format': 'yuv', 'width': w, 'height': h}
189        s, e, _, _, _ = cam.do_3a(get_results=True, do_af=False)
190        req = its.objects.manual_capture_request(s, e)
191        req['android.control.aeTargetFpsRange'] = [fps, fps]
192
193        # Convert from milliseconds to nanoseconds.  We only want enough
194        # exposure time to saturate approximately one column.
195        exposure_time = (args.led_time / 2.0) * MSEC_TO_NSEC
196        print('Using exposure time of %sns.' % exposure_time)
197        req['android.sensor.exposureTime'] = exposure_time
198        req["android.sensor.frameDuration"] = int(SEC_TO_NSEC / fps);
199
200        if args.panel_distance is not None:
201            # Convert meters to diopters and use that for the focus distance.
202            req['android.lens.focusDistance'] = 1 / args.panel_distance
203        print('Starting capture')
204        raw_caps = cam.do_capture([req]*fps*test_length, fmt)
205        print('Finished capture')
206
207        # Convert from nanoseconds to milliseconds.
208        shutter_skews = {c['metadata']['android.sensor.rollingShutterSkew'] *
209                          NSEC_TO_MSEC for c in raw_caps}
210        # All frames should have same rolling shutter skew.
211        assert len(shutter_skews) == 1
212        shutter_skew = list(shutter_skews)[0]
213
214        return raw_caps, shutter_skew
215
216
217def load_data(dir_name):
218    """Reads camera frame data from an existing directory.
219
220    Args:
221        dir_name: Name of the directory to read data from.
222
223    Returns:
224        A list of RGB images as numpy arrays.
225    """
226    frame_files = glob.glob('%s/*.png' % dir_name)
227    frames = []
228    for frame_file in sorted(frame_files):
229        frames.append(its.image.load_rgb_image(frame_file))
230    with open('%s/reported_skew.txt' % dir_name, 'r') as f:
231        reported_skew = f.readline()[:-2]  # Strip off 'ms' suffix
232    return frames, reported_skew
233
234
235def find_average_shutter_skew(frames, led_time, debug_dir=None):
236    """Finds the average shutter skew using the given frames.
237
238    Frames without enough information will be discarded from the average to
239    improve overall accuracy.
240
241    Args:
242        frames:    List of RGB images from the camera being tested.
243        led_time:  How long a single LED column is lit for (in milliseconds).
244        debug_dir: (optional) Directory to write debugging information to.
245
246    Returns:
247        The average calculated shutter skew and the number of frames used to
248        calculate the average.
249    """
250    avg_shutter_skew = 0.0
251    avg_slope = 0.0
252    weight = 0.0
253    num_frames_used = 0
254
255    for i, frame in enumerate(frames):
256        debug_print('------------------------')
257        debug_print('| PROCESSING FRAME %03d |' % i)
258        debug_print('------------------------')
259        shutter_skew, confidence, slope = calculate_shutter_skew(
260                frame, led_time, i, debug_dir=debug_dir)
261        if shutter_skew is None:
262            debug_print('Skipped frame.')
263        else:
264            debug_print('Shutter skew is %sms (confidence: %s).' %
265                        (shutter_skew, confidence))
266            # Use the confidence to weight the average.
267            avg_shutter_skew += shutter_skew * confidence
268            avg_slope += slope * confidence
269            weight += confidence
270            num_frames_used += 1
271
272    debug_print('\n')
273    if num_frames_used == 0:
274        return None, None
275    else:
276        avg_shutter_skew /= weight
277        avg_slope /= weight
278        slope_err_str = ('The average slope of the fitted line was too %s '
279                         'to get an accurate measurement (slope was %s).  '
280                         'Try making the LED panel %s.')
281        if avg_slope < SLOPE_MIN_THRESHOLD:
282            print(slope_err_str % ('flat', avg_slope, 'slower'),
283                  file=sys.stderr)
284        elif avg_slope > SLOPE_MAX_THRESHOLD:
285            print(slope_err_str % ('steep', avg_slope, 'faster'),
286                  file=sys.stderr)
287        return avg_shutter_skew, num_frames_used
288
289
290def calculate_shutter_skew(frame, led_time, frame_num=None, debug_dir=None):
291    """Calculates the shutter skew of the camera being used for this test.
292
293    Args:
294        frame:     A single RGB image captured by the camera being tested.
295        led_time:  How long a single LED column is lit for (in milliseconds).
296        frame_num: (optional) Number of the given frame.
297        debug_dir: (optional) Directory to write debugging information to.
298
299    Returns:
300        The shutter skew (in milliseconds), the confidence in the accuracy of
301        the measurement (useful for weighting averages), and the slope of the
302        fitted line.
303    """
304    contours, scratch_img, contour_img, mono_img = find_contours(frame.copy())
305    if debug_dir is not None:
306        cv2.imwrite('%s/contour/%03d.png' % (debug_dir, frame_num), contour_img)
307        cv2.imwrite('%s/mono/%03d.png' % (debug_dir, frame_num), mono_img)
308
309    largest_cluster, cluster_percentage = find_largest_cluster(contours,
310                                                               scratch_img)
311    if largest_cluster is None:
312        debug_print('No majority cluster found.')
313        return None, None, None
314    elif len(largest_cluster) <= 1:
315        debug_print('Majority cluster was too small.')
316        return None, None, None
317    debug_print('%s points in the largest cluster.' % len(largest_cluster))
318
319    np_cluster = np.array([[c.x, c.y] for c in largest_cluster])
320    [vx], [vy], [x0], [y0] = cv2.fitLine(np_cluster, cv2.cv.CV_DIST_L2,
321                                         0, 0.01, 0.01)
322    slope = vy / vx
323    debug_print('Slope is %s.' % slope)
324    (frame_h, frame_w, _) = frame.shape
325    # Draw line onto scratch frame.
326    pt1 = tuple(map(int, (x0 - vx * 1000, y0 - vy * 1000)))
327    pt2 = tuple(map(int, (x0 + vx * 1000, y0 + vy * 1000)))
328    cv2.line(scratch_img, pt1, pt2, (0, 255, 255), thickness=3)
329
330    # We only need the width of the cluster.
331    _, _, cluster_w, _ = find_cluster_bounding_rect(largest_cluster,
332                                                    scratch_img)
333
334    num_columns = find_num_columns_spanned(largest_cluster)
335    debug_print('%s columns spanned by cluster.' % num_columns)
336    # How long it takes for a column to move from the left of the bounding
337    # rectangle to the right.
338    left_to_right_time = led_time * num_columns
339    milliseconds_per_x_pixel = left_to_right_time / cluster_w
340    # The distance between the line's intersection at the top of the frame and
341    # the intersection at the bottom.
342    x_range = frame_h / slope
343    shutter_skew = milliseconds_per_x_pixel * x_range
344    # If the aspect ratio is different from 4:3 (the aspect ratio of the actual
345    # sensor), we need to correct, because it will be cropped.
346    shutter_skew *= (float(frame_w) / float(frame_h)) / (4.0 / 3.0)
347
348    if debug_dir is not None:
349        cv2.imwrite('%s/scratch/%03d.png' % (debug_dir, frame_num),
350                    scratch_img)
351
352    return shutter_skew, cluster_percentage, slope
353
354
355def find_contours(img):
356    """Finds contours in the given image.
357
358    Args:
359        img: Image in Android camera RGB format.
360
361    Returns:
362        OpenCV-formatted contours, the original image in OpenCV format, a
363        thresholded image with the contours drawn on, and a grayscale version of
364        the image.
365    """
366    # Convert to format OpenCV can work with (BGR ordering with byte-ranged
367    # values).
368    img *= 255
369    img = img.astype(np.uint8)
370    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
371
372    # Since the LED colors for the panel we're using are red, we can get better
373    # contours for the LEDs if we ignore the green and blue channels.  This also
374    # makes it so we don't pick up the blue control screen of the LED panel.
375    red_img = img[:, :, 2]
376    _, thresh = cv2.threshold(red_img, 0, 255, cv2.THRESH_BINARY +
377                              cv2.THRESH_OTSU)
378
379    # Remove noise before finding contours by eroding the thresholded image and
380    # then re-dilating it.  The size of the kernel represents how many
381    # neighboring pixels to consider for the result of a single pixel.
382    kernel = np.ones((3, 3), np.uint8)
383    opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
384
385    if DEBUG:
386        # Need to convert it back to BGR if we want to draw colored contours.
387        contour_img = cv2.cvtColor(opening, cv2.COLOR_GRAY2BGR)
388    else:
389        contour_img = None
390    cv2_version = cv2.__version__
391    if cv2_version.startswith('3.'): # OpenCV 3.x
392        _, contours, _ = cv2.findContours(
393                opening, cv2.cv.CV_RETR_EXTERNAL, cv2.cv.CV_CHAIN_APPROX_NONE)
394    else: # OpenCV 2.x and 4.x
395        contours, _ = cv2.findContours(
396                opening, cv2.cv.CV_RETR_EXTERNAL, cv2.cv.CV_CHAIN_APPROX_NONE)
397    if DEBUG:
398        cv2.drawContours(contour_img, contours, -1, (0, 0, 255), thickness=2)
399    return contours, img, contour_img, red_img
400
401
402def convert_to_circles(contours):
403    """Converts given contours into circle objects.
404
405    Args:
406        contours: Contours generated by OpenCV.
407
408    Returns:
409        A list of circles.
410    """
411
412    class Circle(object):
413        """Holds data to uniquely define a circle."""
414
415        def __init__(self, contour):
416            self.x = int(np.mean(contour[:, 0, 0]))
417            self.y = int(np.mean(contour[:, 0, 1]))
418            # Get diameters of each axis then half it.
419            x_r = (np.max(contour[:, 0, 0]) - np.min(contour[:, 0, 0])) / 2.0
420            y_r = (np.max(contour[:, 0, 1]) - np.min(contour[:, 0, 1])) / 2.0
421            # Average x radius and y radius to get the approximate radius for
422            # the given contour.
423            self.r = (x_r + y_r) / 2.0
424            assert self.r > 0.0
425
426        def distance_to(self, other):
427            return (math.sqrt((other.x - self.x)**2 + (other.y - self.y)**2) -
428                    self.r - other.r)
429
430        def intersects(self, other):
431            return self.distance_to(other) <= 0.0
432
433    return list(map(Circle, contours))
434
435
436def find_largest_cluster(contours, frame):
437    """Finds the largest cluster in the given contours.
438
439    Args:
440        contours: Contours generated by OpenCV.
441        frame:    For drawing debugging information onto.
442
443    Returns:
444        The cluster with the most contours in it and the percentage of all
445        contours that the cluster contains.
446    """
447    clusters = proximity_clusters(contours)
448
449    if not clusters:
450        return None, None  # No clusters found.
451
452    largest_cluster = max(clusters, key=len)
453    cluster_percentage = len(largest_cluster) / len(contours)
454
455    if cluster_percentage < MAJORITY_THRESHOLD:
456        return None, None
457
458    if DEBUG:
459        # Draw largest cluster on scratch frame.
460        for circle in largest_cluster:
461            cv2.circle(frame, (int(circle.x), int(circle.y)), int(circle.r),
462                       (0, 255, 0), thickness=2)
463
464    return largest_cluster, cluster_percentage
465
466
467def proximity_clusters(contours):
468    """Sorts the given contours into groups by distance.
469
470    Converts every given contour to a circle and clusters by adding a circle to
471    a cluster only if it is close to at least one other circle in the cluster.
472
473    TODO: Make algorithm faster (currently O(n**2)).
474
475    Args:
476        contours: Contours generated by OpenCV.
477
478    Returns:
479        A list of clusters, where each cluster is a list of the circles
480        contained in the cluster.
481    """
482    circles = convert_to_circles(contours)
483
484    # Use disjoint-set data structure to store assignments.  Start every point
485    # in their own cluster.
486    cluster_assignments = [-1 for i in range(len(circles))]
487
488    def get_canonical_index(i):
489        if cluster_assignments[i] >= 0:
490            index = get_canonical_index(cluster_assignments[i])
491            # Collapse tree for better runtime.
492            cluster_assignments[i] = index
493            return index
494        else:
495            return i
496
497    def get_cluster_size(i):
498        return -cluster_assignments[get_canonical_index(i)]
499
500    for i, curr in enumerate(circles):
501        close_circles = [j for j, p in enumerate(circles) if i != j and
502                         curr.distance_to(p) < CLUSTER_DISTANCE]
503        if close_circles:
504            # Note: largest_cluster is an index into cluster_assignments.
505            largest_cluster = min(close_circles, key=get_cluster_size)
506            largest_size = get_cluster_size(largest_cluster)
507            curr_index = get_canonical_index(i)
508            curr_size = get_cluster_size(i)
509            if largest_size > curr_size:
510                # largest_cluster is larger than us.
511                target_index = get_canonical_index(largest_cluster)
512                # Add our cluster size to the bigger one.
513                cluster_assignments[target_index] -= curr_size
514                # Reroute our group to the bigger one.
515                cluster_assignments[curr_index] = target_index
516            else:
517                # We're the largest (or equal to the largest) cluster.  Reroute
518                # all groups to us.
519                for j in close_circles:
520                    smaller_size = get_cluster_size(j)
521                    smaller_index = get_canonical_index(j)
522                    if smaller_index != curr_index:
523                        # We only want to modify clusters that aren't already in
524                        # the current one.
525
526                        # Add the smaller cluster's size to ours.
527                        cluster_assignments[curr_index] -= smaller_size
528                        # Reroute their group to us.
529                        cluster_assignments[smaller_index] = curr_index
530
531    # Convert assignments list into list of clusters.
532    clusters_dict = {}
533    for i in range(len(cluster_assignments)):
534        canonical_index = get_canonical_index(i)
535        if canonical_index not in clusters_dict:
536            clusters_dict[canonical_index] = []
537        clusters_dict[canonical_index].append(circles[i])
538    return clusters_dict.values()
539
540
541def find_cluster_bounding_rect(cluster, scratch_frame):
542    """Finds the minimum rectangle that bounds the given cluster.
543
544    The bounding rectangle will always be axis-aligned.
545
546    Args:
547        cluster:       Cluster being used to find the bounding rectangle.
548        scratch_frame: Image that rectangle is drawn onto for debugging
549                       purposes.
550
551    Returns:
552        The leftmost and topmost x and y coordinates, respectively, along with
553        the width and height of the rectangle.
554    """
555    avg_distance = find_average_neighbor_distance(cluster)
556    debug_print('Average distance between points in largest cluster is %s '
557                'pixels.' % avg_distance)
558
559    c_x = min(cluster, key=lambda c: c.x - c.r)
560    c_y = min(cluster, key=lambda c: c.y - c.r)
561    c_w = max(cluster, key=lambda c: c.x + c.r)
562    c_h = max(cluster, key=lambda c: c.y + c.r)
563
564    x = c_x.x - c_x.r - avg_distance
565    y = c_y.y - c_y.r - avg_distance
566    w = (c_w.x + c_w.r + avg_distance) - x
567    h = (c_h.y + c_h.r + avg_distance) - y
568
569    if DEBUG:
570        points = np.array([[x, y], [x + w, y], [x + w, y + h], [x, y + h]],
571                          np.int32)
572        cv2.polylines(scratch_frame, [points], True, (255, 0, 0), thickness=2)
573
574    return x, y, w, h
575
576
577def find_average_neighbor_distance(cluster):
578    """Finds the average distance between every circle and its closest neighbor.
579
580    Args:
581        cluster: List of circles
582
583    Returns:
584        The average distance.
585    """
586    avg_distance = 0.0
587    for a in cluster:
588        closest_point = None
589        closest_dist = None
590        for b in cluster:
591            if a is b:
592                continue
593            curr_dist = a.distance_to(b)
594            if closest_point is None or curr_dist < closest_dist:
595                closest_point = b
596                closest_dist = curr_dist
597        avg_distance += closest_dist
598    avg_distance /= len(cluster)
599    return avg_distance
600
601
602def find_num_columns_spanned(circles):
603    """Finds how many columns of the LED panel are spanned by the given circles.
604
605    Args:
606        circles: List of circles (assumed to be from the LED panel).
607
608    Returns:
609        The number of columns spanned.
610    """
611    if not circles:
612        return 0
613
614    def x_intersects(c_a, c_b):
615        return abs(c_a.x - c_b.x) < (c_a.r + c_b.r)
616
617    circles = sorted(circles, key=lambda c: c.x)
618    last_circle = circles[0]
619    num_columns = 1
620    for circle in circles[1:]:
621        if not x_intersects(circle, last_circle):
622            last_circle = circle
623            num_columns += 1
624
625    return num_columns
626
627
628def setup_debug_dir(dir_name=None):
629    """Creates a debug directory and required subdirectories.
630
631    Each subdirectory contains images from a different step in the process.
632
633    Args:
634        dir_name: The directory to create.  If none is specified, a temp
635        directory is created.
636
637    Returns:
638        The name of the directory that is used.
639    """
640    if dir_name is None:
641        dir_name = tempfile.mkdtemp()
642    else:
643        force_mkdir(dir_name)
644    print('Saving debugging files to "%s"' % dir_name)
645    # For original captured images.
646    force_mkdir(dir_name + '/raw', clean=True)
647    # For monochrome images.
648    force_mkdir(dir_name + '/mono', clean=True)
649    # For contours generated from monochrome images.
650    force_mkdir(dir_name + '/contour', clean=True)
651    # For post-contour debugging information.
652    force_mkdir(dir_name + '/scratch', clean=True)
653    return dir_name
654
655
656def force_mkdir(dir_name, clean=False):
657    """Creates a directory if it doesn't already exist.
658
659    Args:
660        dir_name: Name of the directory to create.
661        clean:    (optional) If set to true, cleans image files from the
662                  directory (if it already exists).
663    """
664    if os.path.exists(dir_name):
665        if clean:
666            for image in glob.glob('%s/*.png' % dir_name):
667                os.remove(image)
668    else:
669        os.makedirs(dir_name)
670
671
672def debug_print(s, *args, **kwargs):
673    """Only prints if the test is running in debug mode."""
674    if DEBUG:
675        print(s, *args, **kwargs)
676
677
678if __name__ == '__main__':
679    main()
680