1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Resampling dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.data.experimental.ops import interleave_ops 23from tensorflow.python.data.experimental.ops import scan_ops 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import logging_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import random_ops 33from tensorflow.python.util.tf_export import tf_export 34 35 36@tf_export("data.experimental.rejection_resample") 37def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): 38 """A transformation that resamples a dataset to achieve a target distribution. 39 40 **NOTE** Resampling is performed via rejection sampling; some fraction 41 of the input values will be dropped. 42 43 Args: 44 class_func: A function mapping an element of the input dataset to a scalar 45 `tf.int32` tensor. Values should be in `[0, num_classes)`. 46 target_dist: A floating point type tensor, shaped `[num_classes]`. 47 initial_dist: (Optional.) A floating point type tensor, shaped 48 `[num_classes]`. If not provided, the true class distribution is 49 estimated live in a streaming fashion. 50 seed: (Optional.) Python integer seed for the resampler. 51 52 Returns: 53 A `Dataset` transformation function, which can be passed to 54 `tf.data.Dataset.apply`. 55 """ 56 def _apply_fn(dataset): 57 """Function from `Dataset` to `Dataset` that applies the transformation.""" 58 target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") 59 60 # Get initial distribution. 61 if initial_dist is not None: 62 initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") 63 acceptance_dist, prob_of_original = ( 64 _calculate_acceptance_probs_with_mixing(initial_dist_t, 65 target_dist_t)) 66 initial_dist_ds = dataset_ops.Dataset.from_tensors( 67 initial_dist_t).repeat() 68 acceptance_dist_ds = dataset_ops.Dataset.from_tensors( 69 acceptance_dist).repeat() 70 prob_of_original_ds = dataset_ops.Dataset.from_tensors( 71 prob_of_original).repeat() 72 else: 73 initial_dist_ds = _estimate_initial_dist_ds(target_dist_t, 74 dataset.map(class_func)) 75 acceptance_and_original_prob_ds = initial_dist_ds.map( 76 lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda 77 initial, target_dist_t)) 78 acceptance_dist_ds = acceptance_and_original_prob_ds.map( 79 lambda accept_prob, _: accept_prob) 80 prob_of_original_ds = acceptance_and_original_prob_ds.map( 81 lambda _, prob_original: prob_original) 82 filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, 83 class_func, seed) 84 # Prefetch filtered dataset for speed. 85 filtered_ds = filtered_ds.prefetch(3) 86 87 prob_original_static = _get_prob_original_static( 88 initial_dist_t, target_dist_t) if initial_dist is not None else None 89 90 def add_class_value(*x): 91 if len(x) == 1: 92 return class_func(*x), x[0] 93 else: 94 return class_func(*x), x 95 96 if prob_original_static == 1: 97 return dataset.map(add_class_value) 98 elif prob_original_static == 0: 99 return filtered_ds 100 else: 101 return interleave_ops.sample_from_datasets( 102 [dataset.map(add_class_value), filtered_ds], 103 weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]), 104 seed=seed) 105 106 return _apply_fn 107 108 109def _get_prob_original_static(initial_dist_t, target_dist_t): 110 """Returns the static probability of sampling from the original. 111 112 `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters 113 an Op that it isn't defined for. We have some custom logic to avoid this. 114 115 Args: 116 initial_dist_t: A tensor of the initial distribution. 117 target_dist_t: A tensor of the target distribution. 118 119 Returns: 120 The probability of sampling from the original distribution as a constant, 121 if it is a constant, or `None`. 122 """ 123 init_static = tensor_util.constant_value(initial_dist_t) 124 target_static = tensor_util.constant_value(target_dist_t) 125 126 if init_static is None or target_static is None: 127 return None 128 else: 129 return np.min(target_static / init_static) 130 131 132def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_func, seed): 133 """Filters a dataset based on per-class acceptance probabilities. 134 135 Args: 136 dataset: The dataset to be filtered. 137 acceptance_dist_ds: A dataset of acceptance probabilities. 138 initial_dist_ds: A dataset of the initial probability distribution, given or 139 estimated. 140 class_func: A function mapping an element of the input dataset to a scalar 141 `tf.int32` tensor. Values should be in `[0, num_classes)`. 142 seed: (Optional.) Python integer seed for the resampler. 143 144 Returns: 145 A dataset of (class value, data) after filtering. 146 """ 147 def maybe_warn_on_large_rejection(accept_dist, initial_dist): 148 proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist) 149 return control_flow_ops.cond( 150 math_ops.less(proportion_rejected, .5), 151 lambda: accept_dist, 152 lambda: logging_ops.Print( # pylint: disable=g-long-lambda 153 accept_dist, [proportion_rejected, initial_dist, accept_dist], 154 message="Proportion of examples rejected by sampler is high: ", 155 summarize=100, 156 first_n=10)) 157 158 acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, 159 initial_dist_ds)) 160 .map(maybe_warn_on_large_rejection)) 161 162 def _gather_and_copy(acceptance_prob, data): 163 if isinstance(data, tuple): 164 class_val = class_func(*data) 165 else: 166 class_val = class_func(data) 167 return class_val, array_ops.gather(acceptance_prob, class_val), data 168 169 current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( 170 (acceptance_dist_ds, dataset)).map(_gather_and_copy) 171 filtered_ds = ( 172 current_probabilities_and_class_and_data_ds.filter( 173 lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) 174 return filtered_ds.map(lambda class_value, _, data: (class_value, data)) 175 176 177def _estimate_initial_dist_ds( 178 target_dist_t, class_values_ds, dist_estimation_batch_size=32, 179 smoothing_constant=10): 180 num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0]) 181 initial_examples_per_class_seen = array_ops.fill( 182 [num_classes], np.int64(smoothing_constant)) 183 184 def update_estimate_and_tile(num_examples_per_class_seen, c): 185 updated_examples_per_class_seen, dist = _estimate_data_distribution( 186 c, num_examples_per_class_seen) 187 tiled_dist = array_ops.tile( 188 array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) 189 return updated_examples_per_class_seen, tiled_dist 190 191 initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) 192 .apply(scan_ops.scan(initial_examples_per_class_seen, 193 update_estimate_and_tile)) 194 .unbatch()) 195 196 return initial_dist_ds 197 198 199def _get_target_to_initial_ratio(initial_probs, target_probs): 200 # Add tiny to initial_probs to avoid divide by zero. 201 denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) 202 return target_probs / denom 203 204 205def _estimate_data_distribution(c, num_examples_per_class_seen): 206 """Estimate data distribution as labels are seen. 207 208 Args: 209 c: The class labels. Type `int32`, shape `[batch_size]`. 210 num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, 211 containing counts. 212 213 Returns: 214 num_examples_per_lass_seen: Updated counts. Type `int64`, shape 215 `[num_classes]`. 216 dist: The updated distribution. Type `float32`, shape `[num_classes]`. 217 """ 218 num_classes = num_examples_per_class_seen.get_shape()[0] 219 # Update the class-count based on what labels are seen in batch. 220 num_examples_per_class_seen = math_ops.add( 221 num_examples_per_class_seen, math_ops.reduce_sum( 222 array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) 223 init_prob_estimate = math_ops.truediv( 224 num_examples_per_class_seen, 225 math_ops.reduce_sum(num_examples_per_class_seen)) 226 dist = math_ops.cast(init_prob_estimate, dtypes.float32) 227 return num_examples_per_class_seen, dist 228 229 230def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): 231 """Calculates the acceptance probabilities and mixing ratio. 232 233 In this case, we assume that we can *either* sample from the original data 234 distribution with probability `m`, or sample from a reshaped distribution 235 that comes from rejection sampling on the original distribution. This 236 rejection sampling is done on a per-class basis, with `a_i` representing the 237 probability of accepting data from class `i`. 238 239 This method is based on solving the following analysis for the reshaped 240 distribution: 241 242 Let F be the probability of a rejection (on any example). 243 Let p_i be the proportion of examples in the data in class i (init_probs) 244 Let a_i is the rate the rejection sampler should *accept* class i 245 Let t_i is the target proportion in the minibatches for class i (target_probs) 246 247 ``` 248 F = sum_i(p_i * (1-a_i)) 249 = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1 250 ``` 251 252 An example with class `i` will be accepted if `k` rejections occur, then an 253 example with class `i` is seen by the rejector, and it is accepted. This can 254 be written as follows: 255 256 ``` 257 t_i = sum_k=0^inf(F^k * p_i * a_i) 258 = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1 259 = p_i * a_i / sum_j(p_j * a_j) using F from above 260 ``` 261 262 Note that the following constraints hold: 263 ``` 264 0 <= p_i <= 1, sum_i(p_i) = 1 265 0 <= a_i <= 1 266 0 <= t_i <= 1, sum_i(t_i) = 1 267 ``` 268 269 A solution for a_i in terms of the other variables is the following: 270 ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` 271 272 If we try to minimize the amount of data rejected, we get the following: 273 274 M_max = max_i [ t_i / p_i ] 275 M_min = min_i [ t_i / p_i ] 276 277 The desired probability of accepting data if it comes from class `i`: 278 279 a_i = (t_i/p_i - m) / (M_max - m) 280 281 The desired probability of pulling a data element from the original dataset, 282 rather than the filtered one: 283 284 m = M_min 285 286 Args: 287 initial_probs: A Tensor of the initial probability distribution, given or 288 estimated. 289 target_probs: A Tensor of the corresponding classes. 290 291 Returns: 292 (A 1D Tensor with the per-class acceptance probabilities, the desired 293 probability of pull from the original distribution.) 294 """ 295 ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs) 296 max_ratio = math_ops.reduce_max(ratio_l) 297 min_ratio = math_ops.reduce_min(ratio_l) 298 299 # Target prob to sample from original distribution. 300 m = min_ratio 301 302 # TODO(joelshor): Simplify fraction, if possible. 303 a_i = (ratio_l - m) / (max_ratio - m) 304 return a_i, m 305