1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Copied from tensorflow/core/util/ctc/ctc_beam_entry.h
17 // TODO(b/111524997): Remove this file.
18 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_
19 #define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_
20 
21 #include <algorithm>
22 #include <memory>
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "third_party/eigen3/Eigen/Core"
27 #include "tensorflow/lite/experimental/kernels/ctc_loss_util.h"
28 
29 namespace tflite {
30 namespace experimental {
31 namespace ctc {
32 
33 // The ctc_beam_search namespace holds several classes meant to be accessed only
34 // in case of extending the CTCBeamSearch decoder to allow custom scoring
35 // functions.
36 //
37 // BeamEntry is exposed through template arguments BeamScorer and BeamComparer
38 // of CTCBeamSearch (ctc_beam_search.h).
39 namespace ctc_beam_search {
40 
41 struct EmptyBeamState {};
42 
43 struct BeamProbability {
BeamProbabilityBeamProbability44   BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {}
ResetBeamProbability45   void Reset() {
46     total = kLogZero;
47     blank = kLogZero;
48     label = kLogZero;
49   }
50   float total;
51   float blank;
52   float label;
53 };
54 
55 template <class CTCBeamState>
56 class BeamRoot;
57 
58 template <class CTCBeamState = EmptyBeamState>
59 struct BeamEntry {
60   // BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
61   friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry(
62       BeamEntry<CTCBeamState>* p, int l);
ActiveBeamEntry63   inline bool Active() const { return newp.total != kLogZero; }
64   // Return the child at the given index, or construct a new one in-place if
65   // none was found.
GetChildBeamEntry66   BeamEntry& GetChild(int ind) {
67     auto entry = children.emplace(ind, nullptr);
68     auto& child_entry = entry.first->second;
69     // If this is a new child, populate the BeamEntry<CTCBeamState>*.
70     if (entry.second) {
71       child_entry = beam_root->AddEntry(this, ind);
72     }
73     return *child_entry;
74   }
LabelSeqBeamEntry75   std::vector<int> LabelSeq(bool merge_repeated) const {
76     std::vector<int> labels;
77     int prev_label = -1;
78     const BeamEntry* c = this;
79     while (c->parent != nullptr) {  // Checking c->parent to skip root leaf.
80       if (!merge_repeated || c->label != prev_label) {
81         labels.push_back(c->label);
82       }
83       prev_label = c->label;
84       c = c->parent;
85     }
86     std::reverse(labels.begin(), labels.end());
87     return labels;
88   }
89 
90   BeamEntry<CTCBeamState>* parent;
91   int label;
92   // All instances of child BeamEntry are owned by *beam_root.
93   std::unordered_map<int, BeamEntry<CTCBeamState>*> children;
94   BeamProbability oldp;
95   BeamProbability newp;
96   CTCBeamState state;
97 
98  private:
99   // Constructor giving parent, label, and the beam_root.
100   // The object pointed to by p cannot be copied and should not be moved,
101   // otherwise parent will become invalid.
102   // This private constructor is only called through the factory method
103   // BeamRoot<CTCBeamState>::AddEntry().
BeamEntryBeamEntry104   BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root)
105       : parent(p), label(l), beam_root(beam_root) {}
106   BeamRoot<CTCBeamState>* beam_root;
107 
108   BeamEntry(const BeamEntry&) = delete;
109   void operator=(const BeamEntry&) = delete;
110 };
111 
112 // This class owns all instances of BeamEntry.  This is used to avoid recursive
113 // destructor call during destruction.
114 template <class CTCBeamState = EmptyBeamState>
115 class BeamRoot {
116  public:
BeamRoot(BeamEntry<CTCBeamState> * p,int l)117   BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); }
118   BeamRoot(const BeamRoot&) = delete;
119   BeamRoot& operator=(const BeamRoot&) = delete;
120 
AddEntry(BeamEntry<CTCBeamState> * p,int l)121   BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) {
122     auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this);
123     beam_entries_.emplace_back(new_entry);
124     return new_entry;
125   }
RootEntry()126   BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; }
127 
128  private:
129   BeamEntry<CTCBeamState>* root_entry_ = nullptr;
130   std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_;
131 };
132 
133 // BeamComparer is the default beam comparer provided in CTCBeamSearch.
134 template <class CTCBeamState = EmptyBeamState>
135 class BeamComparer {
136  public:
~BeamComparer()137   virtual ~BeamComparer() {}
operator()138   virtual bool inline operator()(const BeamEntry<CTCBeamState>* a,
139                                  const BeamEntry<CTCBeamState>* b) const {
140     return a->newp.total > b->newp.total;
141   }
142 };
143 
144 }  // namespace ctc_beam_search
145 
146 }  // namespace ctc
147 }  // namespace experimental
148 }  // namespace tflite
149 
150 #endif  // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_
151