1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/math/softmax.h"
18 
19 #include <limits>
20 
21 #include "utils/base/logging.h"
22 #include "utils/math/fastexp.h"
23 
24 namespace libtextclassifier3 {
25 
ComputeSoftmaxProbability(const std::vector<float> & scores,int label)26 float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) {
27   if ((label < 0) || (label >= scores.size())) {
28     TC3_LOG(ERROR) << "label " << label << " outside range "
29                    << "[0, " << scores.size() << ")";
30     return 0.0f;
31   }
32 
33   // Standard softmax formula for label's probability is
34   //
35   //   exp(scores[label]) / sum_i exp(scores[i])
36   //
37   // We compute the mathematically equivalent
38   //
39   //   1 / (1 + sum_{i != label} exp(scores[i] - scores[label]))
40   //
41   // which saves two calls to exp().
42   const float label_score = scores[label];
43   float denominator = 1.0f;  // Contribution of i == label.
44   for (int i = 0; i < scores.size(); ++i) {
45     if (i == label) continue;
46     const float delta_score = scores[i] - label_score;
47 
48     // TODO(salcianu): one can optimize the test below, to avoid any float
49     // operation: extract exponent (via bit mask + shift) and check it's >= 4.
50     if (fabs(delta_score) >= 16.0f) {
51       if (delta_score > 0.0f) {
52         // If delta_score >= 16, the denominator (e^delta_score + other positive
53         // terms) is very big and its inverse can be approximated with 0.
54         return 0.0f;
55       } else {
56         // If delta_score <= -16, then e^delta_score < 1.2e-7.  Even if we have
57         // 1000 such labels i, their sum is < 1.2e-4 (which gets summed with
58         // 1.0f for i == label).  Hence, we can approximate each such label with
59         // 0 and skip the call to VeryFastExp and the update to denominator.
60         continue;
61       }
62     }
63 
64     // At this point, delta_score is in (-16.0, 16.0).  For such values, vfexp
65     // works fine: no under/overflows (we have tests for that in fastexp_test).
66     // Also, even for 1000 labels, denominator will not overflow.
67     denominator += VeryFastExp(delta_score);
68   }
69   return 1.0f / denominator;
70 }
71 
ComputeSoftmax(const std::vector<float> & scores)72 std::vector<float> ComputeSoftmax(const std::vector<float> &scores) {
73   return ComputeSoftmax(scores.data(), scores.size());
74 }
75 
ComputeSoftmax(const float * scores,int scores_size)76 std::vector<float> ComputeSoftmax(const float *scores, int scores_size) {
77   std::vector<float> softmax;
78   std::vector<float> exp_scores;
79   exp_scores.reserve(scores_size);
80   softmax.reserve(scores_size);
81 
82   // Find max value in "scores" vector and rescale to avoid overflows.
83   float max = std::numeric_limits<float>::min();
84   for (int i = 0; i < scores_size; ++i) {
85     const float score = scores[i];
86     if (score > max) max = score;
87   }
88   float denominator = 0;
89   for (int i = 0; i < scores_size; ++i) {
90     const float score = scores[i];
91     // See comments above in ComputeSoftmaxProbability for the reasoning behind
92     // this approximation.
93     const float exp_score = score - max < -16.0f ? 0 : VeryFastExp(score - max);
94     exp_scores.push_back(exp_score);
95     denominator += exp_score;
96   }
97 
98   for (int i = 0; i < scores_size; ++i) {
99     softmax.push_back(exp_scores[i] / denominator);
100   }
101   return softmax;
102 }
103 
104 }  // namespace libtextclassifier3
105