1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Copied from tensorflow/core/util/ctc/ctc_beam_entry.h 17 // TODO(b/111524997): Remove this file. 18 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ 19 #define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ 20 21 #include <algorithm> 22 #include <memory> 23 #include <unordered_map> 24 #include <vector> 25 26 #include "third_party/eigen3/Eigen/Core" 27 #include "tensorflow/lite/experimental/kernels/ctc_loss_util.h" 28 29 namespace tflite { 30 namespace experimental { 31 namespace ctc { 32 33 // The ctc_beam_search namespace holds several classes meant to be accessed only 34 // in case of extending the CTCBeamSearch decoder to allow custom scoring 35 // functions. 36 // 37 // BeamEntry is exposed through template arguments BeamScorer and BeamComparer 38 // of CTCBeamSearch (ctc_beam_search.h). 39 namespace ctc_beam_search { 40 41 struct EmptyBeamState {}; 42 43 struct BeamProbability { BeamProbabilityBeamProbability44 BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {} ResetBeamProbability45 void Reset() { 46 total = kLogZero; 47 blank = kLogZero; 48 label = kLogZero; 49 } 50 float total; 51 float blank; 52 float label; 53 }; 54 55 template <class CTCBeamState> 56 class BeamRoot; 57 58 template <class CTCBeamState = EmptyBeamState> 59 struct BeamEntry { 60 // BeamRoot<CTCBeamState>::AddEntry() serves as the factory method. 61 friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry( 62 BeamEntry<CTCBeamState>* p, int l); ActiveBeamEntry63 inline bool Active() const { return newp.total != kLogZero; } 64 // Return the child at the given index, or construct a new one in-place if 65 // none was found. GetChildBeamEntry66 BeamEntry& GetChild(int ind) { 67 auto entry = children.emplace(ind, nullptr); 68 auto& child_entry = entry.first->second; 69 // If this is a new child, populate the BeamEntry<CTCBeamState>*. 70 if (entry.second) { 71 child_entry = beam_root->AddEntry(this, ind); 72 } 73 return *child_entry; 74 } LabelSeqBeamEntry75 std::vector<int> LabelSeq(bool merge_repeated) const { 76 std::vector<int> labels; 77 int prev_label = -1; 78 const BeamEntry* c = this; 79 while (c->parent != nullptr) { // Checking c->parent to skip root leaf. 80 if (!merge_repeated || c->label != prev_label) { 81 labels.push_back(c->label); 82 } 83 prev_label = c->label; 84 c = c->parent; 85 } 86 std::reverse(labels.begin(), labels.end()); 87 return labels; 88 } 89 90 BeamEntry<CTCBeamState>* parent; 91 int label; 92 // All instances of child BeamEntry are owned by *beam_root. 93 std::unordered_map<int, BeamEntry<CTCBeamState>*> children; 94 BeamProbability oldp; 95 BeamProbability newp; 96 CTCBeamState state; 97 98 private: 99 // Constructor giving parent, label, and the beam_root. 100 // The object pointed to by p cannot be copied and should not be moved, 101 // otherwise parent will become invalid. 102 // This private constructor is only called through the factory method 103 // BeamRoot<CTCBeamState>::AddEntry(). BeamEntryBeamEntry104 BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root) 105 : parent(p), label(l), beam_root(beam_root) {} 106 BeamRoot<CTCBeamState>* beam_root; 107 108 BeamEntry(const BeamEntry&) = delete; 109 void operator=(const BeamEntry&) = delete; 110 }; 111 112 // This class owns all instances of BeamEntry. This is used to avoid recursive 113 // destructor call during destruction. 114 template <class CTCBeamState = EmptyBeamState> 115 class BeamRoot { 116 public: BeamRoot(BeamEntry<CTCBeamState> * p,int l)117 BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); } 118 BeamRoot(const BeamRoot&) = delete; 119 BeamRoot& operator=(const BeamRoot&) = delete; 120 AddEntry(BeamEntry<CTCBeamState> * p,int l)121 BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) { 122 auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this); 123 beam_entries_.emplace_back(new_entry); 124 return new_entry; 125 } RootEntry()126 BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; } 127 128 private: 129 BeamEntry<CTCBeamState>* root_entry_ = nullptr; 130 std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_; 131 }; 132 133 // BeamComparer is the default beam comparer provided in CTCBeamSearch. 134 template <class CTCBeamState = EmptyBeamState> 135 class BeamComparer { 136 public: ~BeamComparer()137 virtual ~BeamComparer() {} operator()138 virtual bool inline operator()(const BeamEntry<CTCBeamState>* a, 139 const BeamEntry<CTCBeamState>* b) const { 140 return a->newp.total > b->newp.total; 141 } 142 }; 143 144 } // namespace ctc_beam_search 145 146 } // namespace ctc 147 } // namespace experimental 148 } // namespace tflite 149 150 #endif // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ 151