1# Copyright 2014 The Android Open Source Project. 2 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Noise model utility functions.""" 15 16import collections 17import logging 18import math 19import os.path 20import pickle 21from typing import Any, Dict, List, Tuple 22import warnings 23import capture_request_utils 24import image_processing_utils 25from matplotlib import pylab 26import matplotlib.pyplot as plt 27import noise_model_constants 28import numpy as np 29import scipy.stats 30 31 32_OUTLIER_MEDIAN_ABS_DEVS_DEFAULT = ( 33 noise_model_constants.OUTLIER_MEDIAN_ABS_DEVS_DEFAULT 34) 35 36 37def _check_auto_exposure_targets( 38 auto_exposure_ns: float, 39 sens_min: int, 40 sens_max: int, 41 bracket_factor: int, 42 min_exposure_ns: int, 43 max_exposure_ns: int, 44) -> None: 45 """Checks if AE too bright for highest gain & too dark for lowest gain. 46 47 Args: 48 auto_exposure_ns: The auto exposure value in nanoseconds. 49 sens_min: The minimum sensitivity value. 50 sens_max: The maximum sensitivity value. 51 bracket_factor: Exposure bracket factor. 52 min_exposure_ns: The minimum exposure time in nanoseconds. 53 max_exposure_ns: The maximum exposure time in nanoseconds. 54 """ 55 56 if auto_exposure_ns < min_exposure_ns * sens_max: 57 raise AssertionError( 58 'Scene is too bright to properly expose at highest ' 59 f'sensitivity: {sens_max}' 60 ) 61 if auto_exposure_ns * bracket_factor > max_exposure_ns * sens_min: 62 raise AssertionError( 63 'Scene is too dark to properly expose at lowest ' 64 f'sensitivity: {sens_min}' 65 ) 66 67 68def check_noise_model_shape(noise_model: np.ndarray) -> None: 69 """Checks if the shape of noise model is valid. 70 71 Args: 72 noise_model: A numpy array of shape (num_channels, num_parameters). 73 """ 74 num_channels, num_parameters = noise_model.shape 75 if num_channels not in noise_model_constants.VALID_NUM_CHANNELS: 76 raise AssertionError( 77 f'The number of channels {num_channels} is not in' 78 f' {noise_model_constants.VALID_NUM_CHANNELS}.' 79 ) 80 if num_parameters != 4: 81 raise AssertionError( 82 f'The number of parameters of each channel {num_parameters} != 4.' 83 ) 84 85 86def validate_noise_model( 87 noise_model: np.ndarray, 88 color_channels: List[str], 89 sens_min: int, 90) -> None: 91 """Performs validation checks on the noise model. 92 93 This function checks if read noise and intercept gradient are positive for 94 each color channel. 95 96 Args: 97 noise_model: Noise model parameters each channel, including scale_a, 98 scale_b, offset_a, offset_b. 99 color_channels: Array of color channels. 100 sens_min: Minimum sensitivity value. 101 """ 102 check_noise_model_shape(noise_model) 103 num_channels = noise_model.shape[0] 104 if len(color_channels) != num_channels: 105 raise AssertionError( 106 f'Number of color channels {num_channels} != number of noise model ' 107 f'channels {len(color_channels)}.' 108 ) 109 110 scale_a, _, offset_a, offset_b = zip(*noise_model) 111 for i, color_channel in enumerate(color_channels): 112 if scale_a[i] < 0: 113 raise AssertionError( 114 f'{color_channel} model API scale gradient < 0: {scale_a[i]:.4e}' 115 ) 116 117 if offset_a[i] <= 0: 118 raise AssertionError( 119 f'{color_channel} model API intercept gradient < 0: {offset_a[i]:.4e}' 120 ) 121 122 read_noise = offset_a[i] * sens_min * sens_min + offset_b[i] 123 if read_noise <= 0: 124 raise AssertionError( 125 f'{color_channel} model min ISO noise < 0! ' 126 f'API intercept gradient: {offset_a[i]:.4e}, ' 127 f'API intercept offset: {offset_b[i]:.4e}, ' 128 f'read_noise: {read_noise:.4e}' 129 ) 130 131 132def compute_digital_gains( 133 gains: np.ndarray, 134 sens_max_analog: np.ndarray, 135) -> np.ndarray: 136 """Computes the digital gains for the given gains and maximum analog gain. 137 138 Define digital gain as the gain divide the max analog gain sensitivity. 139 This function ensures that the digital gains are always equal to 1. If any 140 of the digital gains is not equal to 1, an AssertionError is raised. 141 142 Args: 143 gains: An array of gains. 144 sens_max_analog: The maximum analog gain sensitivity. 145 146 Returns: 147 An numpy array of digital gains. 148 """ 149 digital_gains = np.maximum(gains / sens_max_analog, 1) 150 if not np.all(digital_gains == 1): 151 raise AssertionError( 152 f'Digital gains are not all 1! gains: {gains}, ' 153 f'Max analog gain sensitivity: {sens_max_analog}.' 154 ) 155 return digital_gains 156 157 158def crop_and_save_capture( 159 cap, 160 props, 161 capture_path: str, 162 num_tiles_crop: int, 163) -> None: 164 """Crops and saves a capture image. 165 166 Args: 167 cap: The capture to be cropped and saved. 168 props: The properties to be used to convert the capture to an RGB image. 169 capture_path: The path to which the capture image should be saved. 170 num_tiles_crop: The number of tiles to crop. 171 """ 172 img = image_processing_utils.convert_capture_to_rgb_image(cap, props=props) 173 height, width, _ = img.shape 174 num_tiles_crop_max = min(height, width) // 2 175 if num_tiles_crop >= num_tiles_crop_max: 176 raise AssertionError( 177 f'Number of tiles to corp {num_tiles_crop} >= {num_tiles_crop_max}.' 178 ) 179 img = img[ 180 num_tiles_crop: height - num_tiles_crop, 181 num_tiles_crop: width - num_tiles_crop, 182 :, 183 ] 184 185 image_processing_utils.write_image(img, capture_path, True) 186 187 188def crop_and_reorder_stats_images( 189 mean_img: np.ndarray, 190 var_img: np.ndarray, 191 num_tiles_crop: int, 192 channel_indices: List[int], 193) -> Tuple[np.ndarray, np.ndarray]: 194 """Crops the stats images and sorts stats images channels in canonical order. 195 196 Args: 197 mean_img: The mean image. 198 var_img: The variance image. 199 num_tiles_crop: The number of tiles to crop from each side of the image. 200 channel_indices: The channel indices to sort stats image channels in 201 canonical order. 202 203 Returns: 204 The cropped and reordered mean image and variance image. 205 """ 206 if mean_img.shape != var_img.shape: 207 raise AssertionError( 208 'Unmatched shapes of mean and variance image: ' 209 f'shape of mean image is {mean_img.shape}, ' 210 f'shape of variance image is {var_img.shape}.' 211 ) 212 height, width, _ = mean_img.shape 213 if 2 * num_tiles_crop > min(height, width): 214 raise AssertionError( 215 f'The number of tiles to crop ({num_tiles_crop}) is so large that' 216 ' images cannot be cropped.' 217 ) 218 219 means = [] 220 vars_ = [] 221 for i in channel_indices: 222 means_i = mean_img[ 223 num_tiles_crop: height - num_tiles_crop, 224 num_tiles_crop: width - num_tiles_crop, 225 i, 226 ] 227 vars_i = var_img[ 228 num_tiles_crop: height - num_tiles_crop, 229 num_tiles_crop: width - num_tiles_crop, 230 i, 231 ] 232 means.append(means_i) 233 vars_.append(vars_i) 234 means, vars_ = np.asarray(means), np.asarray(vars_) 235 return means, vars_ 236 237 238def filter_stats( 239 means: np.ndarray, 240 vars_: np.ndarray, 241 black_levels: List[float], 242 white_level: float, 243 max_signal_value: float = 0.25, 244 is_remove_var_outliers: bool = False, 245 deviations: int = _OUTLIER_MEDIAN_ABS_DEVS_DEFAULT, 246) -> Tuple[np.ndarray, np.ndarray]: 247 """Filters means outliers and variance outliers. 248 249 Args: 250 means: A numpy ndarray of pixel mean values. 251 vars_: A numpy ndarray of pixel variance values. 252 black_levels: A list of black levels for each pixel. 253 white_level: A scalar white level. 254 max_signal_value: The maximum signal (mean) value. 255 is_remove_var_outliers: A boolean value indicating whether to remove 256 variance outliers. 257 deviations: A scalar value specifying the number of standard deviations to 258 use when removing variance outliers. 259 260 Returns: 261 A tuple of (means_filtered, vars_filtered) where means_filtered and 262 vars_filtered are numpy ndarrays of filtered pixel mean and variance 263 values, respectively. 264 """ 265 if means.shape != vars_.shape: 266 raise AssertionError( 267 f'Unmatched shapes of means and vars: means.shape={means.shape},' 268 f' vars.shape={vars_.shape}.' 269 ) 270 num_planes = len(means) 271 means_filtered = [] 272 vars_filtered = [] 273 274 for pidx in range(num_planes): 275 black_level = black_levels[pidx] 276 means_i = means[pidx] 277 vars_i = vars_[pidx] 278 279 # Basic constraints: 280 # (1) means are within the range [0, 1], 281 # (2) vars are non-negative values. 282 constraints = [ 283 means_i >= black_level, 284 means_i <= white_level, 285 vars_i >= 0, 286 ] 287 if is_remove_var_outliers: 288 # Filter out variances that differ too much from the median of variances. 289 std_dev = scipy.stats.median_abs_deviation(vars_i, axis=None, scale=1) 290 med = np.median(vars_i) 291 constraints.extend([ 292 vars_i > med - deviations * std_dev, 293 vars_i < med + deviations * std_dev, 294 ]) 295 296 keep_indices = np.where(np.logical_and.reduce(constraints)) 297 if not np.any(keep_indices): 298 logging.info('After filter channel %d, stats array is empty.', pidx) 299 300 # Normalizes the range to [0, 1]. 301 means_i = (means_i[keep_indices] - black_level) / ( 302 white_level - black_level 303 ) 304 vars_i = vars_i[keep_indices] / ((white_level - black_level) ** 2) 305 # Filter out the tiles if they have samples that might be clipped. 306 mean_var_pairs = list( 307 filter( 308 lambda x: x[0] + 2 * math.sqrt(x[1]) < max_signal_value, 309 zip(means_i, vars_i), 310 ) 311 ) 312 if mean_var_pairs: 313 means_i, vars_i = zip(*mean_var_pairs) 314 else: 315 means_i, vars_i = [], [] 316 means_i = np.asarray(means_i) 317 vars_i = np.asarray(vars_i) 318 means_filtered.append(means_i) 319 vars_filtered.append(vars_i) 320 321 # After filtering, means_filtered and vars_filtered may have different shapes 322 # in each color planes. 323 means_filtered = np.asarray(means_filtered, dtype=object) 324 vars_filtered = np.asarray(vars_filtered, dtype=object) 325 return means_filtered, vars_filtered 326 327 328def get_next_iso( 329 iso: float, 330 max_iso: int, 331 iso_multiplier: float, 332) -> float: 333 """Moves to the next sensitivity. 334 335 Args: 336 iso: The current ISO sensitivity. 337 max_iso: The maximum ISO sensitivity. 338 iso_multiplier: The ISO multiplier to use. 339 340 Returns: 341 The next ISO sensitivity. 342 """ 343 if iso_multiplier <= 1: 344 raise AssertionError( 345 f'ISO multiplier is {iso_multiplier}, which should be greater than 1.' 346 ) 347 348 if round(iso) < max_iso < round(iso * iso_multiplier): 349 return max_iso 350 else: 351 return iso * iso_multiplier 352 353 354def capture_stats_images( 355 cam, 356 props, 357 stats_config: Dict[str, Any], 358 sens_min: int, 359 sens_max_meas: int, 360 zoom_ratio: float, 361 num_tiles_crop: int, 362 max_signal_value: float, 363 iso_multiplier: float, 364 max_bracket: int, 365 bracket_factor: int, 366 capture_path_prefix: str, 367 stats_file_name: str = '', 368 is_remove_var_outliers: bool = False, 369 outlier_median_abs_deviations: int = _OUTLIER_MEDIAN_ABS_DEVS_DEFAULT, 370 is_debug_mode: bool = False, 371) -> Dict[int, List[Tuple[float, np.ndarray, np.ndarray]]]: 372 """Capture stats images and saves the stats in a dictionary. 373 374 This function captures stats images at different ISO values and exposure 375 times, and stores the stats data in a file with the specified name. 376 The stats data includes the mean and variance of each plane, as well as 377 exposure times. 378 379 Args: 380 cam: The camera session (its_session_utils.ItsSession) for capturing stats 381 images. 382 props: Camera property object. 383 stats_config: The stats format config, a dictionary that specifies the raw 384 stats image format and tile size. 385 sens_min: The minimum sensitivity. 386 sens_max_meas: The maximum sensitivity to measure. 387 zoom_ratio: The zoom ratio to use. 388 num_tiles_crop: The number of tiles to crop the images into. 389 max_signal_value: The maximum signal value to allow. 390 iso_multiplier: The ISO multiplier to use. 391 max_bracket: The maximum number of bracketed exposures to capture. 392 bracket_factor: The bracket factor with default value 2^max_bracket. 393 capture_path_prefix: The path prefix to use for captured images. 394 stats_file_name: The name of the file to save the stats images to. 395 is_remove_var_outliers: Whether to remove variance outliers. 396 outlier_median_abs_deviations: The number of median absolute deviations to 397 use for detecting outliers. 398 is_debug_mode: Whether to enable debug mode. 399 400 Returns: 401 A dictionary mapping ISO values to mean and variance image of each plane. 402 """ 403 if is_debug_mode: 404 logging.info('Capturing stats images with stats config: %s.', stats_config) 405 capture_folder = os.path.join(capture_path_prefix, 'captures') 406 if not os.path.exists(capture_folder): 407 os.makedirs(capture_folder) 408 logging.info('Capture folder: %s', capture_folder) 409 410 white_level = props['android.sensor.info.whiteLevel'] 411 min_exposure_ns, max_exposure_ns = props[ 412 'android.sensor.info.exposureTimeRange' 413 ] 414 # Focus at zero to intentionally blur the scene as much as possible. 415 f_dist = 0.0 416 # Whether the stats images are quad Bayer or standard Bayer. 417 is_quad_bayer = 'QuadBayer' in stats_config['format'] 418 if is_quad_bayer: 419 num_channels = noise_model_constants.NUM_QUAD_BAYER_CHANNELS 420 else: 421 num_channels = noise_model_constants.NUM_BAYER_CHANNELS 422 # A dict maps iso to stats images of different exposure times. 423 iso_to_stats_dict = collections.defaultdict(list) 424 # Start the sensitivity at the minimum. 425 iso = sens_min 426 # Previous iso cap. 427 pre_iso_cap = None 428 if stats_file_name: 429 stats_file_path = os.path.join(capture_path_prefix, stats_file_name) 430 if os.path.isfile(stats_file_path): 431 try: 432 with open(stats_file_path, 'rb') as f: 433 saved_iso_to_stats_dict = pickle.load(f) 434 # Filter saved stats data. 435 if saved_iso_to_stats_dict: 436 for iso, stats in saved_iso_to_stats_dict.items(): 437 if sens_min <= iso <= sens_max_meas: 438 iso_to_stats_dict[iso] = stats 439 440 # Set the starting iso to the last iso in saved stats file. 441 if iso_to_stats_dict.keys(): 442 pre_iso_cap = sorted(iso_to_stats_dict.keys())[-1] 443 iso = get_next_iso(pre_iso_cap, sens_max_meas, iso_multiplier) 444 except OSError as e: 445 logging.exception( 446 'Failed to load stats file stored at %s. Error message: %s', 447 stats_file_path, 448 e, 449 ) 450 451 if round(iso) <= sens_max_meas: 452 # Wait until camera is repositioned for noise model calibration. 453 input( 454 f'\nPress <ENTER> after covering camera lense {cam.get_camera_name()} ' 455 'with frosted glass diffuser, and facing lense at evenly illuminated' 456 ' surface.\n' 457 ) 458 # Do AE to get a rough idea of where we are. 459 iso_ae, exp_ae, _, _, _ = cam.do_3a( 460 get_results=True, do_awb=False, do_af=False 461 ) 462 463 # Underexpose to get more data for low signal levels. 464 auto_exposure_ns = iso_ae * exp_ae / bracket_factor 465 _check_auto_exposure_targets( 466 auto_exposure_ns, 467 sens_min, 468 sens_max_meas, 469 bracket_factor, 470 min_exposure_ns, 471 max_exposure_ns, 472 ) 473 474 while round(iso) <= sens_max_meas: 475 req = capture_request_utils.manual_capture_request( 476 round(iso), min_exposure_ns, f_dist 477 ) 478 cap = cam.do_capture(req, stats_config) 479 # Instead of raising an error when the sensitivity readback != requested 480 # use the readback value for calculations instead. 481 iso_cap = cap['metadata']['android.sensor.sensitivity'] 482 483 # Different iso values may result in captures with the same iso_cap 484 # value, so skip this capture if it's redundant. 485 if iso_cap == pre_iso_cap: 486 logging.info( 487 'Skip current capture because of the same iso %d with the previous' 488 ' capture.', 489 iso_cap, 490 ) 491 iso = get_next_iso(iso, sens_max_meas, iso_multiplier) 492 continue 493 pre_iso_cap = iso_cap 494 495 logging.info('Request ISO: %d, Capture ISO: %d.', iso, iso_cap) 496 497 for bracket in range(max_bracket): 498 # Get the exposure for this sensitivity and exposure time. 499 exposure_ns = round(math.pow(2, bracket) * auto_exposure_ns / iso) 500 exposure_ms = round(exposure_ns * 1.0e-6, 3) 501 logging.info('ISO: %d, exposure time: %.3f ms.', iso_cap, exposure_ms) 502 req = capture_request_utils.manual_capture_request( 503 iso_cap, 504 exposure_ns, 505 f_dist, 506 ) 507 req['android.control.zoomRatio'] = zoom_ratio 508 cap = cam.do_capture(req, stats_config) 509 510 if is_debug_mode: 511 capture_path = os.path.join( 512 capture_folder, f'iso{iso_cap}_exposure{exposure_ns}ns.jpg' 513 ) 514 crop_and_save_capture(cap, props, capture_path, num_tiles_crop) 515 516 mean_img, var_img = image_processing_utils.unpack_rawstats_capture( 517 cap, num_channels=num_channels 518 ) 519 cfa_order = image_processing_utils.get_canonical_cfa_order( 520 props, is_quad_bayer 521 ) 522 523 means, vars_ = crop_and_reorder_stats_images( 524 mean_img, 525 var_img, 526 num_tiles_crop, 527 cfa_order, 528 ) 529 if is_debug_mode: 530 logging.info('Raw stats image size: %s', mean_img.shape) 531 logging.info('R plane means image size: %s', means[0].shape) 532 logging.info( 533 'means min: %.3f, median: %.3f, max: %.3f', 534 np.min(means), np.median(means), np.max(means), 535 ) 536 logging.info( 537 'vars_ min: %.4f, median: %.4f, max: %.4f', 538 np.min(vars_), np.median(vars_), np.max(vars_), 539 ) 540 541 black_levels = image_processing_utils.get_black_levels( 542 props, 543 cap['metadata'], 544 is_quad_bayer, 545 ) 546 547 means, vars_ = filter_stats( 548 means, 549 vars_, 550 black_levels, 551 white_level, 552 max_signal_value, 553 is_remove_var_outliers, 554 outlier_median_abs_deviations, 555 ) 556 557 iso_to_stats_dict[iso_cap].append((exposure_ms, means, vars_)) 558 559 if stats_file_name: 560 with open(stats_file_path, 'wb+') as f: 561 pickle.dump(iso_to_stats_dict, f) 562 iso = get_next_iso(iso, sens_max_meas, iso_multiplier) 563 564 return iso_to_stats_dict 565 566 567def measure_linear_noise_models( 568 iso_to_stats_dict: Dict[int, List[Tuple[float, np.ndarray, np.ndarray]]], 569 color_planes: List[str], 570): 571 """Measures linear noise models. 572 573 This function measures linear noise models from means and variances for each 574 color plane and ISO setting. 575 576 Args: 577 iso_to_stats_dict: A dictionary mapping ISO settings to a list of stats 578 data. 579 color_planes: A list of color planes. 580 581 Returns: 582 A tuple containing: 583 measured_models: A list of linear models, one for each color plane. 584 samples: A list of samples, one for each color plane. Each sample is a 585 tuple of (iso, mean, var). 586 """ 587 num_planes = len(color_planes) 588 # Model parameters for each color plane. 589 measured_models = [[] for _ in range(num_planes)] 590 # Samples (ISO, mean and var) of each quad Bayer color channels. 591 samples = [[] for _ in range(num_planes)] 592 593 for iso in sorted(iso_to_stats_dict.keys()): 594 logging.info('Calculating measured models for ISO %d.', iso) 595 stats_per_plane = [[] for _ in range(num_planes)] 596 for _, means, vars_ in iso_to_stats_dict[iso]: 597 for pidx in range(num_planes): 598 means_p = means[pidx] 599 vars_p = vars_[pidx] 600 if means_p.size > 0 and vars_p.size > 0: 601 stats_per_plane[pidx].extend(list(zip(means_p, vars_p))) 602 603 for pidx, mean_var_pairs in enumerate(stats_per_plane): 604 if not mean_var_pairs: 605 raise ValueError( 606 f'For ISO {iso}, samples are empty in color plane' 607 f' {color_planes[pidx]}.' 608 ) 609 slope, intercept, rvalue, _, _ = scipy.stats.linregress(mean_var_pairs) 610 611 measured_models[pidx].append((iso, slope, intercept)) 612 logging.info( 613 ( 614 'Measured model for ISO %d and color plane %s: ' 615 'y = %e * x + %e (R=%.6f).' 616 ), 617 iso, color_planes[pidx], slope, intercept, rvalue, 618 ) 619 620 # Add the samples for this sensitivity to the global samples list. 621 samples[pidx].extend([(iso, mean, var) for (mean, var) in mean_var_pairs]) 622 623 return measured_models, samples 624 625 626def compute_noise_model( 627 samples: List[List[Tuple[float, np.ndarray, np.ndarray]]], 628 sens_max_analog: int, 629 offset_a: np.ndarray, 630 offset_b: np.ndarray, 631 is_two_stage_model: bool = False, 632) -> np.ndarray: 633 """Computes noise model parameters from samples. 634 635 The noise model is defined by the following equation: 636 f(x) = scale * x + offset 637 638 where we have: 639 scale = scale_a * analog_gain * digital_gain + scale_b, 640 offset = (offset_a * analog_gain^2 + offset_b) * digital_gain^2. 641 scale is the multiplicative factor and offset is the offset term. 642 643 Assume digital_gain is 1.0 and scale_a, scale_b, offset_a, offset_b are 644 sa, sb, oa, ob respectively, so we have noise model function: 645 f(x) = (sa * analog_gain + sb) * x + (oa * analog_gain^2 + ob). 646 647 The noise model is fit to the mesuared data using the scipy.optimize 648 function, which uses an iterative Levenberg-Marquardt algorithm to 649 find the model parameters that minimize the mean squared error. 650 651 Args: 652 samples: A list of samples, each of which is a list of tuples of `(gains, 653 means, vars_)`. 654 sens_max_analog: The maximum analog gain. 655 offset_a: The gradient coefficients from the read noise calibration. 656 offset_b: The intercept coefficients from the read noise calibration. 657 is_two_stage_model: A boolean flag indicating if the noise model is 658 calibrated in the two-stage mode. 659 660 Returns: 661 A numpy array containing noise model parameters (scale_a, scale_b, 662 offset_a, offset_b) of each channel. 663 """ 664 noise_model = [] 665 for pidx, samples_p in enumerate(samples): 666 gains, means, vars_ = zip(*samples_p) 667 gains = np.asarray(gains).flatten() 668 means = np.asarray(means).flatten() 669 vars_ = np.asarray(vars_).flatten() 670 671 compute_digital_gains(gains, sens_max_analog) 672 673 # Use a global linear optimization to fit the noise model. 674 # Noise model function: 675 # f(x) = scale * x + offset 676 # Where: 677 # scale = scale_a * analog_gain * digital_gain + scale_b. 678 # offset = (offset_a * analog_gain^2 + offset_b) * digital_gain^2. 679 # Function f will be used to train the scale and offset coefficients 680 # scale_a, scale_b, offset_a, offset_b. 681 if is_two_stage_model: 682 # For the two-stage model, we want to use the line fit coefficients 683 # found from capturing read noise data (offset_a and offset_b) to 684 # train the scale coefficients. 685 oa, ob = offset_a[pidx], offset_b[pidx] 686 687 # Cannot pass oa and ob as the parameters of f since we only want 688 # curve_fit return 2 parameters. 689 def f(x, sa, sb): 690 scale = sa * x[0] + sb 691 # pylint: disable=cell-var-from-loop 692 offset = oa * x[0] ** 2 + ob 693 return (scale * x[1] + offset) / x[0] 694 695 else: 696 def f(x, sa, sb, oa, ob): 697 scale = sa * x[0] + sb 698 offset = oa * x[0] ** 2 + ob 699 return (scale * x[1] + offset) / x[0] 700 701 # Divide the whole system by gains*means. 702 coeffs, _ = scipy.optimize.curve_fit(f, (gains, means), vars_ / (gains)) 703 704 # If using two-stage model, two of the coefficients calculated above are 705 # constant, so we need to append them to the coeffs ndarray. 706 if is_two_stage_model: 707 coeffs = np.append(coeffs, offset_a[pidx]) 708 coeffs = np.append(coeffs, offset_b[pidx]) 709 710 # coeffs[0:4] = (scale_a, scale_b, offset_a, offset_b). 711 noise_model.append(coeffs[0:4]) 712 713 noise_model = np.asarray(noise_model) 714 check_noise_model_shape(noise_model) 715 return noise_model 716 717 718def create_stats_figure( 719 iso: int, 720 color_channel_names: List[str], 721): 722 """Creates a figure with subplots showing the mean and variance samples. 723 724 Args: 725 iso: The ISO setting for the images. 726 color_channel_names: A list of strings containing the names of the color 727 channels. 728 729 Returns: 730 A tuple of the figure and a list of the subplots. 731 """ 732 if len(color_channel_names) not in noise_model_constants.VALID_NUM_CHANNELS: 733 raise AssertionError( 734 'The number of channels should be in' 735 f' {noise_model_constants.VALID_NUM_CHANNELS}, but found' 736 f' {len(color_channel_names)}. ' 737 ) 738 739 is_quad_bayer = ( 740 len(color_channel_names) == noise_model_constants.NUM_QUAD_BAYER_CHANNELS 741 ) 742 if is_quad_bayer: 743 # Adds a plot of the mean and variance samples for each color plane. 744 fig, axes = plt.subplots(4, 4, figsize=(22, 22)) 745 fig.gca() 746 fig.suptitle('ISO %d' % iso, x=0.52, y=0.99) 747 748 cax = fig.add_axes([0.65, 0.995, 0.33, 0.003]) 749 cax.set_title('log(exposure_ms):', x=-0.13, y=-2.0) 750 fig.colorbar( 751 noise_model_constants.COLOR_BAR, cax=cax, orientation='horizontal' 752 ) 753 754 # Add a big axis, hide frame. 755 fig.add_subplot(111, frameon=False) 756 757 # Add a common x-axis and y-axis. 758 plt.tick_params( 759 labelcolor='none', 760 which='both', 761 top=False, 762 bottom=False, 763 left=False, 764 right=False, 765 ) 766 plt.xlabel('Mean signal level', ha='center') 767 plt.ylabel('Variance', va='center', rotation='vertical') 768 769 subplots = [] 770 for pidx in range(noise_model_constants.NUM_QUAD_BAYER_CHANNELS): 771 subplot = axes[pidx // 4, pidx % 4] 772 subplot.set_title(color_channel_names[pidx]) 773 # Set 'y' axis to scientific notation for all numbers by setting 774 # scilimits to (0, 0). 775 subplot.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) 776 subplots.append(subplot) 777 778 else: 779 # Adds a plot of the mean and variance samples for each color plane. 780 fig, [[plt_r, plt_gr], [plt_gb, plt_b]] = plt.subplots( 781 2, 2, figsize=(11, 11) 782 ) 783 fig.gca() 784 # Add color bar to show exposure times. 785 cax = fig.add_axes([0.73, 0.99, 0.25, 0.01]) 786 cax.set_title('log(exposure_ms):', x=-0.3, y=-1.0) 787 fig.colorbar( 788 noise_model_constants.COLOR_BAR, cax=cax, orientation='horizontal' 789 ) 790 791 subplots = [plt_r, plt_gr, plt_gb, plt_b] 792 fig.suptitle('ISO %d' % iso, x=0.54, y=0.99) 793 for pidx, subplot in enumerate(subplots): 794 subplot.set_title(color_channel_names[pidx]) 795 subplot.set_xlabel('Mean signal level') 796 subplot.set_ylabel('Variance') 797 subplot.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) 798 799 with warnings.catch_warnings(): 800 warnings.simplefilter('ignore', UserWarning) 801 pylab.tight_layout() 802 803 return fig, subplots 804