1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // LINT.IfChange
16 
17 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
18 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
19 
20 #include <algorithm>
21 #include <cmath>
22 #include <limits>
23 #include <memory>
24 #include <vector>
25 
26 #include "third_party/eigen3/Eigen/Core"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/top_n.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/ctc/ctc_beam_entry.h"
34 #include "tensorflow/core/util/ctc/ctc_beam_scorer.h"
35 #include "tensorflow/core/util/ctc/ctc_decoder.h"
36 #include "tensorflow/core/util/ctc/ctc_loss_util.h"
37 
38 namespace tensorflow {
39 namespace ctc {
40 
41 template <typename T, typename CTCBeamState = ctc_beam_search::EmptyBeamState,
42           typename CTCBeamComparer =
43               ctc_beam_search::BeamComparer<T, CTCBeamState>>
44 class CTCBeamSearchDecoder : public CTCDecoder<T> {
45   // Beam Search
46   //
47   // Example (GravesTh Fig. 7.5):
48   //         a    -
49   //  P = [ 0.3  0.7 ]  t = 0
50   //      [ 0.4  0.6 ]  t = 1
51   //
52   // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42
53   //      P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58
54   //
55   // In this case, Best Path decoding is suboptimal.
56   //
57   // For Beam Search, we use the following main recurrence relations:
58   //
59   // Relation 1:
60   // ---------------------------------------------------------- Eq. 1
61   //      P(l=abcd @ t=7) = P(l=abc  @ t=6) * P(d @ 7)
62   //                      + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7))
63   // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and
64   // updated recursively in the beam entry.
65   //
66   // Relation 2:
67   // ---------------------------------------------------------- Eq. 2
68   //      P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3)
69   // for ? in a, b, d, ..., (not including c or the blank index),
70   // and the recurrence starts from the beam entry for P(l=abc @ t=2).
71   //
72   // For this case, the length of the new sequence equals t+1 (t
73   // starts at 0).  This special case can be calculated as:
74   //   P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
75   // but we calculate it recursively for speed purposes.
76   typedef ctc_beam_search::BeamEntry<T, CTCBeamState> BeamEntry;
77   typedef ctc_beam_search::BeamRoot<T, CTCBeamState> BeamRoot;
78   typedef ctc_beam_search::BeamProbability<T> BeamProbability;
79 
80  public:
81   typedef BaseBeamScorer<T, CTCBeamState> DefaultBeamScorer;
82 
83   // The beam search decoder is constructed specifying the beam_width (number of
84   // candidates to keep at each decoding timestep) and a beam scorer (used for
85   // custom scoring, for example enabling the use of a language model).
86   // The ownership of the scorer remains with the caller. The default
87   // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the
88   // standard beam search.
89   CTCBeamSearchDecoder(int num_classes, int beam_width,
90                        BaseBeamScorer<T, CTCBeamState>* scorer,
91                        int batch_size = 1, bool merge_repeated = false)
92       : CTCDecoder<T>(num_classes, batch_size, merge_repeated),
93         beam_width_(beam_width),
94         leaves_(beam_width),
95         beam_scorer_(CHECK_NOTNULL(scorer)) {
96     Reset();
97   }
98 
~CTCBeamSearchDecoder()99   ~CTCBeamSearchDecoder() override {}
100 
101   // Run the hibernating beam search algorithm on the given input.
102   Status Decode(const typename CTCDecoder<T>::SequenceLength& seq_len,
103                 const std::vector<typename CTCDecoder<T>::Input>& input,
104                 std::vector<typename CTCDecoder<T>::Output>* output,
105                 typename CTCDecoder<T>::ScoreOutput* scores) override;
106 
107   // Calculate the next step of the beam search and update the internal state.
108   template <typename Vector>
109   void Step(const Vector& log_input_t);
110 
111   template <typename Vector>
112   T GetTopK(const int K, const Vector& input, std::vector<T>* top_k_logits,
113             std::vector<int>* top_k_indices);
114 
115   // Retrieve the beam scorer instance used during decoding.
GetBeamScorer()116   BaseBeamScorer<T, CTCBeamState>* GetBeamScorer() const {
117     return beam_scorer_;
118   }
119 
120   // Set label selection parameters for faster decoding.
121   // See comments for label_selection_size_ and label_selection_margin_.
SetLabelSelectionParameters(int label_selection_size,T label_selection_margin)122   void SetLabelSelectionParameters(int label_selection_size,
123                                    T label_selection_margin) {
124     label_selection_size_ = label_selection_size;
125     label_selection_margin_ = label_selection_margin;
126   }
127 
128   // Reset the beam search
129   void Reset();
130 
131   // Extract the top n paths at current time step
132   Status TopPaths(int n, std::vector<std::vector<int>>* paths,
133                   std::vector<T>* log_probs, bool merge_repeated) const;
134 
135  private:
136   int beam_width_;
137 
138   // Label selection is designed to avoid possibly very expensive scorer calls,
139   // by pruning the hypotheses based on the input alone.
140   // Label selection size controls how many items in each beam are passed
141   // through to the beam scorer. Only items with top N input scores are
142   // considered.
143   // Label selection margin controls the difference between minimal input score
144   // (versus the best scoring label) for an item to be passed to the beam
145   // scorer. This margin is expressed in terms of log-probability.
146   // Default is to do no label selection.
147   // For more detail: https://research.google.com/pubs/pub44823.html
148   int label_selection_size_ = 0;       // zero means unlimited
149   T label_selection_margin_ = -1;      // -1 means unlimited.
150 
151   gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
152   std::unique_ptr<BeamRoot> beam_root_;
153   BaseBeamScorer<T, CTCBeamState>* beam_scorer_;
154 
155   TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
156 };
157 
158 template <typename T, typename CTCBeamState, typename CTCBeamComparer>
Decode(const typename CTCDecoder<T>::SequenceLength & seq_len,const std::vector<typename CTCDecoder<T>::Input> & input,std::vector<typename CTCDecoder<T>::Output> * output,typename CTCDecoder<T>::ScoreOutput * scores)159 Status CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Decode(
160     const typename CTCDecoder<T>::SequenceLength& seq_len,
161     const std::vector<typename CTCDecoder<T>::Input>& input,
162     std::vector<typename CTCDecoder<T>::Output>* output,
163     typename CTCDecoder<T>::ScoreOutput* scores) {
164   // Storage for top paths.
165   std::vector<std::vector<int>> beams;
166   std::vector<T> beam_log_probabilities;
167   int top_n = output->size();
168   if (std::any_of(output->begin(), output->end(),
169                   [this](const typename CTCDecoder<T>::Output& output) -> bool {
170                     return output.size() < this->batch_size_;
171                   })) {
172     return errors::InvalidArgument(
173         "output needs to be of size at least (top_n, batch_size).");
174   }
175   if (scores->rows() < this->batch_size_ || scores->cols() < top_n) {
176     return errors::InvalidArgument(
177         "scores needs to be of size at least (batch_size, top_n).");
178   }
179 
180   for (int b = 0; b < this->batch_size_; ++b) {
181     int seq_len_b = seq_len[b];
182     Reset();
183 
184     for (int t = 0; t < seq_len_b; ++t) {
185       // Pass log-probabilities for this example + time.
186       Step(input[t].row(b));
187     }  // for (int t...
188 
189     // O(n * log(n))
190     std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
191     leaves_.Reset();
192     for (int i = 0; i < branches->size(); ++i) {
193       BeamEntry* entry = (*branches)[i];
194       beam_scorer_->ExpandStateEnd(&entry->state);
195       entry->newp.total +=
196           beam_scorer_->GetStateEndExpansionScore(entry->state);
197       leaves_.push(entry);
198     }
199 
200     Status status =
201         TopPaths(top_n, &beams, &beam_log_probabilities, this->merge_repeated_);
202     if (!status.ok()) {
203       return status;
204     }
205 
206     CHECK_EQ(top_n, beam_log_probabilities.size());
207     CHECK_EQ(beams.size(), beam_log_probabilities.size());
208 
209     for (int i = 0; i < top_n; ++i) {
210       // Copy output to the correct beam + batch
211       (*output)[i][b].swap(beams[i]);
212       (*scores)(b, i) = -beam_log_probabilities[i];
213     }
214   }  // for (int b...
215   return Status::OK();
216 }
217 
218 template <typename T, typename CTCBeamState, typename CTCBeamComparer>
219 template <typename Vector>
GetTopK(const int K,const Vector & input,std::vector<T> * top_k_logits,std::vector<int> * top_k_indices)220 T CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::GetTopK(
221     const int K, const Vector& input, std::vector<T>* top_k_logits,
222     std::vector<int>* top_k_indices) {
223   // Find Top K choices, complexity nk in worst case. The array input is read
224   // just once.
225   CHECK_EQ(this->num_classes_, input.size());
226   top_k_logits->clear();
227   top_k_indices->clear();
228   top_k_logits->resize(K, -INFINITY);
229   top_k_indices->resize(K, -1);
230   for (int j = 0; j < this->num_classes_ - 1; ++j) {
231     const T logit = input(j);
232     if (logit > (*top_k_logits)[K - 1]) {
233       int k = K - 1;
234       while (k > 0 && logit > (*top_k_logits)[k - 1]) {
235         (*top_k_logits)[k] = (*top_k_logits)[k - 1];
236         (*top_k_indices)[k] = (*top_k_indices)[k - 1];
237         k--;
238       }
239       (*top_k_logits)[k] = logit;
240       (*top_k_indices)[k] = j;
241     }
242   }
243   // Return max value which is in 0th index or blank character logit
244   return std::max((*top_k_logits)[0], input(this->num_classes_ - 1));
245 }
246 
247 template <typename T, typename CTCBeamState, typename CTCBeamComparer>
248 template <typename Vector>
Step(const Vector & raw_input)249 void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Step(
250     const Vector& raw_input) {
251   std::vector<T> top_k_logits;
252   std::vector<int> top_k_indices;
253   const bool top_k =
254       (label_selection_size_ > 0 && label_selection_size_ < raw_input.size());
255   // Number of character classes to consider in each step.
256   const int max_classes =
257       top_k ? label_selection_size_ : (this->num_classes_ - 1);
258   // Get max coefficient and remove it from raw_input later.
259   T max_coeff;
260   if (top_k) {
261     max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits,
262                         &top_k_indices);
263   } else {
264     max_coeff = raw_input.maxCoeff();
265   }
266   // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
267   T logsumexp = T(0.0);
268   for (int j = 0; j < raw_input.size(); ++j) {
269     logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
270   }
271   logsumexp = Eigen::numext::log(logsumexp);
272   // Final normalization offset to get correct log probabilities.
273   T norm_offset = max_coeff + logsumexp;
274 
275   const T label_selection_input_min =
276       (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
277                                      : -std::numeric_limits<T>::infinity();
278 
279   // Extract the beams sorted in decreasing new probability
280   CHECK_EQ(this->num_classes_, raw_input.size());
281 
282   std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
283   leaves_.Reset();
284 
285   for (BeamEntry* b : *branches) {
286     // P(.. @ t) becomes the new P(.. @ t-1)
287     b->oldp = b->newp;
288   }
289 
290   for (BeamEntry* b : *branches) {
291     if (b->parent != nullptr) {  // if not the root
292       if (b->parent->Active()) {
293         // If last two sequence characters are identical:
294         //   Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5)
295         //                          + Pblank(l=ac @ t=5))
296         // else:
297         //   Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5)
298         //                          + P(l=ab @ t=5))
299         T previous = (b->label == b->parent->label) ? b->parent->oldp.blank
300                                                     : b->parent->oldp.total;
301         b->newp.label =
302             LogSumExp(b->newp.label,
303                       beam_scorer_->GetStateExpansionScore(b->state, previous));
304       }
305       // Plabel(l=abc @ t=6) *= P(c @ 6)
306       b->newp.label += raw_input(b->label) - norm_offset;
307     }
308     // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
309     b->newp.blank = b->oldp.total + raw_input(this->blank_index_) - norm_offset;
310     // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
311     b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
312 
313     // Push the entry back to the top paths list.
314     // Note, this will always fill leaves back up in sorted order.
315     leaves_.push(b);
316   }
317 
318   // we need to resort branches in descending oldp order.
319 
320   // branches is in descending oldp order because it was
321   // originally in descending newp order and we copied newp to oldp.
322 
323   // Grow new leaves
324   for (BeamEntry* b : *branches) {
325     // A new leaf (represented by its BeamProbability) is a candidate
326     // iff its total probability is nonzero and either the beam list
327     // isn't full, or the lowest probability entry in the beam has a
328     // lower probability than the leaf.
329     auto is_candidate = [this](const BeamProbability& prob) {
330       return (prob.total > kLogZero<T>() &&
331               (leaves_.size() < beam_width_ ||
332                prob.total > leaves_.peek_bottom()->newp.total));
333     };
334 
335     if (!is_candidate(b->oldp)) {
336       continue;
337     }
338 
339     for (int ind = 0; ind < max_classes; ind++) {
340       const int label = top_k ? top_k_indices[ind] : ind;
341       const T logit = top_k ? top_k_logits[ind] : raw_input(ind);
342       // Perform label selection: if input for this label looks very
343       // unpromising, never evaluate it with a scorer.
344       // We may compare logits instead of log probabilities,
345       // since the difference is the same in both cases.
346       if (logit < label_selection_input_min) {
347         continue;
348       }
349       BeamEntry& c = b->GetChild(label);
350       if (!c.Active()) {
351         //   Pblank(l=abcd @ t=6) = 0
352         c.newp.blank = kLogZero<T>();
353         // If new child label is identical to beam label:
354         //   Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
355         // Otherwise:
356         //   Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
357         beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
358         T previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
359         c.newp.label = logit - norm_offset +
360                        beam_scorer_->GetStateExpansionScore(c.state, previous);
361         // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
362         c.newp.total = c.newp.label;
363 
364         if (is_candidate(c.newp)) {
365           // Before adding the new node to the beam, check if the beam
366           // is already at maximum width.
367           if (leaves_.size() == beam_width_) {
368             // Bottom is no longer in the beam search.  Reset
369             // its probability; signal it's no longer in the beam search.
370             BeamEntry* bottom = leaves_.peek_bottom();
371             bottom->newp.Reset();
372           }
373           leaves_.push(&c);
374         } else {
375           // Deactivate child.
376           c.oldp.Reset();
377           c.newp.Reset();
378         }
379       }
380     }
381   }  // for (BeamEntry* b...
382 }
383 
384 template <typename T, typename CTCBeamState, typename CTCBeamComparer>
Reset()385 void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Reset() {
386   leaves_.Reset();
387 
388   // This beam root, and all of its children, will be in memory until
389   // the next reset.
390   beam_root_.reset(new BeamRoot(nullptr, -1));
391   beam_root_->RootEntry()->newp.total = T(0.0);  // ln(1)
392   beam_root_->RootEntry()->newp.blank = T(0.0);  // ln(1)
393 
394   // Add the root as the initial leaf.
395   leaves_.push(beam_root_->RootEntry());
396 
397   // Call initialize state on the root object.
398   beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
399 }
400 
401 template <typename T, typename CTCBeamState, typename CTCBeamComparer>
TopPaths(int n,std::vector<std::vector<int>> * paths,std::vector<T> * log_probs,bool merge_repeated)402 Status CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::TopPaths(
403     int n, std::vector<std::vector<int>>* paths, std::vector<T>* log_probs,
404     bool merge_repeated) const {
405   CHECK_NOTNULL(paths)->clear();
406   CHECK_NOTNULL(log_probs)->clear();
407   if (n > beam_width_) {
408     return errors::InvalidArgument("requested more paths than the beam width.");
409   }
410   if (n > leaves_.size()) {
411     return errors::InvalidArgument(
412         "Less leaves in the beam search than requested.");
413   }
414 
415   gtl::TopN<BeamEntry*, CTCBeamComparer> top_branches(n);
416 
417   // O(beam_width_ * log(n)), space complexity is O(n)
418   for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) {
419     top_branches.push(*it);
420   }
421   // O(n * log(n))
422   std::unique_ptr<std::vector<BeamEntry*>> branches(top_branches.Extract());
423 
424   for (int i = 0; i < n; ++i) {
425     BeamEntry* e((*branches)[i]);
426     paths->push_back(e->LabelSeq(merge_repeated));
427     log_probs->push_back(e->newp.total);
428   }
429   return Status::OK();
430 }
431 
432 }  // namespace ctc
433 }  // namespace tensorflow
434 
435 #endif  // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
436 // LINT.ThenChange(//tensorflow/lite/experimental/kernels/ctc_beam_search.h)
437