1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // LINT.IfChange
16 
17 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
18 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
19 
20 #include <algorithm>
21 #include <cmath>
22 #include <limits>
23 #include <memory>
24 #include <vector>
25 
26 #include "third_party/eigen3/Eigen/Core"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/top_n.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/ctc/ctc_beam_entry.h"
34 #include "tensorflow/core/util/ctc/ctc_beam_scorer.h"
35 #include "tensorflow/core/util/ctc/ctc_decoder.h"
36 #include "tensorflow/core/util/ctc/ctc_loss_util.h"
37 
38 namespace tensorflow {
39 namespace ctc {
40 
41 template <typename CTCBeamState = ctc_beam_search::EmptyBeamState,
42           typename CTCBeamComparer =
43               ctc_beam_search::BeamComparer<CTCBeamState>>
44 class CTCBeamSearchDecoder : public CTCDecoder {
45   // Beam Search
46   //
47   // Example (GravesTh Fig. 7.5):
48   //         a    -
49   //  P = [ 0.3  0.7 ]  t = 0
50   //      [ 0.4  0.6 ]  t = 1
51   //
52   // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42
53   //      P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58
54   //
55   // In this case, Best Path decoding is suboptimal.
56   //
57   // For Beam Search, we use the following main recurrence relations:
58   //
59   // Relation 1:
60   // ---------------------------------------------------------- Eq. 1
61   //      P(l=abcd @ t=7) = P(l=abc  @ t=6) * P(d @ 7)
62   //                      + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7))
63   // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and
64   // updated recursively in the beam entry.
65   //
66   // Relation 2:
67   // ---------------------------------------------------------- Eq. 2
68   //      P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3)
69   // for ? in a, b, d, ..., (not including c or the blank index),
70   // and the recurrence starts from the beam entry for P(l=abc @ t=2).
71   //
72   // For this case, the length of the new sequence equals t+1 (t
73   // starts at 0).  This special case can be calculated as:
74   //   P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
75   // but we calculate it recursively for speed purposes.
76   typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
77   typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
78   typedef ctc_beam_search::BeamProbability BeamProbability;
79 
80  public:
81   typedef BaseBeamScorer<CTCBeamState> DefaultBeamScorer;
82 
83   // The beam search decoder is constructed specifying the beam_width (number of
84   // candidates to keep at each decoding timestep) and a beam scorer (used for
85   // custom scoring, for example enabling the use of a language model).
86   // The ownership of the scorer remains with the caller. The default
87   // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the
88   // standard beam search.
89   CTCBeamSearchDecoder(int num_classes, int beam_width,
90                        BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1,
91                        bool merge_repeated = false)
CTCDecoder(num_classes,batch_size,merge_repeated)92       : CTCDecoder(num_classes, batch_size, merge_repeated),
93         beam_width_(beam_width),
94         leaves_(beam_width),
95         beam_scorer_(CHECK_NOTNULL(scorer)) {
96     Reset();
97   }
98 
~CTCBeamSearchDecoder()99   ~CTCBeamSearchDecoder() override {}
100 
101   // Run the hibernating beam search algorithm on the given input.
102   Status Decode(const CTCDecoder::SequenceLength& seq_len,
103                 const std::vector<CTCDecoder::Input>& input,
104                 std::vector<CTCDecoder::Output>* output,
105                 CTCDecoder::ScoreOutput* scores) override;
106 
107   // Calculate the next step of the beam search and update the internal state.
108   template <typename Vector>
109   void Step(const Vector& log_input_t);
110 
111   template <typename Vector>
112   float GetTopK(const int K, const Vector& input,
113                 std::vector<float>* top_k_logits,
114                 std::vector<int>* top_k_indices);
115 
116   // Retrieve the beam scorer instance used during decoding.
GetBeamScorer()117   BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
118 
119   // Set label selection parameters for faster decoding.
120   // See comments for label_selection_size_ and label_selection_margin_.
SetLabelSelectionParameters(int label_selection_size,float label_selection_margin)121   void SetLabelSelectionParameters(int label_selection_size,
122                                    float label_selection_margin) {
123     label_selection_size_ = label_selection_size;
124     label_selection_margin_ = label_selection_margin;
125   }
126 
127   // Reset the beam search
128   void Reset();
129 
130   // Extract the top n paths at current time step
131   Status TopPaths(int n, std::vector<std::vector<int>>* paths,
132                   std::vector<float>* log_probs, bool merge_repeated) const;
133 
134  private:
135   int beam_width_;
136 
137   // Label selection is designed to avoid possibly very expensive scorer calls,
138   // by pruning the hypotheses based on the input alone.
139   // Label selection size controls how many items in each beam are passed
140   // through to the beam scorer. Only items with top N input scores are
141   // considered.
142   // Label selection margin controls the difference between minimal input score
143   // (versus the best scoring label) for an item to be passed to the beam
144   // scorer. This margin is expressed in terms of log-probability.
145   // Default is to do no label selection.
146   // For more detail: https://research.google.com/pubs/pub44823.html
147   int label_selection_size_ = 0;       // zero means unlimited
148   float label_selection_margin_ = -1;  // -1 means unlimited.
149 
150   gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
151   std::unique_ptr<BeamRoot> beam_root_;
152   BaseBeamScorer<CTCBeamState>* beam_scorer_;
153 
154   TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
155 };
156 
157 template <typename CTCBeamState, typename CTCBeamComparer>
Decode(const CTCDecoder::SequenceLength & seq_len,const std::vector<CTCDecoder::Input> & input,std::vector<CTCDecoder::Output> * output,ScoreOutput * scores)158 Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
159     const CTCDecoder::SequenceLength& seq_len,
160     const std::vector<CTCDecoder::Input>& input,
161     std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) {
162   // Storage for top paths.
163   std::vector<std::vector<int>> beams;
164   std::vector<float> beam_log_probabilities;
165   int top_n = output->size();
166   if (std::any_of(output->begin(), output->end(),
167                   [this](const CTCDecoder::Output& output) -> bool {
168                     return output.size() < this->batch_size_;
169                   })) {
170     return errors::InvalidArgument(
171         "output needs to be of size at least (top_n, batch_size).");
172   }
173   if (scores->rows() < batch_size_ || scores->cols() < top_n) {
174     return errors::InvalidArgument(
175         "scores needs to be of size at least (batch_size, top_n).");
176   }
177 
178   for (int b = 0; b < batch_size_; ++b) {
179     int seq_len_b = seq_len[b];
180     Reset();
181 
182     for (int t = 0; t < seq_len_b; ++t) {
183       // Pass log-probabilities for this example + time.
184       Step(input[t].row(b));
185     }  // for (int t...
186 
187     // O(n * log(n))
188     std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
189     leaves_.Reset();
190     for (int i = 0; i < branches->size(); ++i) {
191       BeamEntry* entry = (*branches)[i];
192       beam_scorer_->ExpandStateEnd(&entry->state);
193       entry->newp.total +=
194           beam_scorer_->GetStateEndExpansionScore(entry->state);
195       leaves_.push(entry);
196     }
197 
198     Status status =
199         TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_);
200     if (!status.ok()) {
201       return status;
202     }
203 
204     CHECK_EQ(top_n, beam_log_probabilities.size());
205     CHECK_EQ(beams.size(), beam_log_probabilities.size());
206 
207     for (int i = 0; i < top_n; ++i) {
208       // Copy output to the correct beam + batch
209       (*output)[i][b].swap(beams[i]);
210       (*scores)(b, i) = -beam_log_probabilities[i];
211     }
212   }  // for (int b...
213   return Status::OK();
214 }
215 
216 template <typename CTCBeamState, typename CTCBeamComparer>
217 template <typename Vector>
GetTopK(const int K,const Vector & input,std::vector<float> * top_k_logits,std::vector<int> * top_k_indices)218 float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
219     const int K, const Vector& input, std::vector<float>* top_k_logits,
220     std::vector<int>* top_k_indices) {
221   // Find Top K choices, complexity nk in worst case. The array input is read
222   // just once.
223   CHECK_EQ(num_classes_, input.size());
224   top_k_logits->clear();
225   top_k_indices->clear();
226   top_k_logits->resize(K, -INFINITY);
227   top_k_indices->resize(K, -1);
228   for (int j = 0; j < num_classes_ - 1; ++j) {
229     const float logit = input(j);
230     if (logit > (*top_k_logits)[K - 1]) {
231       int k = K - 1;
232       while (k > 0 && logit > (*top_k_logits)[k - 1]) {
233         (*top_k_logits)[k] = (*top_k_logits)[k - 1];
234         (*top_k_indices)[k] = (*top_k_indices)[k - 1];
235         k--;
236       }
237       (*top_k_logits)[k] = logit;
238       (*top_k_indices)[k] = j;
239     }
240   }
241   // Return max value which is in 0th index or blank character logit
242   return std::max((*top_k_logits)[0], input(num_classes_ - 1));
243 }
244 
245 template <typename CTCBeamState, typename CTCBeamComparer>
246 template <typename Vector>
Step(const Vector & raw_input)247 void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
248     const Vector& raw_input) {
249   std::vector<float> top_k_logits;
250   std::vector<int> top_k_indices;
251   const bool top_k =
252       (label_selection_size_ > 0 && label_selection_size_ < raw_input.size());
253   // Number of character classes to consider in each step.
254   const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1);
255   // Get max coefficient and remove it from raw_input later.
256   float max_coeff;
257   if (top_k) {
258     max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits,
259                         &top_k_indices);
260   } else {
261     max_coeff = raw_input.maxCoeff();
262   }
263 
264   // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
265   float logsumexp = 0.0;
266   for (int j = 0; j < raw_input.size(); ++j) {
267     logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
268   }
269   logsumexp = Eigen::numext::log(logsumexp);
270   // Final normalization offset to get correct log probabilities.
271   float norm_offset = max_coeff + logsumexp;
272 
273   const float label_selection_input_min =
274       (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
275                                      : -std::numeric_limits<float>::infinity();
276 
277   // Extract the beams sorted in decreasing new probability
278   CHECK_EQ(num_classes_, raw_input.size());
279 
280   std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
281   leaves_.Reset();
282 
283   for (BeamEntry* b : *branches) {
284     // P(.. @ t) becomes the new P(.. @ t-1)
285     b->oldp = b->newp;
286   }
287 
288   for (BeamEntry* b : *branches) {
289     if (b->parent != nullptr) {  // if not the root
290       if (b->parent->Active()) {
291         // If last two sequence characters are identical:
292         //   Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5)
293         //                          + Pblank(l=ac @ t=5))
294         // else:
295         //   Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5)
296         //                          + P(l=ab @ t=5))
297         float previous = (b->label == b->parent->label) ? b->parent->oldp.blank
298                                                         : b->parent->oldp.total;
299         b->newp.label =
300             LogSumExp(b->newp.label,
301                       beam_scorer_->GetStateExpansionScore(b->state, previous));
302       }
303       // Plabel(l=abc @ t=6) *= P(c @ 6)
304       b->newp.label += raw_input(b->label) - norm_offset;
305     }
306     // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
307     b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
308     // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
309     b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
310 
311     // Push the entry back to the top paths list.
312     // Note, this will always fill leaves back up in sorted order.
313     leaves_.push(b);
314   }
315 
316   // we need to resort branches in descending oldp order.
317 
318   // branches is in descending oldp order because it was
319   // originally in descending newp order and we copied newp to oldp.
320 
321   // Grow new leaves
322   for (BeamEntry* b : *branches) {
323     // A new leaf (represented by its BeamProbability) is a candidate
324     // iff its total probability is nonzero and either the beam list
325     // isn't full, or the lowest probability entry in the beam has a
326     // lower probability than the leaf.
327     auto is_candidate = [this](const BeamProbability& prob) {
328       return (prob.total > kLogZero &&
329               (leaves_.size() < beam_width_ ||
330                prob.total > leaves_.peek_bottom()->newp.total));
331     };
332 
333     if (!is_candidate(b->oldp)) {
334       continue;
335     }
336 
337     for (int ind = 0; ind < max_classes; ind++) {
338       const int label = top_k ? top_k_indices[ind] : ind;
339       const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
340       // Perform label selection: if input for this label looks very
341       // unpromising, never evaluate it with a scorer.
342       // We may compare logits instead of log probabilities,
343       // since the difference is the same in both cases.
344       if (logit < label_selection_input_min) {
345         continue;
346       }
347       BeamEntry& c = b->GetChild(label);
348       if (!c.Active()) {
349         //   Pblank(l=abcd @ t=6) = 0
350         c.newp.blank = kLogZero;
351         // If new child label is identical to beam label:
352         //   Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
353         // Otherwise:
354         //   Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
355         beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
356         float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
357         c.newp.label = logit - norm_offset +
358                        beam_scorer_->GetStateExpansionScore(c.state, previous);
359         // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
360         c.newp.total = c.newp.label;
361 
362         if (is_candidate(c.newp)) {
363           // Before adding the new node to the beam, check if the beam
364           // is already at maximum width.
365           if (leaves_.size() == beam_width_) {
366             // Bottom is no longer in the beam search.  Reset
367             // its probability; signal it's no longer in the beam search.
368             BeamEntry* bottom = leaves_.peek_bottom();
369             bottom->newp.Reset();
370           }
371           leaves_.push(&c);
372         } else {
373           // Deactivate child.
374           c.oldp.Reset();
375           c.newp.Reset();
376         }
377       }
378     }
379   }  // for (BeamEntry* b...
380 }
381 
382 template <typename CTCBeamState, typename CTCBeamComparer>
Reset()383 void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
384   leaves_.Reset();
385 
386   // This beam root, and all of its children, will be in memory until
387   // the next reset.
388   beam_root_.reset(new BeamRoot(nullptr, -1));
389   beam_root_->RootEntry()->newp.total = 0.0;  // ln(1)
390   beam_root_->RootEntry()->newp.blank = 0.0;  // ln(1)
391 
392   // Add the root as the initial leaf.
393   leaves_.push(beam_root_->RootEntry());
394 
395   // Call initialize state on the root object.
396   beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
397 }
398 
399 template <typename CTCBeamState, typename CTCBeamComparer>
TopPaths(int n,std::vector<std::vector<int>> * paths,std::vector<float> * log_probs,bool merge_repeated)400 Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
401     int n, std::vector<std::vector<int>>* paths, std::vector<float>* log_probs,
402     bool merge_repeated) const {
403   CHECK_NOTNULL(paths)->clear();
404   CHECK_NOTNULL(log_probs)->clear();
405   if (n > beam_width_) {
406     return errors::InvalidArgument("requested more paths than the beam width.");
407   }
408   if (n > leaves_.size()) {
409     return errors::InvalidArgument(
410         "Less leaves in the beam search than requested.");
411   }
412 
413   gtl::TopN<BeamEntry*, CTCBeamComparer> top_branches(n);
414 
415   // O(beam_width_ * log(n)), space complexity is O(n)
416   for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) {
417     top_branches.push(*it);
418   }
419   // O(n * log(n))
420   std::unique_ptr<std::vector<BeamEntry*>> branches(top_branches.Extract());
421 
422   for (int i = 0; i < n; ++i) {
423     BeamEntry* e((*branches)[i]);
424     paths->push_back(e->LabelSeq(merge_repeated));
425     log_probs->push_back(e->newp.total);
426   }
427   return Status::OK();
428 }
429 
430 }  // namespace ctc
431 }  // namespace tensorflow
432 
433 #endif  // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
434 // LINT.ThenChange(//tensorflow/lite/experimental/kernels/ctc_beam_search.h)
435