1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Resampling dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.data.experimental.ops import interleave_ops
23from tensorflow.python.data.experimental.ops import scan_ops
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import logging_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import random_ops
33from tensorflow.python.util.tf_export import tf_export
34
35
36@tf_export("data.experimental.rejection_resample")
37def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
38  """A transformation that resamples a dataset to achieve a target distribution.
39
40  **NOTE** Resampling is performed via rejection sampling; some fraction
41  of the input values will be dropped.
42
43  Args:
44    class_func: A function mapping an element of the input dataset to a scalar
45      `tf.int32` tensor. Values should be in `[0, num_classes)`.
46    target_dist: A floating point type tensor, shaped `[num_classes]`.
47    initial_dist: (Optional.)  A floating point type tensor, shaped
48      `[num_classes]`.  If not provided, the true class distribution is
49      estimated live in a streaming fashion.
50    seed: (Optional.) Python integer seed for the resampler.
51
52  Returns:
53    A `Dataset` transformation function, which can be passed to
54    `tf.data.Dataset.apply`.
55  """
56  def _apply_fn(dataset):
57    """Function from `Dataset` to `Dataset` that applies the transformation."""
58    target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
59
60    # Get initial distribution.
61    if initial_dist is not None:
62      initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
63      acceptance_dist, prob_of_original = (
64          _calculate_acceptance_probs_with_mixing(initial_dist_t,
65                                                  target_dist_t))
66      initial_dist_ds = dataset_ops.Dataset.from_tensors(
67          initial_dist_t).repeat()
68      acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
69          acceptance_dist).repeat()
70      prob_of_original_ds = dataset_ops.Dataset.from_tensors(
71          prob_of_original).repeat()
72    else:
73      initial_dist_ds = _estimate_initial_dist_ds(target_dist_t,
74                                                  dataset.map(class_func))
75      acceptance_and_original_prob_ds = initial_dist_ds.map(
76          lambda initial: _calculate_acceptance_probs_with_mixing(  # pylint: disable=g-long-lambda
77              initial, target_dist_t))
78      acceptance_dist_ds = acceptance_and_original_prob_ds.map(
79          lambda accept_prob, _: accept_prob)
80      prob_of_original_ds = acceptance_and_original_prob_ds.map(
81          lambda _, prob_original: prob_original)
82    filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
83                             class_func, seed)
84    # Prefetch filtered dataset for speed.
85    filtered_ds = filtered_ds.prefetch(3)
86
87    prob_original_static = _get_prob_original_static(
88        initial_dist_t, target_dist_t) if initial_dist is not None else None
89
90    def add_class_value(*x):
91      if len(x) == 1:
92        return class_func(*x), x[0]
93      else:
94        return class_func(*x), x
95
96    if prob_original_static == 1:
97      return dataset.map(add_class_value)
98    elif prob_original_static == 0:
99      return filtered_ds
100    else:
101      return interleave_ops.sample_from_datasets(
102          [dataset.map(add_class_value), filtered_ds],
103          weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
104          seed=seed)
105
106  return _apply_fn
107
108
109def _get_prob_original_static(initial_dist_t, target_dist_t):
110  """Returns the static probability of sampling from the original.
111
112  `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
113  an Op that it isn't defined for. We have some custom logic to avoid this.
114
115  Args:
116    initial_dist_t: A tensor of the initial distribution.
117    target_dist_t: A tensor of the target distribution.
118
119  Returns:
120    The probability of sampling from the original distribution as a constant,
121    if it is a constant, or `None`.
122  """
123  init_static = tensor_util.constant_value(initial_dist_t)
124  target_static = tensor_util.constant_value(target_dist_t)
125
126  if init_static is None or target_static is None:
127    return None
128  else:
129    return np.min(target_static / init_static)
130
131
132def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_func, seed):
133  """Filters a dataset based on per-class acceptance probabilities.
134
135  Args:
136    dataset: The dataset to be filtered.
137    acceptance_dist_ds: A dataset of acceptance probabilities.
138    initial_dist_ds: A dataset of the initial probability distribution, given or
139        estimated.
140    class_func: A function mapping an element of the input dataset to a scalar
141      `tf.int32` tensor. Values should be in `[0, num_classes)`.
142    seed: (Optional.) Python integer seed for the resampler.
143
144  Returns:
145    A dataset of (class value, data) after filtering.
146  """
147  def maybe_warn_on_large_rejection(accept_dist, initial_dist):
148    proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
149    return control_flow_ops.cond(
150        math_ops.less(proportion_rejected, .5),
151        lambda: accept_dist,
152        lambda: logging_ops.Print(  # pylint: disable=g-long-lambda
153            accept_dist, [proportion_rejected, initial_dist, accept_dist],
154            message="Proportion of examples rejected by sampler is high: ",
155            summarize=100,
156            first_n=10))
157
158  acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds,
159                                                 initial_dist_ds))
160                        .map(maybe_warn_on_large_rejection))
161
162  def _gather_and_copy(acceptance_prob, data):
163    if isinstance(data, tuple):
164      class_val = class_func(*data)
165    else:
166      class_val = class_func(data)
167    return class_val, array_ops.gather(acceptance_prob, class_val), data
168
169  current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
170      (acceptance_dist_ds, dataset)).map(_gather_and_copy)
171  filtered_ds = (
172      current_probabilities_and_class_and_data_ds.filter(
173          lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
174  return filtered_ds.map(lambda class_value, _, data: (class_value, data))
175
176
177def _estimate_initial_dist_ds(
178    target_dist_t, class_values_ds, dist_estimation_batch_size=32,
179    smoothing_constant=10):
180  num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0])
181  initial_examples_per_class_seen = array_ops.fill(
182      [num_classes], np.int64(smoothing_constant))
183
184  def update_estimate_and_tile(num_examples_per_class_seen, c):
185    updated_examples_per_class_seen, dist = _estimate_data_distribution(
186        c, num_examples_per_class_seen)
187    tiled_dist = array_ops.tile(
188        array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
189    return updated_examples_per_class_seen, tiled_dist
190
191  initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
192                     .apply(scan_ops.scan(initial_examples_per_class_seen,
193                                          update_estimate_and_tile))
194                     .unbatch())
195
196  return initial_dist_ds
197
198
199def _get_target_to_initial_ratio(initial_probs, target_probs):
200  # Add tiny to initial_probs to avoid divide by zero.
201  denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
202  return target_probs / denom
203
204
205def _estimate_data_distribution(c, num_examples_per_class_seen):
206  """Estimate data distribution as labels are seen.
207
208  Args:
209    c: The class labels.  Type `int32`, shape `[batch_size]`.
210    num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
211      containing counts.
212
213  Returns:
214    num_examples_per_lass_seen: Updated counts.  Type `int64`, shape
215      `[num_classes]`.
216    dist: The updated distribution.  Type `float32`, shape `[num_classes]`.
217  """
218  num_classes = num_examples_per_class_seen.get_shape()[0]
219  # Update the class-count based on what labels are seen in batch.
220  num_examples_per_class_seen = math_ops.add(
221      num_examples_per_class_seen, math_ops.reduce_sum(
222          array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
223  init_prob_estimate = math_ops.truediv(
224      num_examples_per_class_seen,
225      math_ops.reduce_sum(num_examples_per_class_seen))
226  dist = math_ops.cast(init_prob_estimate, dtypes.float32)
227  return num_examples_per_class_seen, dist
228
229
230def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
231  """Calculates the acceptance probabilities and mixing ratio.
232
233  In this case, we assume that we can *either* sample from the original data
234  distribution with probability `m`, or sample from a reshaped distribution
235  that comes from rejection sampling on the original distribution. This
236  rejection sampling is done on a per-class basis, with `a_i` representing the
237  probability of accepting data from class `i`.
238
239  This method is based on solving the following analysis for the reshaped
240  distribution:
241
242  Let F be the probability of a rejection (on any example).
243  Let p_i be the proportion of examples in the data in class i (init_probs)
244  Let a_i is the rate the rejection sampler should *accept* class i
245  Let t_i is the target proportion in the minibatches for class i (target_probs)
246
247  ```
248  F = sum_i(p_i * (1-a_i))
249    = 1 - sum_i(p_i * a_i)     using sum_i(p_i) = 1
250  ```
251
252  An example with class `i` will be accepted if `k` rejections occur, then an
253  example with class `i` is seen by the rejector, and it is accepted. This can
254  be written as follows:
255
256  ```
257  t_i = sum_k=0^inf(F^k * p_i * a_i)
258      = p_i * a_j / (1 - F)    using geometric series identity, since 0 <= F < 1
259      = p_i * a_i / sum_j(p_j * a_j)        using F from above
260  ```
261
262  Note that the following constraints hold:
263  ```
264  0 <= p_i <= 1, sum_i(p_i) = 1
265  0 <= a_i <= 1
266  0 <= t_i <= 1, sum_i(t_i) = 1
267  ```
268
269  A solution for a_i in terms of the other variables is the following:
270    ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
271
272  If we try to minimize the amount of data rejected, we get the following:
273
274  M_max = max_i [ t_i / p_i ]
275  M_min = min_i [ t_i / p_i ]
276
277  The desired probability of accepting data if it comes from class `i`:
278
279  a_i = (t_i/p_i - m) / (M_max - m)
280
281  The desired probability of pulling a data element from the original dataset,
282  rather than the filtered one:
283
284  m = M_min
285
286  Args:
287    initial_probs: A Tensor of the initial probability distribution, given or
288      estimated.
289    target_probs: A Tensor of the corresponding classes.
290
291  Returns:
292    (A 1D Tensor with the per-class acceptance probabilities, the desired
293    probability of pull from the original distribution.)
294  """
295  ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
296  max_ratio = math_ops.reduce_max(ratio_l)
297  min_ratio = math_ops.reduce_min(ratio_l)
298
299  # Target prob to sample from original distribution.
300  m = min_ratio
301
302  # TODO(joelshor): Simplify fraction, if possible.
303  a_i = (ratio_l - m) / (max_ratio - m)
304  return a_i, m
305