1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // LINT.IfChange
16
17 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
18 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
19
20 #include <algorithm>
21 #include <cmath>
22 #include <limits>
23 #include <memory>
24 #include <vector>
25
26 #include "third_party/eigen3/Eigen/Core"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/top_n.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/ctc/ctc_beam_entry.h"
34 #include "tensorflow/core/util/ctc/ctc_beam_scorer.h"
35 #include "tensorflow/core/util/ctc/ctc_decoder.h"
36 #include "tensorflow/core/util/ctc/ctc_loss_util.h"
37
38 namespace tensorflow {
39 namespace ctc {
40
41 template <typename T, typename CTCBeamState = ctc_beam_search::EmptyBeamState,
42 typename CTCBeamComparer =
43 ctc_beam_search::BeamComparer<T, CTCBeamState>>
44 class CTCBeamSearchDecoder : public CTCDecoder<T> {
45 // Beam Search
46 //
47 // Example (GravesTh Fig. 7.5):
48 // a -
49 // P = [ 0.3 0.7 ] t = 0
50 // [ 0.4 0.6 ] t = 1
51 //
52 // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42
53 // P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58
54 //
55 // In this case, Best Path decoding is suboptimal.
56 //
57 // For Beam Search, we use the following main recurrence relations:
58 //
59 // Relation 1:
60 // ---------------------------------------------------------- Eq. 1
61 // P(l=abcd @ t=7) = P(l=abc @ t=6) * P(d @ 7)
62 // + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7))
63 // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and
64 // updated recursively in the beam entry.
65 //
66 // Relation 2:
67 // ---------------------------------------------------------- Eq. 2
68 // P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3)
69 // for ? in a, b, d, ..., (not including c or the blank index),
70 // and the recurrence starts from the beam entry for P(l=abc @ t=2).
71 //
72 // For this case, the length of the new sequence equals t+1 (t
73 // starts at 0). This special case can be calculated as:
74 // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
75 // but we calculate it recursively for speed purposes.
76 typedef ctc_beam_search::BeamEntry<T, CTCBeamState> BeamEntry;
77 typedef ctc_beam_search::BeamRoot<T, CTCBeamState> BeamRoot;
78 typedef ctc_beam_search::BeamProbability<T> BeamProbability;
79
80 public:
81 typedef BaseBeamScorer<T, CTCBeamState> DefaultBeamScorer;
82
83 // The beam search decoder is constructed specifying the beam_width (number of
84 // candidates to keep at each decoding timestep) and a beam scorer (used for
85 // custom scoring, for example enabling the use of a language model).
86 // The ownership of the scorer remains with the caller. The default
87 // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the
88 // standard beam search.
89 CTCBeamSearchDecoder(int num_classes, int beam_width,
90 BaseBeamScorer<T, CTCBeamState>* scorer,
91 int batch_size = 1, bool merge_repeated = false)
92 : CTCDecoder<T>(num_classes, batch_size, merge_repeated),
93 beam_width_(beam_width),
94 leaves_(beam_width),
95 beam_scorer_(CHECK_NOTNULL(scorer)) {
96 Reset();
97 }
98
~CTCBeamSearchDecoder()99 ~CTCBeamSearchDecoder() override {}
100
101 // Run the hibernating beam search algorithm on the given input.
102 Status Decode(const typename CTCDecoder<T>::SequenceLength& seq_len,
103 const std::vector<typename CTCDecoder<T>::Input>& input,
104 std::vector<typename CTCDecoder<T>::Output>* output,
105 typename CTCDecoder<T>::ScoreOutput* scores) override;
106
107 // Calculate the next step of the beam search and update the internal state.
108 template <typename Vector>
109 void Step(const Vector& log_input_t);
110
111 template <typename Vector>
112 T GetTopK(const int K, const Vector& input, std::vector<T>* top_k_logits,
113 std::vector<int>* top_k_indices);
114
115 // Retrieve the beam scorer instance used during decoding.
GetBeamScorer()116 BaseBeamScorer<T, CTCBeamState>* GetBeamScorer() const {
117 return beam_scorer_;
118 }
119
120 // Set label selection parameters for faster decoding.
121 // See comments for label_selection_size_ and label_selection_margin_.
SetLabelSelectionParameters(int label_selection_size,T label_selection_margin)122 void SetLabelSelectionParameters(int label_selection_size,
123 T label_selection_margin) {
124 label_selection_size_ = label_selection_size;
125 label_selection_margin_ = label_selection_margin;
126 }
127
128 // Reset the beam search
129 void Reset();
130
131 // Extract the top n paths at current time step
132 Status TopPaths(int n, std::vector<std::vector<int>>* paths,
133 std::vector<T>* log_probs, bool merge_repeated) const;
134
135 private:
136 int beam_width_;
137
138 // Label selection is designed to avoid possibly very expensive scorer calls,
139 // by pruning the hypotheses based on the input alone.
140 // Label selection size controls how many items in each beam are passed
141 // through to the beam scorer. Only items with top N input scores are
142 // considered.
143 // Label selection margin controls the difference between minimal input score
144 // (versus the best scoring label) for an item to be passed to the beam
145 // scorer. This margin is expressed in terms of log-probability.
146 // Default is to do no label selection.
147 // For more detail: https://research.google.com/pubs/pub44823.html
148 int label_selection_size_ = 0; // zero means unlimited
149 T label_selection_margin_ = -1; // -1 means unlimited.
150
151 gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
152 std::unique_ptr<BeamRoot> beam_root_;
153 BaseBeamScorer<T, CTCBeamState>* beam_scorer_;
154
155 TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
156 };
157
158 template <typename T, typename CTCBeamState, typename CTCBeamComparer>
Decode(const typename CTCDecoder<T>::SequenceLength & seq_len,const std::vector<typename CTCDecoder<T>::Input> & input,std::vector<typename CTCDecoder<T>::Output> * output,typename CTCDecoder<T>::ScoreOutput * scores)159 Status CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Decode(
160 const typename CTCDecoder<T>::SequenceLength& seq_len,
161 const std::vector<typename CTCDecoder<T>::Input>& input,
162 std::vector<typename CTCDecoder<T>::Output>* output,
163 typename CTCDecoder<T>::ScoreOutput* scores) {
164 // Storage for top paths.
165 std::vector<std::vector<int>> beams;
166 std::vector<T> beam_log_probabilities;
167 int top_n = output->size();
168 if (std::any_of(output->begin(), output->end(),
169 [this](const typename CTCDecoder<T>::Output& output) -> bool {
170 return output.size() < this->batch_size_;
171 })) {
172 return errors::InvalidArgument(
173 "output needs to be of size at least (top_n, batch_size).");
174 }
175 if (scores->rows() < this->batch_size_ || scores->cols() < top_n) {
176 return errors::InvalidArgument(
177 "scores needs to be of size at least (batch_size, top_n).");
178 }
179
180 for (int b = 0; b < this->batch_size_; ++b) {
181 int seq_len_b = seq_len[b];
182 Reset();
183
184 for (int t = 0; t < seq_len_b; ++t) {
185 // Pass log-probabilities for this example + time.
186 Step(input[t].row(b));
187 } // for (int t...
188
189 // O(n * log(n))
190 std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
191 leaves_.Reset();
192 for (int i = 0; i < branches->size(); ++i) {
193 BeamEntry* entry = (*branches)[i];
194 beam_scorer_->ExpandStateEnd(&entry->state);
195 entry->newp.total +=
196 beam_scorer_->GetStateEndExpansionScore(entry->state);
197 leaves_.push(entry);
198 }
199
200 Status status =
201 TopPaths(top_n, &beams, &beam_log_probabilities, this->merge_repeated_);
202 if (!status.ok()) {
203 return status;
204 }
205
206 CHECK_EQ(top_n, beam_log_probabilities.size());
207 CHECK_EQ(beams.size(), beam_log_probabilities.size());
208
209 for (int i = 0; i < top_n; ++i) {
210 // Copy output to the correct beam + batch
211 (*output)[i][b].swap(beams[i]);
212 (*scores)(b, i) = -beam_log_probabilities[i];
213 }
214 } // for (int b...
215 return Status::OK();
216 }
217
218 template <typename T, typename CTCBeamState, typename CTCBeamComparer>
219 template <typename Vector>
GetTopK(const int K,const Vector & input,std::vector<T> * top_k_logits,std::vector<int> * top_k_indices)220 T CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::GetTopK(
221 const int K, const Vector& input, std::vector<T>* top_k_logits,
222 std::vector<int>* top_k_indices) {
223 // Find Top K choices, complexity nk in worst case. The array input is read
224 // just once.
225 CHECK_EQ(this->num_classes_, input.size());
226 top_k_logits->clear();
227 top_k_indices->clear();
228 top_k_logits->resize(K, -INFINITY);
229 top_k_indices->resize(K, -1);
230 for (int j = 0; j < this->num_classes_ - 1; ++j) {
231 const T logit = input(j);
232 if (logit > (*top_k_logits)[K - 1]) {
233 int k = K - 1;
234 while (k > 0 && logit > (*top_k_logits)[k - 1]) {
235 (*top_k_logits)[k] = (*top_k_logits)[k - 1];
236 (*top_k_indices)[k] = (*top_k_indices)[k - 1];
237 k--;
238 }
239 (*top_k_logits)[k] = logit;
240 (*top_k_indices)[k] = j;
241 }
242 }
243 // Return max value which is in 0th index or blank character logit
244 return std::max((*top_k_logits)[0], input(this->num_classes_ - 1));
245 }
246
247 template <typename T, typename CTCBeamState, typename CTCBeamComparer>
248 template <typename Vector>
Step(const Vector & raw_input)249 void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Step(
250 const Vector& raw_input) {
251 std::vector<T> top_k_logits;
252 std::vector<int> top_k_indices;
253 const bool top_k =
254 (label_selection_size_ > 0 && label_selection_size_ < raw_input.size());
255 // Number of character classes to consider in each step.
256 const int max_classes =
257 top_k ? label_selection_size_ : (this->num_classes_ - 1);
258 // Get max coefficient and remove it from raw_input later.
259 T max_coeff;
260 if (top_k) {
261 max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits,
262 &top_k_indices);
263 } else {
264 max_coeff = raw_input.maxCoeff();
265 }
266 // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
267 T logsumexp = T(0.0);
268 for (int j = 0; j < raw_input.size(); ++j) {
269 logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
270 }
271 logsumexp = Eigen::numext::log(logsumexp);
272 // Final normalization offset to get correct log probabilities.
273 T norm_offset = max_coeff + logsumexp;
274
275 const T label_selection_input_min =
276 (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
277 : -std::numeric_limits<T>::infinity();
278
279 // Extract the beams sorted in decreasing new probability
280 CHECK_EQ(this->num_classes_, raw_input.size());
281
282 std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
283 leaves_.Reset();
284
285 for (BeamEntry* b : *branches) {
286 // P(.. @ t) becomes the new P(.. @ t-1)
287 b->oldp = b->newp;
288 }
289
290 for (BeamEntry* b : *branches) {
291 if (b->parent != nullptr) { // if not the root
292 if (b->parent->Active()) {
293 // If last two sequence characters are identical:
294 // Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5)
295 // + Pblank(l=ac @ t=5))
296 // else:
297 // Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5)
298 // + P(l=ab @ t=5))
299 T previous = (b->label == b->parent->label) ? b->parent->oldp.blank
300 : b->parent->oldp.total;
301 b->newp.label =
302 LogSumExp(b->newp.label,
303 beam_scorer_->GetStateExpansionScore(b->state, previous));
304 }
305 // Plabel(l=abc @ t=6) *= P(c @ 6)
306 b->newp.label += raw_input(b->label) - norm_offset;
307 }
308 // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
309 b->newp.blank = b->oldp.total + raw_input(this->blank_index_) - norm_offset;
310 // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
311 b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
312
313 // Push the entry back to the top paths list.
314 // Note, this will always fill leaves back up in sorted order.
315 leaves_.push(b);
316 }
317
318 // we need to resort branches in descending oldp order.
319
320 // branches is in descending oldp order because it was
321 // originally in descending newp order and we copied newp to oldp.
322
323 // Grow new leaves
324 for (BeamEntry* b : *branches) {
325 // A new leaf (represented by its BeamProbability) is a candidate
326 // iff its total probability is nonzero and either the beam list
327 // isn't full, or the lowest probability entry in the beam has a
328 // lower probability than the leaf.
329 auto is_candidate = [this](const BeamProbability& prob) {
330 return (prob.total > kLogZero<T>() &&
331 (leaves_.size() < beam_width_ ||
332 prob.total > leaves_.peek_bottom()->newp.total));
333 };
334
335 if (!is_candidate(b->oldp)) {
336 continue;
337 }
338
339 for (int ind = 0; ind < max_classes; ind++) {
340 const int label = top_k ? top_k_indices[ind] : ind;
341 const T logit = top_k ? top_k_logits[ind] : raw_input(ind);
342 // Perform label selection: if input for this label looks very
343 // unpromising, never evaluate it with a scorer.
344 // We may compare logits instead of log probabilities,
345 // since the difference is the same in both cases.
346 if (logit < label_selection_input_min) {
347 continue;
348 }
349 BeamEntry& c = b->GetChild(label);
350 if (!c.Active()) {
351 // Pblank(l=abcd @ t=6) = 0
352 c.newp.blank = kLogZero<T>();
353 // If new child label is identical to beam label:
354 // Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
355 // Otherwise:
356 // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
357 beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
358 T previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
359 c.newp.label = logit - norm_offset +
360 beam_scorer_->GetStateExpansionScore(c.state, previous);
361 // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
362 c.newp.total = c.newp.label;
363
364 if (is_candidate(c.newp)) {
365 // Before adding the new node to the beam, check if the beam
366 // is already at maximum width.
367 if (leaves_.size() == beam_width_) {
368 // Bottom is no longer in the beam search. Reset
369 // its probability; signal it's no longer in the beam search.
370 BeamEntry* bottom = leaves_.peek_bottom();
371 bottom->newp.Reset();
372 }
373 leaves_.push(&c);
374 } else {
375 // Deactivate child.
376 c.oldp.Reset();
377 c.newp.Reset();
378 }
379 }
380 }
381 } // for (BeamEntry* b...
382 }
383
384 template <typename T, typename CTCBeamState, typename CTCBeamComparer>
Reset()385 void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Reset() {
386 leaves_.Reset();
387
388 // This beam root, and all of its children, will be in memory until
389 // the next reset.
390 beam_root_.reset(new BeamRoot(nullptr, -1));
391 beam_root_->RootEntry()->newp.total = T(0.0); // ln(1)
392 beam_root_->RootEntry()->newp.blank = T(0.0); // ln(1)
393
394 // Add the root as the initial leaf.
395 leaves_.push(beam_root_->RootEntry());
396
397 // Call initialize state on the root object.
398 beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
399 }
400
401 template <typename T, typename CTCBeamState, typename CTCBeamComparer>
TopPaths(int n,std::vector<std::vector<int>> * paths,std::vector<T> * log_probs,bool merge_repeated)402 Status CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::TopPaths(
403 int n, std::vector<std::vector<int>>* paths, std::vector<T>* log_probs,
404 bool merge_repeated) const {
405 CHECK_NOTNULL(paths)->clear();
406 CHECK_NOTNULL(log_probs)->clear();
407 if (n > beam_width_) {
408 return errors::InvalidArgument("requested more paths than the beam width.");
409 }
410 if (n > leaves_.size()) {
411 return errors::InvalidArgument(
412 "Less leaves in the beam search than requested.");
413 }
414
415 gtl::TopN<BeamEntry*, CTCBeamComparer> top_branches(n);
416
417 // O(beam_width_ * log(n)), space complexity is O(n)
418 for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) {
419 top_branches.push(*it);
420 }
421 // O(n * log(n))
422 std::unique_ptr<std::vector<BeamEntry*>> branches(top_branches.Extract());
423
424 for (int i = 0; i < n; ++i) {
425 BeamEntry* e((*branches)[i]);
426 paths->push_back(e->LabelSeq(merge_repeated));
427 log_probs->push_back(e->newp.total);
428 }
429 return Status::OK();
430 }
431
432 } // namespace ctc
433 } // namespace tensorflow
434
435 #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
436 // LINT.ThenChange(//tensorflow/lite/experimental/kernels/ctc_beam_search.h)
437