1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Generic utils similar to those from the C++ header <algorithm>.
18 
19 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
20 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
21 
22 #include <algorithm>
23 #include <queue>
24 #include <vector>
25 
26 namespace libtextclassifier3 {
27 namespace mobile {
28 
29 // Returns index of max element from the vector |elements|.  Returns 0 if
30 // |elements| is empty.  T should be a type that can be compared by operator<.
31 template<typename T>
GetArgMax(const std::vector<T> & elements)32 inline int GetArgMax(const std::vector<T> &elements) {
33   return std::distance(
34       elements.begin(),
35       std::max_element(elements.begin(), elements.end()));
36 }
37 
38 // Returns index of min element from the vector |elements|.  Returns 0 if
39 // |elements| is empty.  T should be a type that can be compared by operator<.
40 template<typename T>
GetArgMin(const std::vector<T> & elements)41 inline int GetArgMin(const std::vector<T> &elements) {
42   return std::distance(
43       elements.begin(),
44       std::min_element(elements.begin(), elements.end()));
45 }
46 
47 // Returns indices of greatest k elements from |v|.
48 //
49 // The order between elements is indicated by |smaller|, which should be an
50 // object like std::less<T>, std::greater<T>, etc.  If smaller(a, b) is true,
51 // that means that "a is smaller than b".  Intuitively, |smaller| is a
52 // generalization of operator<.  Formally, it is a strict weak ordering, see
53 // https://en.cppreference.com/w/cpp/named_req/Compare
54 //
55 // Calling this function with std::less<T>() returns the indices of the larger k
56 // elements; calling it with std::greater<T>() returns the indices of the
57 // smallest k elements.  This is similar to e.g., std::priority_queue: using the
58 // default std::less gives you a max-heap, while using std::greater results in a
59 // min-heap.
60 //
61 // Returned indices are sorted in decreasing order of the corresponding elements
62 // (e.g., first element of the returned array is the index of the largest
63 // element).  In case of ties (e.g., equal elements) we select the one with the
64 // smallest index.  E.g., getting the indices of the top-2 elements from [3, 2,
65 // 1, 3, 0, 3] returns [0, 3] (the indices of the first and the second 3).
66 //
67 // Corner cases: If k <= 0, this function returns an empty vector.  If |v| has
68 // only n < k elements, this function returns all n indices [0, 1, 2, ..., n -
69 // 1], sorted according to the comp order of the indicated elements.
70 //
71 // Assuming each comparison is O(1), this function uses O(k) auxiliary space,
72 // and runs in O(n * log k) time.  Note: it is possible to use std::nth_element
73 // and obtain an O(n + k * log k) time algorithm, but that uses O(n) auxiliary
74 // space.  In our case, k << n, e.g., we may want to select the top-3 most
75 // likely classes from a set of 100 classes, so the time complexity difference
76 // should not matter in practice.
77 template <typename T, typename Smaller>
GetTopKIndices(int k,const std::vector<T> & v,Smaller smaller)78 std::vector<int> GetTopKIndices(int k, const std::vector<T> &v,
79                                 Smaller smaller) {
80   if (k <= 0) {
81     return std::vector<int>();
82   }
83 
84   if (k > v.size()) {
85     k = v.size();
86   }
87 
88   // An order between indices.  Intuitively, rev_vcomp(i1, i2) iff v[i2] is
89   // smaller than v[i1].  No typo: this inversion is necessary for Invariant B
90   // below.  "vcomp" stands for "value comparator" (we compare the values
91   // indicates by the two indices) and "rev_" stands for the reverse order.
92   const auto rev_vcomp = [&v, &smaller](int i1, int i2) -> bool {
93     if (smaller(v[i2], v[i1])) return true;
94     if (smaller(v[i1], v[i2])) return false;
95 
96     // Break ties in favor of earlier elements.
97     return i1 < i2;
98   };
99 
100   // Indices of the top-k elements seen so far.
101   std::vector<int> heap(k);
102 
103   // First, we fill |heap| with the first k indices.
104   for (int i = 0; i < k; ++i) {
105     heap[i] = i;
106   }
107   std::make_heap(heap.begin(), heap.end(), rev_vcomp);
108 
109   // Next, we explore the rest of the vector v.  Loop invariants:
110   //
111   // Invariant A: |heap| contains the indices of the top-k elements from v[0:i].
112   //
113   // Invariant B: heap[0] is the index of the smallest element from all elements
114   // indicated by the indices from |heap|.
115   //
116   // Invariant C: |heap| is a max heap, according to order rev_vcomp.
117   for (int i = k; i < v.size(); ++i) {
118     // We have to update |heap| iff v[i] is larger than the smallest of the
119     // top-k seen so far.  This test is easy to do, due to Invariant B above.
120     if (smaller(v[heap[0]], v[i])) {
121       // Next lines replace heap[0] with i and re-"heapify" heap[0:k-1].
122       heap.push_back(i);
123       std::pop_heap(heap.begin(), heap.end(), rev_vcomp);
124       heap.pop_back();
125     }
126   }
127 
128   // Arrange indices from |heap| in decreasing order of corresponding elements.
129   //
130   // More info: in iteration #0, we extract the largest heap element (according
131   // to rev_vcomp, i.e., the index of the smallest of the top-k elements) and
132   // place it at the end of heap, i.e., in heap[k-1].  In iteration #1, we
133   // extract the second largest and place it in heap[k-2], etc.
134   for (int i = 0; i < k; ++i) {
135     std::pop_heap(heap.begin(), heap.end() - i, rev_vcomp);
136   }
137   return heap;
138 }
139 
140 template <typename T>
GetTopKIndices(int k,const std::vector<T> & elements)141 std::vector<int> GetTopKIndices(int k, const std::vector<T> &elements) {
142   return GetTopKIndices(k, elements, std::less<T>());
143 }
144 
145 }  // namespace mobile
146 }  // namespace nlp_saft
147 
148 #endif  // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
149