1# Copyright 2014 The Android Open Source Project.
2
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Noise model utility functions."""
15
16import collections
17import logging
18import math
19import os.path
20import pickle
21from typing import Any, Dict, List, Tuple
22import warnings
23import capture_request_utils
24import image_processing_utils
25from matplotlib import pylab
26import matplotlib.pyplot as plt
27import noise_model_constants
28import numpy as np
29import scipy.stats
30
31
32_OUTLIER_MEDIAN_ABS_DEVS_DEFAULT = (
33    noise_model_constants.OUTLIER_MEDIAN_ABS_DEVS_DEFAULT
34)
35
36
37def _check_auto_exposure_targets(
38    auto_exposure_ns: float,
39    sens_min: int,
40    sens_max: int,
41    bracket_factor: int,
42    min_exposure_ns: int,
43    max_exposure_ns: int,
44) -> None:
45  """Checks if AE too bright for highest gain & too dark for lowest gain.
46
47  Args:
48    auto_exposure_ns: The auto exposure value in nanoseconds.
49    sens_min: The minimum sensitivity value.
50    sens_max: The maximum sensitivity value.
51    bracket_factor: Exposure bracket factor.
52    min_exposure_ns: The minimum exposure time in nanoseconds.
53    max_exposure_ns: The maximum exposure time in nanoseconds.
54  """
55
56  if auto_exposure_ns < min_exposure_ns * sens_max:
57    raise AssertionError(
58        'Scene is too bright to properly expose at highest '
59        f'sensitivity: {sens_max}'
60    )
61  if auto_exposure_ns * bracket_factor > max_exposure_ns * sens_min:
62    raise AssertionError(
63        'Scene is too dark to properly expose at lowest '
64        f'sensitivity: {sens_min}'
65    )
66
67
68def check_noise_model_shape(noise_model: np.ndarray) -> None:
69  """Checks if the shape of noise model is valid.
70
71  Args:
72    noise_model: A numpy array of shape (num_channels, num_parameters).
73  """
74  num_channels, num_parameters = noise_model.shape
75  if num_channels not in noise_model_constants.VALID_NUM_CHANNELS:
76    raise AssertionError(
77        f'The number of channels {num_channels} is not in'
78        f' {noise_model_constants.VALID_NUM_CHANNELS}.'
79    )
80  if num_parameters != 4:
81    raise AssertionError(
82        f'The number of parameters of each channel {num_parameters} != 4.'
83    )
84
85
86def validate_noise_model(
87    noise_model: np.ndarray,
88    color_channels: List[str],
89    sens_min: int,
90) -> None:
91  """Performs validation checks on the noise model.
92
93  This function checks if read noise and intercept gradient are positive for
94  each color channel.
95
96  Args:
97      noise_model: Noise model parameters each channel, including scale_a,
98        scale_b, offset_a, offset_b.
99      color_channels: Array of color channels.
100      sens_min: Minimum sensitivity value.
101  """
102  check_noise_model_shape(noise_model)
103  num_channels = noise_model.shape[0]
104  if len(color_channels) != num_channels:
105    raise AssertionError(
106        f'Number of color channels {num_channels} != number of noise model '
107        f'channels {len(color_channels)}.'
108    )
109
110  scale_a, _, offset_a, offset_b = zip(*noise_model)
111  for i, color_channel in enumerate(color_channels):
112    if scale_a[i] < 0:
113      raise AssertionError(
114          f'{color_channel} model API scale gradient < 0: {scale_a[i]:.4e}'
115      )
116
117    if offset_a[i] <= 0:
118      raise AssertionError(
119          f'{color_channel} model API intercept gradient < 0: {offset_a[i]:.4e}'
120      )
121
122    read_noise = offset_a[i] * sens_min * sens_min + offset_b[i]
123    if read_noise <= 0:
124      raise AssertionError(
125          f'{color_channel} model min ISO noise < 0! '
126          f'API intercept gradient: {offset_a[i]:.4e}, '
127          f'API intercept offset: {offset_b[i]:.4e}, '
128          f'read_noise: {read_noise:.4e}'
129      )
130
131
132def compute_digital_gains(
133    gains: np.ndarray,
134    sens_max_analog: np.ndarray,
135) -> np.ndarray:
136  """Computes the digital gains for the given gains and maximum analog gain.
137
138  Define digital gain as the gain divide the max analog gain sensitivity.
139  This function ensures that the digital gains are always equal to 1. If any
140  of the digital gains is not equal to 1, an AssertionError is raised.
141
142  Args:
143    gains: An array of gains.
144    sens_max_analog: The maximum analog gain sensitivity.
145
146  Returns:
147    An numpy array of digital gains.
148  """
149  digital_gains = np.maximum(gains / sens_max_analog, 1)
150  if not np.all(digital_gains == 1):
151    raise AssertionError(
152        f'Digital gains are not all 1! gains: {gains}, '
153        f'Max analog gain sensitivity: {sens_max_analog}.'
154    )
155  return digital_gains
156
157
158def crop_and_save_capture(
159    cap,
160    props,
161    capture_path: str,
162    num_tiles_crop: int,
163) -> None:
164  """Crops and saves a capture image.
165
166  Args:
167    cap: The capture to be cropped and saved.
168    props: The properties to be used to convert the capture to an RGB image.
169    capture_path: The path to which the capture image should be saved.
170    num_tiles_crop: The number of tiles to crop.
171  """
172  img = image_processing_utils.convert_capture_to_rgb_image(cap, props=props)
173  height, width, _ = img.shape
174  num_tiles_crop_max = min(height, width) // 2
175  if num_tiles_crop >= num_tiles_crop_max:
176    raise AssertionError(
177        f'Number of tiles to corp {num_tiles_crop} >= {num_tiles_crop_max}.'
178    )
179  img = img[
180      num_tiles_crop: height - num_tiles_crop,
181      num_tiles_crop: width - num_tiles_crop,
182      :,
183  ]
184
185  image_processing_utils.write_image(img, capture_path, True)
186
187
188def crop_and_reorder_stats_images(
189    mean_img: np.ndarray,
190    var_img: np.ndarray,
191    num_tiles_crop: int,
192    channel_indices: List[int],
193) -> Tuple[np.ndarray, np.ndarray]:
194  """Crops the stats images and sorts stats images channels in canonical order.
195
196  Args:
197      mean_img: The mean image.
198      var_img: The variance image.
199      num_tiles_crop: The number of tiles to crop from each side of the image.
200      channel_indices: The channel indices to sort stats image channels in
201        canonical order.
202
203  Returns:
204      The cropped and reordered mean image and variance image.
205  """
206  if mean_img.shape != var_img.shape:
207    raise AssertionError(
208        'Unmatched shapes of mean and variance image: '
209        f'shape of mean image is {mean_img.shape}, '
210        f'shape of variance image is {var_img.shape}.'
211    )
212  height, width, _ = mean_img.shape
213  if 2 * num_tiles_crop > min(height, width):
214    raise AssertionError(
215        f'The number of tiles to crop ({num_tiles_crop}) is so large that'
216        ' images cannot be cropped.'
217    )
218
219  means = []
220  vars_ = []
221  for i in channel_indices:
222    means_i = mean_img[
223        num_tiles_crop: height - num_tiles_crop,
224        num_tiles_crop: width - num_tiles_crop,
225        i,
226    ]
227    vars_i = var_img[
228        num_tiles_crop: height - num_tiles_crop,
229        num_tiles_crop: width - num_tiles_crop,
230        i,
231    ]
232    means.append(means_i)
233    vars_.append(vars_i)
234  means, vars_ = np.asarray(means), np.asarray(vars_)
235  return means, vars_
236
237
238def filter_stats(
239    means: np.ndarray,
240    vars_: np.ndarray,
241    black_levels: List[float],
242    white_level: float,
243    max_signal_value: float = 0.25,
244    is_remove_var_outliers: bool = False,
245    deviations: int = _OUTLIER_MEDIAN_ABS_DEVS_DEFAULT,
246) -> Tuple[np.ndarray, np.ndarray]:
247  """Filters means outliers and variance outliers.
248
249  Args:
250      means: A numpy ndarray of pixel mean values.
251      vars_: A numpy ndarray of pixel variance values.
252      black_levels: A list of black levels for each pixel.
253      white_level: A scalar white level.
254      max_signal_value: The maximum signal (mean) value.
255      is_remove_var_outliers: A boolean value indicating whether to remove
256        variance outliers.
257      deviations: A scalar value specifying the number of standard deviations to
258        use when removing variance outliers.
259
260  Returns:
261      A tuple of (means_filtered, vars_filtered) where means_filtered and
262      vars_filtered are numpy ndarrays of filtered pixel mean and variance
263      values, respectively.
264  """
265  if means.shape != vars_.shape:
266    raise AssertionError(
267        f'Unmatched shapes of means and vars: means.shape={means.shape},'
268        f' vars.shape={vars_.shape}.'
269    )
270  num_planes = len(means)
271  means_filtered = []
272  vars_filtered = []
273
274  for pidx in range(num_planes):
275    black_level = black_levels[pidx]
276    means_i = means[pidx]
277    vars_i = vars_[pidx]
278
279    # Basic constraints:
280    # (1) means are within the range [0, 1],
281    # (2) vars are non-negative values.
282    constraints = [
283        means_i >= black_level,
284        means_i <= white_level,
285        vars_i >= 0,
286    ]
287    if is_remove_var_outliers:
288      # Filter out variances that differ too much from the median of variances.
289      std_dev = scipy.stats.median_abs_deviation(vars_i, axis=None, scale=1)
290      med = np.median(vars_i)
291      constraints.extend([
292          vars_i > med - deviations * std_dev,
293          vars_i < med + deviations * std_dev,
294      ])
295
296    keep_indices = np.where(np.logical_and.reduce(constraints))
297    if not np.any(keep_indices):
298      logging.info('After filter channel %d, stats array is empty.', pidx)
299
300    # Normalizes the range to [0, 1].
301    means_i = (means_i[keep_indices] - black_level) / (
302        white_level - black_level
303    )
304    vars_i = vars_i[keep_indices] / ((white_level - black_level) ** 2)
305    # Filter out the tiles if they have samples that might be clipped.
306    mean_var_pairs = list(
307        filter(
308            lambda x: x[0] + 2 * math.sqrt(x[1]) < max_signal_value,
309            zip(means_i, vars_i),
310        )
311    )
312    if mean_var_pairs:
313      means_i, vars_i = zip(*mean_var_pairs)
314    else:
315      means_i, vars_i = [], []
316    means_i = np.asarray(means_i)
317    vars_i = np.asarray(vars_i)
318    means_filtered.append(means_i)
319    vars_filtered.append(vars_i)
320
321  # After filtering, means_filtered and vars_filtered may have different shapes
322  # in each color planes.
323  means_filtered = np.asarray(means_filtered, dtype=object)
324  vars_filtered = np.asarray(vars_filtered, dtype=object)
325  return means_filtered, vars_filtered
326
327
328def get_next_iso(
329    iso: float,
330    max_iso: int,
331    iso_multiplier: float,
332) -> float:
333  """Moves to the next sensitivity.
334
335  Args:
336    iso: The current ISO sensitivity.
337    max_iso: The maximum ISO sensitivity.
338    iso_multiplier: The ISO multiplier to use.
339
340  Returns:
341    The next ISO sensitivity.
342  """
343  if iso_multiplier <= 1:
344    raise AssertionError(
345        f'ISO multiplier is {iso_multiplier}, which should be greater than 1.'
346    )
347
348  if round(iso) < max_iso < round(iso * iso_multiplier):
349    return max_iso
350  else:
351    return iso * iso_multiplier
352
353
354def capture_stats_images(
355    cam,
356    props,
357    stats_config: Dict[str, Any],
358    sens_min: int,
359    sens_max_meas: int,
360    zoom_ratio: float,
361    num_tiles_crop: int,
362    max_signal_value: float,
363    iso_multiplier: float,
364    max_bracket: int,
365    bracket_factor: int,
366    capture_path_prefix: str,
367    stats_file_name: str = '',
368    is_remove_var_outliers: bool = False,
369    outlier_median_abs_deviations: int = _OUTLIER_MEDIAN_ABS_DEVS_DEFAULT,
370    is_debug_mode: bool = False,
371) -> Dict[int, List[Tuple[float, np.ndarray, np.ndarray]]]:
372  """Capture stats images and saves the stats in a dictionary.
373
374  This function captures stats images at different ISO values and exposure
375  times, and stores the stats data in a file with the specified name.
376  The stats data includes the mean and variance of each plane, as well as
377  exposure times.
378
379  Args:
380    cam: The camera session (its_session_utils.ItsSession) for capturing stats
381      images.
382    props: Camera property object.
383    stats_config: The stats format config, a dictionary that specifies the raw
384      stats image format and tile size.
385    sens_min: The minimum sensitivity.
386    sens_max_meas: The maximum sensitivity to measure.
387    zoom_ratio: The zoom ratio to use.
388    num_tiles_crop: The number of tiles to crop the images into.
389    max_signal_value: The maximum signal value to allow.
390    iso_multiplier: The ISO multiplier to use.
391    max_bracket: The maximum number of bracketed exposures to capture.
392    bracket_factor: The bracket factor with default value 2^max_bracket.
393    capture_path_prefix: The path prefix to use for captured images.
394    stats_file_name: The name of the file to save the stats images to.
395    is_remove_var_outliers: Whether to remove variance outliers.
396    outlier_median_abs_deviations: The number of median absolute deviations to
397      use for detecting outliers.
398    is_debug_mode: Whether to enable debug mode.
399
400  Returns:
401    A dictionary mapping ISO values to mean and variance image of each plane.
402  """
403  if is_debug_mode:
404    logging.info('Capturing stats images with stats config: %s.', stats_config)
405    capture_folder = os.path.join(capture_path_prefix, 'captures')
406    if not os.path.exists(capture_folder):
407      os.makedirs(capture_folder)
408    logging.info('Capture folder: %s', capture_folder)
409
410  white_level = props['android.sensor.info.whiteLevel']
411  min_exposure_ns, max_exposure_ns = props[
412      'android.sensor.info.exposureTimeRange'
413  ]
414  # Focus at zero to intentionally blur the scene as much as possible.
415  f_dist = 0.0
416  # Whether the stats images are quad Bayer or standard Bayer.
417  is_quad_bayer = 'QuadBayer' in stats_config['format']
418  if is_quad_bayer:
419    num_channels = noise_model_constants.NUM_QUAD_BAYER_CHANNELS
420  else:
421    num_channels = noise_model_constants.NUM_BAYER_CHANNELS
422  # A dict maps iso to stats images of different exposure times.
423  iso_to_stats_dict = collections.defaultdict(list)
424  # Start the sensitivity at the minimum.
425  iso = sens_min
426  # Previous iso cap.
427  pre_iso_cap = None
428  if stats_file_name:
429    stats_file_path = os.path.join(capture_path_prefix, stats_file_name)
430    if os.path.isfile(stats_file_path):
431      try:
432        with open(stats_file_path, 'rb') as f:
433          saved_iso_to_stats_dict = pickle.load(f)
434          # Filter saved stats data.
435          if saved_iso_to_stats_dict:
436            for iso, stats in saved_iso_to_stats_dict.items():
437              if sens_min <= iso <= sens_max_meas:
438                iso_to_stats_dict[iso] = stats
439
440        # Set the starting iso to the last iso in saved stats file.
441        if iso_to_stats_dict.keys():
442          pre_iso_cap = sorted(iso_to_stats_dict.keys())[-1]
443          iso = get_next_iso(pre_iso_cap, sens_max_meas, iso_multiplier)
444      except OSError as e:
445        logging.exception(
446            'Failed to load stats file stored at %s. Error message: %s',
447            stats_file_path,
448            e,
449        )
450
451  if round(iso) <= sens_max_meas:
452    # Wait until camera is repositioned for noise model calibration.
453    input(
454        f'\nPress <ENTER> after covering camera lense {cam.get_camera_name()} '
455        'with frosted glass diffuser, and facing lense at evenly illuminated'
456        ' surface.\n'
457    )
458    # Do AE to get a rough idea of where we are.
459    iso_ae, exp_ae, _, _, _ = cam.do_3a(
460        get_results=True, do_awb=False, do_af=False
461    )
462
463    # Underexpose to get more data for low signal levels.
464    auto_exposure_ns = iso_ae * exp_ae / bracket_factor
465    _check_auto_exposure_targets(
466        auto_exposure_ns,
467        sens_min,
468        sens_max_meas,
469        bracket_factor,
470        min_exposure_ns,
471        max_exposure_ns,
472    )
473
474  while round(iso) <= sens_max_meas:
475    req = capture_request_utils.manual_capture_request(
476        round(iso), min_exposure_ns, f_dist
477    )
478    cap = cam.do_capture(req, stats_config)
479    # Instead of raising an error when the sensitivity readback != requested
480    # use the readback value for calculations instead.
481    iso_cap = cap['metadata']['android.sensor.sensitivity']
482
483    # Different iso values may result in captures with the same iso_cap
484    # value, so skip this capture if it's redundant.
485    if iso_cap == pre_iso_cap:
486      logging.info(
487          'Skip current capture because of the same iso %d with the previous'
488          ' capture.',
489          iso_cap,
490      )
491      iso = get_next_iso(iso, sens_max_meas, iso_multiplier)
492      continue
493    pre_iso_cap = iso_cap
494
495    logging.info('Request ISO: %d, Capture ISO: %d.', iso, iso_cap)
496
497    for bracket in range(max_bracket):
498      # Get the exposure for this sensitivity and exposure time.
499      exposure_ns = round(math.pow(2, bracket) * auto_exposure_ns / iso)
500      exposure_ms = round(exposure_ns * 1.0e-6, 3)
501      logging.info('ISO: %d, exposure time: %.3f ms.', iso_cap, exposure_ms)
502      req = capture_request_utils.manual_capture_request(
503          iso_cap,
504          exposure_ns,
505          f_dist,
506      )
507      req['android.control.zoomRatio'] = zoom_ratio
508      cap = cam.do_capture(req, stats_config)
509
510      if is_debug_mode:
511        capture_path = os.path.join(
512            capture_folder, f'iso{iso_cap}_exposure{exposure_ns}ns.jpg'
513        )
514        crop_and_save_capture(cap, props, capture_path, num_tiles_crop)
515
516      mean_img, var_img = image_processing_utils.unpack_rawstats_capture(
517          cap, num_channels=num_channels
518      )
519      cfa_order = image_processing_utils.get_canonical_cfa_order(
520          props, is_quad_bayer
521      )
522
523      means, vars_ = crop_and_reorder_stats_images(
524          mean_img,
525          var_img,
526          num_tiles_crop,
527          cfa_order,
528      )
529      if is_debug_mode:
530        logging.info('Raw stats image size: %s', mean_img.shape)
531        logging.info('R plane means image size: %s', means[0].shape)
532        logging.info(
533            'means min: %.3f, median: %.3f, max: %.3f',
534            np.min(means), np.median(means), np.max(means),
535        )
536        logging.info(
537            'vars_ min: %.4f, median: %.4f, max: %.4f',
538            np.min(vars_), np.median(vars_), np.max(vars_),
539        )
540
541      black_levels = image_processing_utils.get_black_levels(
542          props,
543          cap['metadata'],
544          is_quad_bayer,
545      )
546
547      means, vars_ = filter_stats(
548          means,
549          vars_,
550          black_levels,
551          white_level,
552          max_signal_value,
553          is_remove_var_outliers,
554          outlier_median_abs_deviations,
555      )
556
557      iso_to_stats_dict[iso_cap].append((exposure_ms, means, vars_))
558
559    if stats_file_name:
560      with open(stats_file_path, 'wb+') as f:
561        pickle.dump(iso_to_stats_dict, f)
562    iso = get_next_iso(iso, sens_max_meas, iso_multiplier)
563
564  return iso_to_stats_dict
565
566
567def measure_linear_noise_models(
568    iso_to_stats_dict: Dict[int, List[Tuple[float, np.ndarray, np.ndarray]]],
569    color_planes: List[str],
570):
571  """Measures linear noise models.
572
573  This function measures linear noise models from means and variances for each
574  color plane and ISO setting.
575
576  Args:
577      iso_to_stats_dict: A dictionary mapping ISO settings to a list of stats
578        data.
579      color_planes: A list of color planes.
580
581  Returns:
582      A tuple containing:
583          measured_models: A list of linear models, one for each color plane.
584          samples: A list of samples, one for each color plane. Each sample is a
585              tuple of (iso, mean, var).
586  """
587  num_planes = len(color_planes)
588  # Model parameters for each color plane.
589  measured_models = [[] for _ in range(num_planes)]
590  # Samples (ISO, mean and var) of each quad Bayer color channels.
591  samples = [[] for _ in range(num_planes)]
592
593  for iso in sorted(iso_to_stats_dict.keys()):
594    logging.info('Calculating measured models for ISO %d.', iso)
595    stats_per_plane = [[] for _ in range(num_planes)]
596    for _, means, vars_ in iso_to_stats_dict[iso]:
597      for pidx in range(num_planes):
598        means_p = means[pidx]
599        vars_p = vars_[pidx]
600        if means_p.size > 0 and vars_p.size > 0:
601          stats_per_plane[pidx].extend(list(zip(means_p, vars_p)))
602
603    for pidx, mean_var_pairs in enumerate(stats_per_plane):
604      if not mean_var_pairs:
605        raise ValueError(
606            f'For ISO {iso}, samples are empty in color plane'
607            f' {color_planes[pidx]}.'
608        )
609      slope, intercept, rvalue, _, _ = scipy.stats.linregress(mean_var_pairs)
610
611      measured_models[pidx].append((iso, slope, intercept))
612      logging.info(
613          (
614              'Measured model for ISO %d and color plane %s: '
615              'y = %e * x + %e (R=%.6f).'
616          ),
617          iso, color_planes[pidx], slope, intercept, rvalue,
618      )
619
620      # Add the samples for this sensitivity to the global samples list.
621      samples[pidx].extend([(iso, mean, var) for (mean, var) in mean_var_pairs])
622
623  return measured_models, samples
624
625
626def compute_noise_model(
627    samples: List[List[Tuple[float, np.ndarray, np.ndarray]]],
628    sens_max_analog: int,
629    offset_a: np.ndarray,
630    offset_b: np.ndarray,
631    is_two_stage_model: bool = False,
632) -> np.ndarray:
633  """Computes noise model parameters from samples.
634
635  The noise model is defined by the following equation:
636    f(x) = scale * x + offset
637
638  where we have:
639    scale = scale_a * analog_gain * digital_gain + scale_b,
640    offset = (offset_a * analog_gain^2 + offset_b) * digital_gain^2.
641    scale is the multiplicative factor and offset is the offset term.
642
643  Assume digital_gain is 1.0 and scale_a, scale_b, offset_a, offset_b are
644  sa, sb, oa, ob respectively, so we have noise model function:
645  f(x) = (sa * analog_gain + sb) * x + (oa * analog_gain^2 + ob).
646
647  The noise model is fit to the mesuared data using the scipy.optimize
648  function, which uses an iterative Levenberg-Marquardt algorithm to
649  find the model parameters that minimize the mean squared error.
650
651  Args:
652    samples: A list of samples, each of which is a list of tuples of `(gains,
653      means, vars_)`.
654    sens_max_analog: The maximum analog gain.
655    offset_a: The gradient coefficients from the read noise calibration.
656    offset_b: The intercept coefficients from the read noise calibration.
657    is_two_stage_model: A boolean flag indicating if the noise model is
658      calibrated in the two-stage mode.
659
660  Returns:
661    A numpy array containing noise model parameters (scale_a, scale_b,
662    offset_a, offset_b) of each channel.
663  """
664  noise_model = []
665  for pidx, samples_p in enumerate(samples):
666    gains, means, vars_ = zip(*samples_p)
667    gains = np.asarray(gains).flatten()
668    means = np.asarray(means).flatten()
669    vars_ = np.asarray(vars_).flatten()
670
671    compute_digital_gains(gains, sens_max_analog)
672
673    # Use a global linear optimization to fit the noise model.
674    # Noise model function:
675    # f(x) = scale * x + offset
676    # Where:
677    # scale = scale_a * analog_gain * digital_gain + scale_b.
678    # offset = (offset_a * analog_gain^2 + offset_b) * digital_gain^2.
679    # Function f will be used to train the scale and offset coefficients
680    # scale_a, scale_b, offset_a, offset_b.
681    if is_two_stage_model:
682      # For the two-stage model, we want to use the line fit coefficients
683      # found from capturing read noise data (offset_a and offset_b) to
684      # train the scale coefficients.
685      oa, ob = offset_a[pidx], offset_b[pidx]
686
687      # Cannot pass oa and ob as the parameters of f since we only want
688      # curve_fit return 2 parameters.
689      def f(x, sa, sb):
690        scale = sa * x[0] + sb
691        # pylint: disable=cell-var-from-loop
692        offset = oa * x[0] ** 2 + ob
693        return (scale * x[1] + offset) / x[0]
694
695    else:
696      def f(x, sa, sb, oa, ob):
697        scale = sa * x[0] + sb
698        offset = oa * x[0] ** 2 + ob
699        return (scale * x[1] + offset) / x[0]
700
701    # Divide the whole system by gains*means.
702    coeffs, _ = scipy.optimize.curve_fit(f, (gains, means), vars_ / (gains))
703
704    # If using two-stage model, two of the coefficients calculated above are
705    # constant, so we need to append them to the coeffs ndarray.
706    if is_two_stage_model:
707      coeffs = np.append(coeffs, offset_a[pidx])
708      coeffs = np.append(coeffs, offset_b[pidx])
709
710    # coeffs[0:4] = (scale_a, scale_b, offset_a, offset_b).
711    noise_model.append(coeffs[0:4])
712
713  noise_model = np.asarray(noise_model)
714  check_noise_model_shape(noise_model)
715  return noise_model
716
717
718def create_stats_figure(
719    iso: int,
720    color_channel_names: List[str],
721):
722  """Creates a figure with subplots showing the mean and variance samples.
723
724  Args:
725    iso: The ISO setting for the images.
726    color_channel_names: A list of strings containing the names of the color
727      channels.
728
729  Returns:
730    A tuple of the figure and a list of the subplots.
731  """
732  if len(color_channel_names) not in noise_model_constants.VALID_NUM_CHANNELS:
733    raise AssertionError(
734        'The number of channels should be in'
735        f' {noise_model_constants.VALID_NUM_CHANNELS}, but found'
736        f' {len(color_channel_names)}. '
737    )
738
739  is_quad_bayer = (
740      len(color_channel_names) == noise_model_constants.NUM_QUAD_BAYER_CHANNELS
741  )
742  if is_quad_bayer:
743    # Adds a plot of the mean and variance samples for each color plane.
744    fig, axes = plt.subplots(4, 4, figsize=(22, 22))
745    fig.gca()
746    fig.suptitle('ISO %d' % iso, x=0.52, y=0.99)
747
748    cax = fig.add_axes([0.65, 0.995, 0.33, 0.003])
749    cax.set_title('log(exposure_ms):', x=-0.13, y=-2.0)
750    fig.colorbar(
751        noise_model_constants.COLOR_BAR, cax=cax, orientation='horizontal'
752    )
753
754    # Add a big axis, hide frame.
755    fig.add_subplot(111, frameon=False)
756
757    # Add a common x-axis and y-axis.
758    plt.tick_params(
759        labelcolor='none',
760        which='both',
761        top=False,
762        bottom=False,
763        left=False,
764        right=False,
765    )
766    plt.xlabel('Mean signal level', ha='center')
767    plt.ylabel('Variance', va='center', rotation='vertical')
768
769    subplots = []
770    for pidx in range(noise_model_constants.NUM_QUAD_BAYER_CHANNELS):
771      subplot = axes[pidx // 4, pidx % 4]
772      subplot.set_title(color_channel_names[pidx])
773      # Set 'y' axis to scientific notation for all numbers by setting
774      # scilimits to (0, 0).
775      subplot.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
776      subplots.append(subplot)
777
778  else:
779    # Adds a plot of the mean and variance samples for each color plane.
780    fig, [[plt_r, plt_gr], [plt_gb, plt_b]] = plt.subplots(
781        2, 2, figsize=(11, 11)
782    )
783    fig.gca()
784    # Add color bar to show exposure times.
785    cax = fig.add_axes([0.73, 0.99, 0.25, 0.01])
786    cax.set_title('log(exposure_ms):', x=-0.3, y=-1.0)
787    fig.colorbar(
788        noise_model_constants.COLOR_BAR, cax=cax, orientation='horizontal'
789    )
790
791    subplots = [plt_r, plt_gr, plt_gb, plt_b]
792    fig.suptitle('ISO %d' % iso, x=0.54, y=0.99)
793    for pidx, subplot in enumerate(subplots):
794      subplot.set_title(color_channel_names[pidx])
795      subplot.set_xlabel('Mean signal level')
796      subplot.set_ylabel('Variance')
797      subplot.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
798
799  with warnings.catch_warnings():
800    warnings.simplefilter('ignore', UserWarning)
801    pylab.tight_layout()
802
803  return fig, subplots
804