1 // Copyright 2018 Developers of the Rand project.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 
9 //! Weighted index sampling
10 
11 use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
12 use crate::distributions::Distribution;
13 use crate::Rng;
14 use core::cmp::PartialOrd;
15 use core::fmt;
16 
17 // Note that this whole module is only imported if feature="alloc" is enabled.
18 use alloc::vec::Vec;
19 
20 #[cfg(feature = "serde1")]
21 use serde::{Serialize, Deserialize};
22 
23 /// A distribution using weighted sampling of discrete items
24 ///
25 /// Sampling a `WeightedIndex` distribution returns the index of a randomly
26 /// selected element from the iterator used when the `WeightedIndex` was
27 /// created. The chance of a given element being picked is proportional to the
28 /// value of the element. The weights can use any type `X` for which an
29 /// implementation of [`Uniform<X>`] exists.
30 ///
31 /// # Performance
32 ///
33 /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
34 /// `N` is the number of weights. As an alternative,
35 /// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html)
36 /// supports `O(1)` sampling, but with much higher initialisation cost.
37 ///
38 /// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
39 /// size is the sum of the size of those objects, possibly plus some alignment.
40 ///
41 /// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
42 /// weights of type `X`, where `N` is the number of weights. However, since
43 /// `Vec` doesn't guarantee a particular growth strategy, additional memory
44 /// might be allocated but not used. Since the `WeightedIndex` object also
45 /// contains, this might cause additional allocations, though for primitive
46 /// types, [`Uniform<X>`] doesn't allocate any memory.
47 ///
48 /// Sampling from `WeightedIndex` will result in a single call to
49 /// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
50 /// will request a single value from the underlying [`RngCore`], though the
51 /// exact number depends on the implementation of `Uniform<X>::sample`.
52 ///
53 /// # Example
54 ///
55 /// ```
56 /// use rand::prelude::*;
57 /// use rand::distributions::WeightedIndex;
58 ///
59 /// let choices = ['a', 'b', 'c'];
60 /// let weights = [2,   1,   1];
61 /// let dist = WeightedIndex::new(&weights).unwrap();
62 /// let mut rng = thread_rng();
63 /// for _ in 0..100 {
64 ///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
65 ///     println!("{}", choices[dist.sample(&mut rng)]);
66 /// }
67 ///
68 /// let items = [('a', 0), ('b', 3), ('c', 7)];
69 /// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
70 /// for _ in 0..100 {
71 ///     // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
72 ///     println!("{}", items[dist2.sample(&mut rng)].0);
73 /// }
74 /// ```
75 ///
76 /// [`Uniform<X>`]: crate::distributions::Uniform
77 /// [`RngCore`]: crate::RngCore
78 #[derive(Debug, Clone)]
79 #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
80 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
81 pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
82     cumulative_weights: Vec<X>,
83     total_weight: X,
84     weight_distribution: X::Sampler,
85 }
86 
87 impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
88     /// Creates a new a `WeightedIndex` [`Distribution`] using the values
89     /// in `weights`. The weights can use any type `X` for which an
90     /// implementation of [`Uniform<X>`] exists.
91     ///
92     /// Returns an error if the iterator is empty, if any weight is `< 0`, or
93     /// if its total value is 0.
94     ///
95     /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError> where I: IntoIterator, I::Item: SampleBorrow<X>, X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,96     pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
97     where
98         I: IntoIterator,
99         I::Item: SampleBorrow<X>,
100         X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
101     {
102         let mut iter = weights.into_iter();
103         let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
104 
105         let zero = <X as Default>::default();
106         if !(total_weight >= zero) {
107             return Err(WeightedError::InvalidWeight);
108         }
109 
110         let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
111         for w in iter {
112             // Note that `!(w >= x)` is not equivalent to `w < x` for partially
113             // ordered types due to NaNs which are equal to nothing.
114             if !(w.borrow() >= &zero) {
115                 return Err(WeightedError::InvalidWeight);
116             }
117             weights.push(total_weight.clone());
118             total_weight += w.borrow();
119         }
120 
121         if total_weight == zero {
122             return Err(WeightedError::AllWeightsZero);
123         }
124         let distr = X::Sampler::new(zero, total_weight.clone());
125 
126         Ok(WeightedIndex {
127             cumulative_weights: weights,
128             total_weight,
129             weight_distribution: distr,
130         })
131     }
132 
133     /// Update a subset of weights, without changing the number of weights.
134     ///
135     /// `new_weights` must be sorted by the index.
136     ///
137     /// Using this method instead of `new` might be more efficient if only a small number of
138     /// weights is modified. No allocations are performed, unless the weight type `X` uses
139     /// allocation internally.
140     ///
141     /// In case of error, `self` is not modified.
update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> where X: for<'a> ::core::ops::AddAssign<&'a X> + for<'a> ::core::ops::SubAssign<&'a X> + Clone + Default142     pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
143     where X: for<'a> ::core::ops::AddAssign<&'a X>
144             + for<'a> ::core::ops::SubAssign<&'a X>
145             + Clone
146             + Default {
147         if new_weights.is_empty() {
148             return Ok(());
149         }
150 
151         let zero = <X as Default>::default();
152 
153         let mut total_weight = self.total_weight.clone();
154 
155         // Check for errors first, so we don't modify `self` in case something
156         // goes wrong.
157         let mut prev_i = None;
158         for &(i, w) in new_weights {
159             if let Some(old_i) = prev_i {
160                 if old_i >= i {
161                     return Err(WeightedError::InvalidWeight);
162                 }
163             }
164             if !(*w >= zero) {
165                 return Err(WeightedError::InvalidWeight);
166             }
167             if i > self.cumulative_weights.len() {
168                 return Err(WeightedError::TooMany);
169             }
170 
171             let mut old_w = if i < self.cumulative_weights.len() {
172                 self.cumulative_weights[i].clone()
173             } else {
174                 self.total_weight.clone()
175             };
176             if i > 0 {
177                 old_w -= &self.cumulative_weights[i - 1];
178             }
179 
180             total_weight -= &old_w;
181             total_weight += w;
182             prev_i = Some(i);
183         }
184         if total_weight <= zero {
185             return Err(WeightedError::AllWeightsZero);
186         }
187 
188         // Update the weights. Because we checked all the preconditions in the
189         // previous loop, this should never panic.
190         let mut iter = new_weights.iter();
191 
192         let mut prev_weight = zero.clone();
193         let mut next_new_weight = iter.next();
194         let &(first_new_index, _) = next_new_weight.unwrap();
195         let mut cumulative_weight = if first_new_index > 0 {
196             self.cumulative_weights[first_new_index - 1].clone()
197         } else {
198             zero.clone()
199         };
200         for i in first_new_index..self.cumulative_weights.len() {
201             match next_new_weight {
202                 Some(&(j, w)) if i == j => {
203                     cumulative_weight += w;
204                     next_new_weight = iter.next();
205                 }
206                 _ => {
207                     let mut tmp = self.cumulative_weights[i].clone();
208                     tmp -= &prev_weight; // We know this is positive.
209                     cumulative_weight += &tmp;
210                 }
211             }
212             prev_weight = cumulative_weight.clone();
213             core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
214         }
215 
216         self.total_weight = total_weight;
217         self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());
218 
219         Ok(())
220     }
221 }
222 
223 impl<X> Distribution<usize> for WeightedIndex<X>
224 where X: SampleUniform + PartialOrd
225 {
sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize226     fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
227         use ::core::cmp::Ordering;
228         let chosen_weight = self.weight_distribution.sample(rng);
229         // Find the first item which has a weight *higher* than the chosen weight.
230         self.cumulative_weights
231             .binary_search_by(|w| {
232                 if *w <= chosen_weight {
233                     Ordering::Less
234                 } else {
235                     Ordering::Greater
236                 }
237             })
238             .unwrap_err()
239     }
240 }
241 
242 #[cfg(test)]
243 mod test {
244     use super::*;
245 
246     #[cfg(feature = "serde1")]
247     #[test]
test_weightedindex_serde1()248     fn test_weightedindex_serde1() {
249         let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
250 
251         let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
252         let de_weighted_index: WeightedIndex<i32> =
253             bincode::deserialize(&ser_weighted_index).unwrap();
254 
255         assert_eq!(
256             de_weighted_index.cumulative_weights,
257             weighted_index.cumulative_weights
258         );
259         assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
260     }
261 
262     #[test]
test_accepting_nan()263     fn test_accepting_nan(){
264         assert_eq!(
265             WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(),
266             WeightedError::InvalidWeight,
267         );
268         assert_eq!(
269             WeightedIndex::new(&[core::f32::NAN]).unwrap_err(),
270             WeightedError::InvalidWeight,
271         );
272         assert_eq!(
273             WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(),
274             WeightedError::InvalidWeight,
275         );
276 
277         assert_eq!(
278             WeightedIndex::new(&[0.5, 7.0])
279                 .unwrap()
280                 .update_weights(&[(0, &core::f32::NAN)])
281                 .unwrap_err(),
282             WeightedError::InvalidWeight,
283         )
284     }
285 
286 
287     #[test]
288     #[cfg_attr(miri, ignore)] // Miri is too slow
test_weightedindex()289     fn test_weightedindex() {
290         let mut r = crate::test::rng(700);
291         const N_REPS: u32 = 5000;
292         let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
293         let total_weight = weights.iter().sum::<u32>() as f32;
294 
295         let verify = |result: [i32; 14]| {
296             for (i, count) in result.iter().enumerate() {
297                 let exp = (weights[i] * N_REPS) as f32 / total_weight;
298                 let mut err = (*count as f32 - exp).abs();
299                 if err != 0.0 {
300                     err /= exp;
301                 }
302                 assert!(err <= 0.25);
303             }
304         };
305 
306         // WeightedIndex from vec
307         let mut chosen = [0i32; 14];
308         let distr = WeightedIndex::new(weights.to_vec()).unwrap();
309         for _ in 0..N_REPS {
310             chosen[distr.sample(&mut r)] += 1;
311         }
312         verify(chosen);
313 
314         // WeightedIndex from slice
315         chosen = [0i32; 14];
316         let distr = WeightedIndex::new(&weights[..]).unwrap();
317         for _ in 0..N_REPS {
318             chosen[distr.sample(&mut r)] += 1;
319         }
320         verify(chosen);
321 
322         // WeightedIndex from iterator
323         chosen = [0i32; 14];
324         let distr = WeightedIndex::new(weights.iter()).unwrap();
325         for _ in 0..N_REPS {
326             chosen[distr.sample(&mut r)] += 1;
327         }
328         verify(chosen);
329 
330         for _ in 0..5 {
331             assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
332             assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
333             assert_eq!(
334                 WeightedIndex::new(&[0, 0, 0, 0, 10, 0])
335                     .unwrap()
336                     .sample(&mut r),
337                 4
338             );
339         }
340 
341         assert_eq!(
342             WeightedIndex::new(&[10][0..0]).unwrap_err(),
343             WeightedError::NoItem
344         );
345         assert_eq!(
346             WeightedIndex::new(&[0]).unwrap_err(),
347             WeightedError::AllWeightsZero
348         );
349         assert_eq!(
350             WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(),
351             WeightedError::InvalidWeight
352         );
353         assert_eq!(
354             WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(),
355             WeightedError::InvalidWeight
356         );
357         assert_eq!(
358             WeightedIndex::new(&[-10]).unwrap_err(),
359             WeightedError::InvalidWeight
360         );
361     }
362 
363     #[test]
test_update_weights()364     fn test_update_weights() {
365         let data = [
366             (
367                 &[10u32, 2, 3, 4][..],
368                 &[(1, &100), (2, &4)][..], // positive change
369                 &[10, 100, 4, 4][..],
370             ),
371             (
372                 &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
373                 &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
374                 &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
375             ),
376         ];
377 
378         for (weights, update, expected_weights) in data.iter() {
379             let total_weight = weights.iter().sum::<u32>();
380             let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
381             assert_eq!(distr.total_weight, total_weight);
382 
383             distr.update_weights(update).unwrap();
384             let expected_total_weight = expected_weights.iter().sum::<u32>();
385             let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
386             assert_eq!(distr.total_weight, expected_total_weight);
387             assert_eq!(distr.total_weight, expected_distr.total_weight);
388             assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
389         }
390     }
391 
392     #[test]
value_stability()393     fn value_stability() {
394         fn test_samples<X: SampleUniform + PartialOrd, I>(
395             weights: I, buf: &mut [usize], expected: &[usize],
396         ) where
397             I: IntoIterator,
398             I::Item: SampleBorrow<X>,
399             X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
400         {
401             assert_eq!(buf.len(), expected.len());
402             let distr = WeightedIndex::new(weights).unwrap();
403             let mut rng = crate::test::rng(701);
404             for r in buf.iter_mut() {
405                 *r = rng.sample(&distr);
406             }
407             assert_eq!(buf, expected);
408         }
409 
410         let mut buf = [0; 10];
411         test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[
412             0, 6, 2, 6, 3, 4, 7, 8, 2, 5,
413         ]);
414         test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[
415             0, 0, 0, 1, 0, 0, 2, 3, 0, 0,
416         ]);
417         test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[
418             2, 2, 1, 3, 2, 1, 3, 3, 2, 1,
419         ]);
420     }
421 }
422 
423 /// Error type returned from `WeightedIndex::new`.
424 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
425 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
426 pub enum WeightedError {
427     /// The provided weight collection contains no items.
428     NoItem,
429 
430     /// A weight is either less than zero, greater than the supported maximum,
431     /// NaN, or otherwise invalid.
432     InvalidWeight,
433 
434     /// All items in the provided weight collection are zero.
435     AllWeightsZero,
436 
437     /// Too many weights are provided (length greater than `u32::MAX`)
438     TooMany,
439 }
440 
441 #[cfg(feature = "std")]
442 impl ::std::error::Error for WeightedError {}
443 
444 impl fmt::Display for WeightedError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result445     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
446         match *self {
447             WeightedError::NoItem => write!(f, "No weights provided."),
448             WeightedError::InvalidWeight => write!(f, "A weight is invalid."),
449             WeightedError::AllWeightsZero => write!(f, "All weights are zero."),
450             WeightedError::TooMany => write!(f, "Too many weights (hit u32::MAX)"),
451         }
452     }
453 }
454