1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // LINT.IfChange
16
17 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
18 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
19
20 #include <algorithm>
21 #include <cmath>
22 #include <limits>
23 #include <memory>
24 #include <vector>
25
26 #include "third_party/eigen3/Eigen/Core"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/top_n.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/ctc/ctc_beam_entry.h"
34 #include "tensorflow/core/util/ctc/ctc_beam_scorer.h"
35 #include "tensorflow/core/util/ctc/ctc_decoder.h"
36 #include "tensorflow/core/util/ctc/ctc_loss_util.h"
37
38 namespace tensorflow {
39 namespace ctc {
40
41 template <typename CTCBeamState = ctc_beam_search::EmptyBeamState,
42 typename CTCBeamComparer =
43 ctc_beam_search::BeamComparer<CTCBeamState>>
44 class CTCBeamSearchDecoder : public CTCDecoder {
45 // Beam Search
46 //
47 // Example (GravesTh Fig. 7.5):
48 // a -
49 // P = [ 0.3 0.7 ] t = 0
50 // [ 0.4 0.6 ] t = 1
51 //
52 // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42
53 // P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58
54 //
55 // In this case, Best Path decoding is suboptimal.
56 //
57 // For Beam Search, we use the following main recurrence relations:
58 //
59 // Relation 1:
60 // ---------------------------------------------------------- Eq. 1
61 // P(l=abcd @ t=7) = P(l=abc @ t=6) * P(d @ 7)
62 // + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7))
63 // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and
64 // updated recursively in the beam entry.
65 //
66 // Relation 2:
67 // ---------------------------------------------------------- Eq. 2
68 // P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3)
69 // for ? in a, b, d, ..., (not including c or the blank index),
70 // and the recurrence starts from the beam entry for P(l=abc @ t=2).
71 //
72 // For this case, the length of the new sequence equals t+1 (t
73 // starts at 0). This special case can be calculated as:
74 // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
75 // but we calculate it recursively for speed purposes.
76 typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
77 typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
78 typedef ctc_beam_search::BeamProbability BeamProbability;
79
80 public:
81 typedef BaseBeamScorer<CTCBeamState> DefaultBeamScorer;
82
83 // The beam search decoder is constructed specifying the beam_width (number of
84 // candidates to keep at each decoding timestep) and a beam scorer (used for
85 // custom scoring, for example enabling the use of a language model).
86 // The ownership of the scorer remains with the caller. The default
87 // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the
88 // standard beam search.
89 CTCBeamSearchDecoder(int num_classes, int beam_width,
90 BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1,
91 bool merge_repeated = false)
CTCDecoder(num_classes,batch_size,merge_repeated)92 : CTCDecoder(num_classes, batch_size, merge_repeated),
93 beam_width_(beam_width),
94 leaves_(beam_width),
95 beam_scorer_(CHECK_NOTNULL(scorer)) {
96 Reset();
97 }
98
~CTCBeamSearchDecoder()99 ~CTCBeamSearchDecoder() override {}
100
101 // Run the hibernating beam search algorithm on the given input.
102 Status Decode(const CTCDecoder::SequenceLength& seq_len,
103 const std::vector<CTCDecoder::Input>& input,
104 std::vector<CTCDecoder::Output>* output,
105 CTCDecoder::ScoreOutput* scores) override;
106
107 // Calculate the next step of the beam search and update the internal state.
108 template <typename Vector>
109 void Step(const Vector& log_input_t);
110
111 template <typename Vector>
112 float GetTopK(const int K, const Vector& input,
113 std::vector<float>* top_k_logits,
114 std::vector<int>* top_k_indices);
115
116 // Retrieve the beam scorer instance used during decoding.
GetBeamScorer()117 BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
118
119 // Set label selection parameters for faster decoding.
120 // See comments for label_selection_size_ and label_selection_margin_.
SetLabelSelectionParameters(int label_selection_size,float label_selection_margin)121 void SetLabelSelectionParameters(int label_selection_size,
122 float label_selection_margin) {
123 label_selection_size_ = label_selection_size;
124 label_selection_margin_ = label_selection_margin;
125 }
126
127 // Reset the beam search
128 void Reset();
129
130 // Extract the top n paths at current time step
131 Status TopPaths(int n, std::vector<std::vector<int>>* paths,
132 std::vector<float>* log_probs, bool merge_repeated) const;
133
134 private:
135 int beam_width_;
136
137 // Label selection is designed to avoid possibly very expensive scorer calls,
138 // by pruning the hypotheses based on the input alone.
139 // Label selection size controls how many items in each beam are passed
140 // through to the beam scorer. Only items with top N input scores are
141 // considered.
142 // Label selection margin controls the difference between minimal input score
143 // (versus the best scoring label) for an item to be passed to the beam
144 // scorer. This margin is expressed in terms of log-probability.
145 // Default is to do no label selection.
146 // For more detail: https://research.google.com/pubs/pub44823.html
147 int label_selection_size_ = 0; // zero means unlimited
148 float label_selection_margin_ = -1; // -1 means unlimited.
149
150 gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
151 std::unique_ptr<BeamRoot> beam_root_;
152 BaseBeamScorer<CTCBeamState>* beam_scorer_;
153
154 TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
155 };
156
157 template <typename CTCBeamState, typename CTCBeamComparer>
Decode(const CTCDecoder::SequenceLength & seq_len,const std::vector<CTCDecoder::Input> & input,std::vector<CTCDecoder::Output> * output,ScoreOutput * scores)158 Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
159 const CTCDecoder::SequenceLength& seq_len,
160 const std::vector<CTCDecoder::Input>& input,
161 std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) {
162 // Storage for top paths.
163 std::vector<std::vector<int>> beams;
164 std::vector<float> beam_log_probabilities;
165 int top_n = output->size();
166 if (std::any_of(output->begin(), output->end(),
167 [this](const CTCDecoder::Output& output) -> bool {
168 return output.size() < this->batch_size_;
169 })) {
170 return errors::InvalidArgument(
171 "output needs to be of size at least (top_n, batch_size).");
172 }
173 if (scores->rows() < batch_size_ || scores->cols() < top_n) {
174 return errors::InvalidArgument(
175 "scores needs to be of size at least (batch_size, top_n).");
176 }
177
178 for (int b = 0; b < batch_size_; ++b) {
179 int seq_len_b = seq_len[b];
180 Reset();
181
182 for (int t = 0; t < seq_len_b; ++t) {
183 // Pass log-probabilities for this example + time.
184 Step(input[t].row(b));
185 } // for (int t...
186
187 // O(n * log(n))
188 std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
189 leaves_.Reset();
190 for (int i = 0; i < branches->size(); ++i) {
191 BeamEntry* entry = (*branches)[i];
192 beam_scorer_->ExpandStateEnd(&entry->state);
193 entry->newp.total +=
194 beam_scorer_->GetStateEndExpansionScore(entry->state);
195 leaves_.push(entry);
196 }
197
198 Status status =
199 TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_);
200 if (!status.ok()) {
201 return status;
202 }
203
204 CHECK_EQ(top_n, beam_log_probabilities.size());
205 CHECK_EQ(beams.size(), beam_log_probabilities.size());
206
207 for (int i = 0; i < top_n; ++i) {
208 // Copy output to the correct beam + batch
209 (*output)[i][b].swap(beams[i]);
210 (*scores)(b, i) = -beam_log_probabilities[i];
211 }
212 } // for (int b...
213 return Status::OK();
214 }
215
216 template <typename CTCBeamState, typename CTCBeamComparer>
217 template <typename Vector>
GetTopK(const int K,const Vector & input,std::vector<float> * top_k_logits,std::vector<int> * top_k_indices)218 float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
219 const int K, const Vector& input, std::vector<float>* top_k_logits,
220 std::vector<int>* top_k_indices) {
221 // Find Top K choices, complexity nk in worst case. The array input is read
222 // just once.
223 CHECK_EQ(num_classes_, input.size());
224 top_k_logits->clear();
225 top_k_indices->clear();
226 top_k_logits->resize(K, -INFINITY);
227 top_k_indices->resize(K, -1);
228 for (int j = 0; j < num_classes_ - 1; ++j) {
229 const float logit = input(j);
230 if (logit > (*top_k_logits)[K - 1]) {
231 int k = K - 1;
232 while (k > 0 && logit > (*top_k_logits)[k - 1]) {
233 (*top_k_logits)[k] = (*top_k_logits)[k - 1];
234 (*top_k_indices)[k] = (*top_k_indices)[k - 1];
235 k--;
236 }
237 (*top_k_logits)[k] = logit;
238 (*top_k_indices)[k] = j;
239 }
240 }
241 // Return max value which is in 0th index or blank character logit
242 return std::max((*top_k_logits)[0], input(num_classes_ - 1));
243 }
244
245 template <typename CTCBeamState, typename CTCBeamComparer>
246 template <typename Vector>
Step(const Vector & raw_input)247 void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
248 const Vector& raw_input) {
249 std::vector<float> top_k_logits;
250 std::vector<int> top_k_indices;
251 const bool top_k =
252 (label_selection_size_ > 0 && label_selection_size_ < raw_input.size());
253 // Number of character classes to consider in each step.
254 const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1);
255 // Get max coefficient and remove it from raw_input later.
256 float max_coeff;
257 if (top_k) {
258 max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits,
259 &top_k_indices);
260 } else {
261 max_coeff = raw_input.maxCoeff();
262 }
263
264 // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
265 float logsumexp = 0.0;
266 for (int j = 0; j < raw_input.size(); ++j) {
267 logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
268 }
269 logsumexp = Eigen::numext::log(logsumexp);
270 // Final normalization offset to get correct log probabilities.
271 float norm_offset = max_coeff + logsumexp;
272
273 const float label_selection_input_min =
274 (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
275 : -std::numeric_limits<float>::infinity();
276
277 // Extract the beams sorted in decreasing new probability
278 CHECK_EQ(num_classes_, raw_input.size());
279
280 std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
281 leaves_.Reset();
282
283 for (BeamEntry* b : *branches) {
284 // P(.. @ t) becomes the new P(.. @ t-1)
285 b->oldp = b->newp;
286 }
287
288 for (BeamEntry* b : *branches) {
289 if (b->parent != nullptr) { // if not the root
290 if (b->parent->Active()) {
291 // If last two sequence characters are identical:
292 // Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5)
293 // + Pblank(l=ac @ t=5))
294 // else:
295 // Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5)
296 // + P(l=ab @ t=5))
297 float previous = (b->label == b->parent->label) ? b->parent->oldp.blank
298 : b->parent->oldp.total;
299 b->newp.label =
300 LogSumExp(b->newp.label,
301 beam_scorer_->GetStateExpansionScore(b->state, previous));
302 }
303 // Plabel(l=abc @ t=6) *= P(c @ 6)
304 b->newp.label += raw_input(b->label) - norm_offset;
305 }
306 // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
307 b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
308 // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
309 b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
310
311 // Push the entry back to the top paths list.
312 // Note, this will always fill leaves back up in sorted order.
313 leaves_.push(b);
314 }
315
316 // we need to resort branches in descending oldp order.
317
318 // branches is in descending oldp order because it was
319 // originally in descending newp order and we copied newp to oldp.
320
321 // Grow new leaves
322 for (BeamEntry* b : *branches) {
323 // A new leaf (represented by its BeamProbability) is a candidate
324 // iff its total probability is nonzero and either the beam list
325 // isn't full, or the lowest probability entry in the beam has a
326 // lower probability than the leaf.
327 auto is_candidate = [this](const BeamProbability& prob) {
328 return (prob.total > kLogZero &&
329 (leaves_.size() < beam_width_ ||
330 prob.total > leaves_.peek_bottom()->newp.total));
331 };
332
333 if (!is_candidate(b->oldp)) {
334 continue;
335 }
336
337 for (int ind = 0; ind < max_classes; ind++) {
338 const int label = top_k ? top_k_indices[ind] : ind;
339 const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
340 // Perform label selection: if input for this label looks very
341 // unpromising, never evaluate it with a scorer.
342 // We may compare logits instead of log probabilities,
343 // since the difference is the same in both cases.
344 if (logit < label_selection_input_min) {
345 continue;
346 }
347 BeamEntry& c = b->GetChild(label);
348 if (!c.Active()) {
349 // Pblank(l=abcd @ t=6) = 0
350 c.newp.blank = kLogZero;
351 // If new child label is identical to beam label:
352 // Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
353 // Otherwise:
354 // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
355 beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
356 float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
357 c.newp.label = logit - norm_offset +
358 beam_scorer_->GetStateExpansionScore(c.state, previous);
359 // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
360 c.newp.total = c.newp.label;
361
362 if (is_candidate(c.newp)) {
363 // Before adding the new node to the beam, check if the beam
364 // is already at maximum width.
365 if (leaves_.size() == beam_width_) {
366 // Bottom is no longer in the beam search. Reset
367 // its probability; signal it's no longer in the beam search.
368 BeamEntry* bottom = leaves_.peek_bottom();
369 bottom->newp.Reset();
370 }
371 leaves_.push(&c);
372 } else {
373 // Deactivate child.
374 c.oldp.Reset();
375 c.newp.Reset();
376 }
377 }
378 }
379 } // for (BeamEntry* b...
380 }
381
382 template <typename CTCBeamState, typename CTCBeamComparer>
Reset()383 void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
384 leaves_.Reset();
385
386 // This beam root, and all of its children, will be in memory until
387 // the next reset.
388 beam_root_.reset(new BeamRoot(nullptr, -1));
389 beam_root_->RootEntry()->newp.total = 0.0; // ln(1)
390 beam_root_->RootEntry()->newp.blank = 0.0; // ln(1)
391
392 // Add the root as the initial leaf.
393 leaves_.push(beam_root_->RootEntry());
394
395 // Call initialize state on the root object.
396 beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
397 }
398
399 template <typename CTCBeamState, typename CTCBeamComparer>
TopPaths(int n,std::vector<std::vector<int>> * paths,std::vector<float> * log_probs,bool merge_repeated)400 Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
401 int n, std::vector<std::vector<int>>* paths, std::vector<float>* log_probs,
402 bool merge_repeated) const {
403 CHECK_NOTNULL(paths)->clear();
404 CHECK_NOTNULL(log_probs)->clear();
405 if (n > beam_width_) {
406 return errors::InvalidArgument("requested more paths than the beam width.");
407 }
408 if (n > leaves_.size()) {
409 return errors::InvalidArgument(
410 "Less leaves in the beam search than requested.");
411 }
412
413 gtl::TopN<BeamEntry*, CTCBeamComparer> top_branches(n);
414
415 // O(beam_width_ * log(n)), space complexity is O(n)
416 for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) {
417 top_branches.push(*it);
418 }
419 // O(n * log(n))
420 std::unique_ptr<std::vector<BeamEntry*>> branches(top_branches.Extract());
421
422 for (int i = 0; i < n; ++i) {
423 BeamEntry* e((*branches)[i]);
424 paths->push_back(e->LabelSeq(merge_repeated));
425 log_probs->push_back(e->newp.total);
426 }
427 return Status::OK();
428 }
429
430 } // namespace ctc
431 } // namespace tensorflow
432
433 #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
434 // LINT.ThenChange(//tensorflow/lite/experimental/kernels/ctc_beam_search.h)
435