1"""Experimentally determines a camera's rolling shutter skew.
2
3See the accompanying PDF for instructions on how to use this test.
4"""
5from __future__ import division
6from __future__ import print_function
7
8import argparse
9import glob
10import math
11import os
12import sys
13import tempfile
14
15import cv2
16import its.caps
17import its.device
18import its.image
19import its.objects
20import numpy as np
21
22DEBUG = False
23
24# Constants for which direction the camera is facing.
25FACING_FRONT = 0
26FACING_BACK = 1
27FACING_EXTERNAL = 2
28
29# Camera capture defaults.
30FPS = 30
31WIDTH = 640
32HEIGHT = 480
33TEST_LENGTH = 1
34
35# Each circle in a cluster must be within this many pixels of some other circle
36# in the cluster.
37CLUSTER_DISTANCE = 50.0 / HEIGHT
38# A cluster must consist of at least this percentage of the total contours for
39# it to be allowed into the computation.
40MAJORITY_THRESHOLD = 0.7
41
42# Constants to make sure the slope of the fitted line is reasonable.
43SLOPE_MIN_THRESHOLD = 0.5
44SLOPE_MAX_THRESHOLD = 1.5
45
46# To improve readability of unit conversions.
47SEC_TO_NSEC = float(10**9)
48MSEC_TO_NSEC = float(10**6)
49NSEC_TO_MSEC = 1.0 / float(10**6)
50
51
52class RollingShutterArgumentParser(object):
53  """Parses command line arguments for the rolling shutter test."""
54
55  def __init__(self):
56    self.__parser = argparse.ArgumentParser(
57        description='Run rolling shutter test')
58    self.__parser.add_argument(
59        '-d', '--debug',
60        action='store_true',
61        help='print and write data useful for debugging')
62    self.__parser.add_argument(
63        '-f', '--fps',
64        type=int,
65        help='FPS to capture with during the test (defaults to 30)')
66    self.__parser.add_argument(
67        '-i', '--img_size',
68        help=('comma-separated dimensions of captured images (defaults '
69              'to 640x480). Example: --img_size=<width>,<height>'))
70    self.__parser.add_argument(
71        '-l', '--led_time',
72        type=float,
73        required=True,
74        help=('how many milliseconds each column of the LED array is '
75              'lit for'))
76    self.__parser.add_argument(
77        '-p', '--panel_distance',
78        type=float,
79        help='how far the LED panel is from the camera (in meters)')
80    self.__parser.add_argument(
81        '-r', '--read_dir',
82        help=('read existing test data from specified directory.  If '
83              'not specified, new test data is collected from the '
84              'device\'s camera)'))
85    self.__parser.add_argument(
86        '--device_id',
87        help=('device ID for device being tested (can also use '
88              '\'device=<DEVICE ID>\')'))
89    self.__parser.add_argument(
90        '-t', '--test_length',
91        type=int,
92        help=('how many seconds the test should run for (defaults to 1 '
93              'second)'))
94    self.__parser.add_argument(
95        '-o', '--debug_dir',
96        help=('write debugging information in a folder in the '
97              'specified directory.  Otherwise, the system\'s default '
98              'location for temporary folders is used.  --debug must '
99              'be specified along with this argument.'))
100
101  def parse_args(self):
102    """Returns object containing parsed values from the command line."""
103    # Don't show argparse the 'device' flag, since it's in a different
104    # format than the others (to maintain CameraITS conventions) and it will
105    # complain.
106    filtered_args = [arg for arg in sys.argv[1:] if 'device=' not in arg]
107    args = self.__parser.parse_args(filtered_args)
108    if args.device_id:
109      # If argparse format is used, convert it to a format its.device can
110      # use later on.
111      sys.argv.append('device=%s' % args.device_id)
112    return args
113
114
115def main():
116  global DEBUG
117  global CLUSTER_DISTANCE
118
119  parser = RollingShutterArgumentParser()
120  args = parser.parse_args()
121
122  DEBUG = args.debug
123  if not DEBUG and args.debug_dir:
124    print('argument --debug_dir requires --debug', file=sys.stderr)
125    sys.exit()
126
127  if args.read_dir is None:
128    # Collect new data.
129    raw_caps, reported_skew = collect_data(args)
130    frames = [its.image.convert_capture_to_rgb_image(c) for c in raw_caps]
131  else:
132    # Load existing data.
133    frames, reported_skew = load_data(args.read_dir)
134
135  # Make the cluster distance relative to the height of the image.
136  (frame_h, _, _) = frames[0].shape
137  CLUSTER_DISTANCE = frame_h * CLUSTER_DISTANCE
138  debug_print('Setting cluster distance to %spx.' % CLUSTER_DISTANCE)
139
140  if DEBUG:
141    debug_dir = setup_debug_dir(args.debug_dir)
142    # Write raw frames.
143    for i, img in enumerate(frames):
144      its.image.write_image(img, '%s/raw/%03d.png' % (debug_dir, i))
145  else:
146    debug_dir = None
147
148  avg_shutter_skew, num_frames_used = find_average_shutter_skew(
149      frames, args.led_time, debug_dir)
150  if debug_dir:
151    # Write the reported skew with the raw images, so the directory can also
152    # be used to read from.
153    with open(debug_dir + '/raw/reported_skew.txt', 'w') as f:
154      f.write('%sms\n' % reported_skew)
155
156  if avg_shutter_skew is None:
157    print('Could not find usable frames.')
158  else:
159    print('Device reported shutter skew of %sms.' % reported_skew)
160    print('Measured shutter skew is %sms (averaged over %s frames).' %
161          (avg_shutter_skew, num_frames_used))
162
163
164def collect_data(args):
165  """Capture a new set of frames from the device's camera.
166
167  Args:
168      args: Parsed command line arguments.
169
170  Returns:
171      A list of RGB images as numpy arrays.
172  """
173  fps = args.fps if args.fps else FPS
174  if args.img_size:
175    w, h = map(int, args.img_size.split(','))
176  else:
177    w, h = WIDTH, HEIGHT
178  test_length = args.test_length if args.test_length else TEST_LENGTH
179
180  with its.device.ItsSession() as cam:
181    props = cam.get_camera_properties()
182    its.caps.skip_unless(its.caps.manual_sensor(props))
183    facing = props['android.lens.facing']
184    if facing != FACING_FRONT and facing != FACING_BACK:
185      print('Unknown lens facing %s' % facing)
186      assert 0
187
188    fmt = {'format': 'yuv', 'width': w, 'height': h}
189    s, e, _, _, _ = cam.do_3a(get_results=True, do_af=False)
190    req = its.objects.manual_capture_request(s, e)
191    req['android.control.aeTargetFpsRange'] = [fps, fps]
192
193    # Convert from milliseconds to nanoseconds.  We only want enough
194    # exposure time to saturate approximately one column.
195    exposure_time = (args.led_time / 2.0) * MSEC_TO_NSEC
196    print('Using exposure time of %sns.' % exposure_time)
197    req['android.sensor.exposureTime'] = exposure_time
198    req['android.sensor.frameDuration'] = int(SEC_TO_NSEC / fps)
199
200    if args.panel_distance is not None:
201      # Convert meters to diopters and use that for the focus distance.
202      req['android.lens.focusDistance'] = 1 / args.panel_distance
203    print('Starting capture')
204    raw_caps = cam.do_capture([req]*fps*test_length, fmt)
205    print('Finished capture')
206
207    # Convert from nanoseconds to milliseconds.
208    shutter_skews = {c['metadata']['android.sensor.rollingShutterSkew'] *
209                     NSEC_TO_MSEC for c in raw_caps}
210    # All frames should have same rolling shutter skew.
211    assert len(shutter_skews) == 1
212    shutter_skew = list(shutter_skews)[0]
213
214    return raw_caps, shutter_skew
215
216
217def load_data(dir_name):
218  """Reads camera frame data from an existing directory.
219
220  Args:
221    dir_name: Name of the directory to read data from.
222
223  Returns:
224    A list of RGB images as numpy arrays.
225  """
226  frame_files = glob.glob('%s/*.png' % dir_name)
227  frames = []
228  for frame_file in sorted(frame_files):
229    frames.append(its.image.load_rgb_image(frame_file))
230  with open('%s/reported_skew.txt' % dir_name, 'r') as f:
231    reported_skew = f.readline()[:-2]  # Strip off 'ms' suffix
232  return frames, reported_skew
233
234
235def find_average_shutter_skew(frames, led_time, debug_dir=None):
236  """Finds the average shutter skew using the given frames.
237
238  Frames without enough information will be discarded from the average to
239  improve overall accuracy.
240
241  Args:
242      frames:    List of RGB images from the camera being tested.
243      led_time:  How long a single LED column is lit for (in milliseconds).
244      debug_dir: (optional) Directory to write debugging information to.
245
246  Returns:
247      The average calculated shutter skew and the number of frames used to
248      calculate the average.
249  """
250  avg_shutter_skew = 0.0
251  avg_slope = 0.0
252  weight = 0.0
253  num_frames_used = 0
254
255  for i, frame in enumerate(frames):
256    debug_print('------------------------')
257    debug_print('| PROCESSING FRAME %03d |' % i)
258    debug_print('------------------------')
259    shutter_skew, confidence, slope = calculate_shutter_skew(
260        frame, led_time, i, debug_dir=debug_dir)
261    if shutter_skew is None:
262      debug_print('Skipped frame.')
263    else:
264      debug_print('Shutter skew is %sms (confidence: %s).' %
265                  (shutter_skew, confidence))
266      # Use the confidence to weight the average.
267      avg_shutter_skew += shutter_skew * confidence
268      avg_slope += slope * confidence
269      weight += confidence
270      num_frames_used += 1
271
272  debug_print('\n')
273  if num_frames_used == 0:
274    return None, None
275  else:
276    avg_shutter_skew /= weight
277    avg_slope /= weight
278    slope_err_str = ('The average slope of the fitted line was too %s '
279                     'to get an accurate measurement (slope was %s).  '
280                     'Try making the LED panel %s.')
281    if avg_slope < SLOPE_MIN_THRESHOLD:
282      print(slope_err_str % ('flat', avg_slope, 'slower'),
283            file=sys.stderr)
284    elif avg_slope > SLOPE_MAX_THRESHOLD:
285      print(slope_err_str % ('steep', avg_slope, 'faster'),
286            file=sys.stderr)
287    return avg_shutter_skew, num_frames_used
288
289
290def calculate_shutter_skew(frame, led_time, frame_num=None, debug_dir=None):
291  """Calculates the shutter skew of the camera being used for this test.
292
293  Args:
294      frame:     A single RGB image captured by the camera being tested.
295      led_time:  How long a single LED column is lit for (in milliseconds).
296      frame_num: (optional) Number of the given frame.
297      debug_dir: (optional) Directory to write debugging information to.
298
299  Returns:
300      The shutter skew (in milliseconds), the confidence in the accuracy of
301      the measurement (useful for weighting averages), and the slope of the
302      fitted line.
303  """
304  contours, scratch_img, contour_img, mono_img = find_contours(frame.copy())
305  if debug_dir is not None:
306    cv2.imwrite('%s/contour/%03d.png' % (debug_dir, frame_num), contour_img)
307    cv2.imwrite('%s/mono/%03d.png' % (debug_dir, frame_num), mono_img)
308
309  largest_cluster, cluster_percentage = find_largest_cluster(contours,
310                                                             scratch_img)
311  if largest_cluster is None:
312    debug_print('No majority cluster found.')
313    return None, None, None
314  elif len(largest_cluster) <= 1:
315    debug_print('Majority cluster was too small.')
316    return None, None, None
317  debug_print('%s points in the largest cluster.' % len(largest_cluster))
318
319  np_cluster = np.array([[c.x, c.y] for c in largest_cluster])
320  [vx], [vy], [x0], [y0] = cv2.fitLine(np_cluster, cv2.cv.CV_DIST_L2,
321                                       0, 0.01, 0.01)
322  slope = vy / vx
323  debug_print('Slope is %s.' % slope)
324  (frame_h, frame_w, _) = frame.shape
325  # Draw line onto scratch frame.
326  pt1 = tuple(map(int, (x0 - vx * 1000, y0 - vy * 1000)))
327  pt2 = tuple(map(int, (x0 + vx * 1000, y0 + vy * 1000)))
328  cv2.line(scratch_img, pt1, pt2, (0, 255, 255), thickness=3)
329
330  # We only need the width of the cluster.
331  _, _, cluster_w, _ = find_cluster_bounding_rect(largest_cluster,
332                                                  scratch_img)
333
334  num_columns = find_num_columns_spanned(largest_cluster)
335  debug_print('%s columns spanned by cluster.' % num_columns)
336  # How long it takes for a column to move from the left of the bounding
337  # rectangle to the right.
338  left_to_right_time = led_time * num_columns
339  milliseconds_per_x_pixel = left_to_right_time / cluster_w
340  # The distance between the line's intersection at the top of the frame and
341  # the intersection at the bottom.
342  x_range = frame_h / slope
343  shutter_skew = milliseconds_per_x_pixel * x_range
344  # If the aspect ratio is different from 4:3 (the aspect ratio of the actual
345  # sensor), we need to correct, because it will be cropped.
346  shutter_skew *= (float(frame_w) / float(frame_h)) / (4.0 / 3.0)
347
348  if debug_dir is not None:
349    cv2.imwrite('%s/scratch/%03d.png' % (debug_dir, frame_num),
350                scratch_img)
351
352  return shutter_skew, cluster_percentage, slope
353
354
355def find_contours(img):
356  """Finds contours in the given image.
357
358  Args:
359    img: Image in Android camera RGB format.
360
361  Returns:
362    OpenCV-formatted contours, the original image in OpenCV format, a
363    thresholded image with the contours drawn on, and a grayscale version of
364    the image.
365  """
366  # Convert to format OpenCV can work with (BGR ordering with byte-ranged
367  # values).
368  img *= 255
369  img = img.astype(np.uint8)
370  img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
371
372  # Since the LED colors for the panel we're using are red, we can get better
373  # contours for the LEDs if we ignore the green and blue channels.  This also
374  # makes it so we don't pick up the blue control screen of the LED panel.
375  red_img = img[:, :, 2]
376  _, thresh = cv2.threshold(red_img, 0, 255, cv2.THRESH_BINARY +
377                            cv2.THRESH_OTSU)
378
379  # Remove noise before finding contours by eroding the thresholded image and
380  # then re-dilating it.  The size of the kernel represents how many
381  # neighboring pixels to consider for the result of a single pixel.
382  kernel = np.ones((3, 3), np.uint8)
383  opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
384
385  if DEBUG:
386    # Need to convert it back to BGR if we want to draw colored contours.
387    contour_img = cv2.cvtColor(opening, cv2.COLOR_GRAY2BGR)
388  else:
389    contour_img = None
390  cv2_version = cv2.__version__
391  if cv2_version.startswith('3.'):  # OpenCV 3.x
392    _, contours, _ = cv2.findContours(
393        opening, cv2.cv.CV_RETR_EXTERNAL, cv2.cv.CV_CHAIN_APPROX_NONE)
394  else:  # OpenCV 2.x and 4.x
395    contours, _ = cv2.findContours(
396        opening, cv2.cv.CV_RETR_EXTERNAL, cv2.cv.CV_CHAIN_APPROX_NONE)
397  if DEBUG:
398    cv2.drawContours(contour_img, contours, -1, (0, 0, 255), thickness=2)
399  return contours, img, contour_img, red_img
400
401
402def convert_to_circles(contours):
403  """Converts given contours into circle objects.
404
405  Args:
406    contours: Contours generated by OpenCV.
407
408  Returns:
409    A list of circles.
410  """
411
412  class Circle(object):
413    """Holds data to uniquely define a circle."""
414
415    def __init__(self, contour):
416      self.x = int(np.mean(contour[:, 0, 0]))
417      self.y = int(np.mean(contour[:, 0, 1]))
418      # Get diameters of each axis then half it.
419      x_r = (np.max(contour[:, 0, 0]) - np.min(contour[:, 0, 0])) / 2.0
420      y_r = (np.max(contour[:, 0, 1]) - np.min(contour[:, 0, 1])) / 2.0
421      # Average x radius and y radius to get the approximate radius for
422      # the given contour.
423      self.r = (x_r + y_r) / 2.0
424      assert self.r > 0.0
425
426    def distance_to(self, other):
427      return (math.sqrt((other.x - self.x)**2 + (other.y - self.y)**2) -
428              self.r - other.r)
429
430    def intersects(self, other):
431      return self.distance_to(other) <= 0.0
432
433  return list(map(Circle, contours))
434
435
436def find_largest_cluster(contours, frame):
437  """Finds the largest cluster in the given contours.
438
439  Args:
440    contours: Contours generated by OpenCV.
441    frame:    For drawing debugging information onto.
442
443  Returns:
444    The cluster with the most contours in it and the percentage of all
445    contours that the cluster contains.
446  """
447  clusters = proximity_clusters(contours)
448
449  if not clusters:
450    return None, None  # No clusters found.
451
452  largest_cluster = max(clusters, key=len)
453  cluster_percentage = len(largest_cluster) / len(contours)
454
455  if cluster_percentage < MAJORITY_THRESHOLD:
456    return None, None
457
458  if DEBUG:
459    # Draw largest cluster on scratch frame.
460    for circle in largest_cluster:
461      cv2.circle(frame, (int(circle.x), int(circle.y)), int(circle.r),
462                 (0, 255, 0), thickness=2)
463
464  return largest_cluster, cluster_percentage
465
466
467def proximity_clusters(contours):
468  """Sorts the given contours into groups by distance.
469
470  Converts every given contour to a circle and clusters by adding a circle to
471  a cluster only if it is close to at least one other circle in the cluster.
472
473  TODO: Make algorithm faster (currently O(n**2)).
474
475  Args:
476    contours: Contours generated by OpenCV.
477
478  Returns:
479    A list of clusters, where each cluster is a list of the circles
480    contained in the cluster.
481  """
482  circles = convert_to_circles(contours)
483
484  # Use disjoint-set data structure to store assignments.  Start every point
485  # in their own cluster.
486  cluster_assignments = [-1 for i in range(len(circles))]
487
488  def get_canonical_index(i):
489    if cluster_assignments[i] >= 0:
490      index = get_canonical_index(cluster_assignments[i])
491      # Collapse tree for better runtime.
492      cluster_assignments[i] = index
493      return index
494    else:
495      return i
496
497  def get_cluster_size(i):
498    return -cluster_assignments[get_canonical_index(i)]
499
500  for i, curr in enumerate(circles):
501    close_circles = [j for j, p in enumerate(circles) if i != j and
502                     curr.distance_to(p) < CLUSTER_DISTANCE]
503    if close_circles:
504      # Note: largest_cluster is an index into cluster_assignments.
505      largest_cluster = min(close_circles, key=get_cluster_size)
506      largest_size = get_cluster_size(largest_cluster)
507      curr_index = get_canonical_index(i)
508      curr_size = get_cluster_size(i)
509      if largest_size > curr_size:
510        # largest_cluster is larger than us.
511        target_index = get_canonical_index(largest_cluster)
512        # Add our cluster size to the bigger one.
513        cluster_assignments[target_index] -= curr_size
514        # Reroute our group to the bigger one.
515        cluster_assignments[curr_index] = target_index
516      else:
517        # We're the largest (or equal to the largest) cluster.  Reroute
518        # all groups to us.
519        for j in close_circles:
520          smaller_size = get_cluster_size(j)
521          smaller_index = get_canonical_index(j)
522          if smaller_index != curr_index:
523            # We only want to modify clusters that aren't already in
524            # the current one.
525
526            # Add the smaller cluster's size to ours.
527            cluster_assignments[curr_index] -= smaller_size
528            # Reroute their group to us.
529            cluster_assignments[smaller_index] = curr_index
530
531  # Convert assignments list into list of clusters.
532  clusters_dict = {}
533  for i in range(len(cluster_assignments)):
534    canonical_index = get_canonical_index(i)
535    if canonical_index not in clusters_dict:
536      clusters_dict[canonical_index] = []
537    clusters_dict[canonical_index].append(circles[i])
538  return clusters_dict.values()
539
540
541def find_cluster_bounding_rect(cluster, scratch_frame):
542  """Finds the minimum rectangle that bounds the given cluster.
543
544  The bounding rectangle will always be axis-aligned.
545
546  Args:
547    cluster:       Cluster being used to find the bounding rectangle.
548    scratch_frame: Image that rectangle is drawn onto for debugging
549                   purposes.
550
551  Returns:
552    The leftmost and topmost x and y coordinates, respectively, along with
553    the width and height of the rectangle.
554  """
555  avg_distance = find_average_neighbor_distance(cluster)
556  debug_print('Average distance between points in largest cluster is %s '
557              'pixels.' % avg_distance)
558
559  c_x = min(cluster, key=lambda c: c.x - c.r)
560  c_y = min(cluster, key=lambda c: c.y - c.r)
561  c_w = max(cluster, key=lambda c: c.x + c.r)
562  c_h = max(cluster, key=lambda c: c.y + c.r)
563
564  x = c_x.x - c_x.r - avg_distance
565  y = c_y.y - c_y.r - avg_distance
566  w = (c_w.x + c_w.r + avg_distance) - x
567  h = (c_h.y + c_h.r + avg_distance) - y
568
569  if DEBUG:
570    points = np.array([[x, y], [x + w, y], [x + w, y + h], [x, y + h]],
571                      np.int32)
572    cv2.polylines(scratch_frame, [points], True, (255, 0, 0), thickness=2)
573
574  return x, y, w, h
575
576
577def find_average_neighbor_distance(cluster):
578  """Finds the average distance between every circle and its closest neighbor.
579
580  Args:
581    cluster: List of circles
582
583  Returns:
584    The average distance.
585  """
586  avg_distance = 0.0
587  for a in cluster:
588    closest_point = None
589    closest_dist = None
590    for b in cluster:
591      if a is b:
592        continue
593      curr_dist = a.distance_to(b)
594      if closest_point is None or curr_dist < closest_dist:
595        closest_point = b
596        closest_dist = curr_dist
597    avg_distance += closest_dist
598  avg_distance /= len(cluster)
599  return avg_distance
600
601
602def find_num_columns_spanned(circles):
603  """Finds how many columns of the LED panel are spanned by the given circles.
604
605  Args:
606    circles: List of circles (assumed to be from the LED panel).
607
608  Returns:
609    The number of columns spanned.
610  """
611  if not circles:
612    return 0
613
614  def x_intersects(c_a, c_b):
615    return abs(c_a.x - c_b.x) < (c_a.r + c_b.r)
616
617  circles = sorted(circles, key=lambda c: c.x)
618  last_circle = circles[0]
619  num_columns = 1
620  for circle in circles[1:]:
621    if not x_intersects(circle, last_circle):
622      last_circle = circle
623      num_columns += 1
624
625  return num_columns
626
627
628def setup_debug_dir(dir_name=None):
629  """Creates a debug directory and required subdirectories.
630
631  Each subdirectory contains images from a different step in the process.
632
633  Args:
634    dir_name: The directory to create.  If none is specified, a temp
635    directory is created.
636
637  Returns:
638    The name of the directory that is used.
639  """
640  if dir_name is None:
641    dir_name = tempfile.mkdtemp()
642  else:
643    force_mkdir(dir_name)
644  print('Saving debugging files to "%s"' % dir_name)
645  # For original captured images.
646  force_mkdir(dir_name + '/raw', clean=True)
647  # For monochrome images.
648  force_mkdir(dir_name + '/mono', clean=True)
649  # For contours generated from monochrome images.
650  force_mkdir(dir_name + '/contour', clean=True)
651  # For post-contour debugging information.
652  force_mkdir(dir_name + '/scratch', clean=True)
653  return dir_name
654
655
656def force_mkdir(dir_name, clean=False):
657  """Creates a directory if it doesn't already exist.
658
659  Args:
660    dir_name: Name of the directory to create.
661    clean:    (optional) If set to true, cleans image files from the
662              directory (if it already exists).
663  """
664  if os.path.exists(dir_name):
665    if clean:
666      for image in glob.glob('%s/*.png' % dir_name):
667        os.remove(image)
668  else:
669    os.makedirs(dir_name)
670
671
672def debug_print(s, *args, **kwargs):
673  """Only prints if the test is running in debug mode."""
674  if DEBUG:
675    print(s, *args, **kwargs)
676
677
678if __name__ == '__main__':
679  main()
680