1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Contains utils for reading, writing and debug printing bit streams.
16 
17 #ifndef SOURCE_COMP_HUFFMAN_CODEC_H_
18 #define SOURCE_COMP_HUFFMAN_CODEC_H_
19 
20 #include <algorithm>
21 #include <cassert>
22 #include <functional>
23 #include <iomanip>
24 #include <map>
25 #include <memory>
26 #include <ostream>
27 #include <queue>
28 #include <sstream>
29 #include <stack>
30 #include <string>
31 #include <tuple>
32 #include <unordered_map>
33 #include <utility>
34 #include <vector>
35 
36 namespace spvtools {
37 namespace comp {
38 
39 // Used to generate and apply a Huffman coding scheme.
40 // |Val| is the type of variable being encoded (for example a string or a
41 // literal).
42 template <class Val>
43 class HuffmanCodec {
44  public:
45   // Huffman tree node.
46   struct Node {
NodeNode47     Node() {}
48 
49     // Creates Node from serialization leaving weight and id undefined.
NodeNode50     Node(const Val& in_value, uint32_t in_left, uint32_t in_right)
51         : value(in_value), left(in_left), right(in_right) {}
52 
53     Val value = Val();
54     uint32_t weight = 0;
55     // Ids are issued sequentially starting from 1. Ids are used as an ordering
56     // tie-breaker, to make sure that the ordering (and resulting coding scheme)
57     // are consistent accross multiple platforms.
58     uint32_t id = 0;
59     // Handles of children.
60     uint32_t left = 0;
61     uint32_t right = 0;
62   };
63 
64   // Creates Huffman codec from a histogramm.
65   // Histogramm counts must not be zero.
HuffmanCodec(const std::map<Val,uint32_t> & hist)66   explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) {
67     if (hist.empty()) return;
68 
69     // Heuristic estimate.
70     nodes_.reserve(3 * hist.size());
71 
72     // Create NIL.
73     CreateNode();
74 
75     // The queue is sorted in ascending order by weight (or by node id if
76     // weights are equal).
77     std::vector<uint32_t> queue_vector;
78     queue_vector.reserve(hist.size());
79     std::priority_queue<uint32_t, std::vector<uint32_t>,
80                         std::function<bool(uint32_t, uint32_t)>>
81         queue(std::bind(&HuffmanCodec::LeftIsBigger, this,
82                         std::placeholders::_1, std::placeholders::_2),
83               std::move(queue_vector));
84 
85     // Put all leaves in the queue.
86     for (const auto& pair : hist) {
87       const uint32_t node = CreateNode();
88       MutableValueOf(node) = pair.first;
89       MutableWeightOf(node) = pair.second;
90       assert(WeightOf(node));
91       queue.push(node);
92     }
93 
94     // Form the tree by combining two subtrees with the least weight,
95     // and pushing the root of the new tree in the queue.
96     while (true) {
97       // We push a node at the end of each iteration, so the queue is never
98       // supposed to be empty at this point, unless there are no leaves, but
99       // that case was already handled.
100       assert(!queue.empty());
101       const uint32_t right = queue.top();
102       queue.pop();
103 
104       // If the queue is empty at this point, then the last node is
105       // the root of the complete Huffman tree.
106       if (queue.empty()) {
107         root_ = right;
108         break;
109       }
110 
111       const uint32_t left = queue.top();
112       queue.pop();
113 
114       // Combine left and right into a new tree and push it into the queue.
115       const uint32_t parent = CreateNode();
116       MutableWeightOf(parent) = WeightOf(right) + WeightOf(left);
117       MutableLeftOf(parent) = left;
118       MutableRightOf(parent) = right;
119       queue.push(parent);
120     }
121 
122     // Traverse the tree and form encoding table.
123     CreateEncodingTable();
124   }
125 
126   // Creates Huffman codec from saved tree structure.
127   // |nodes| is the list of nodes of the tree, nodes[0] being NIL.
128   // |root_handle| is the index of the root node.
HuffmanCodec(uint32_t root_handle,std::vector<Node> && nodes)129   HuffmanCodec(uint32_t root_handle, std::vector<Node>&& nodes) {
130     nodes_ = std::move(nodes);
131     assert(!nodes_.empty());
132     assert(root_handle > 0 && root_handle < nodes_.size());
133     assert(!LeftOf(0) && !RightOf(0));
134 
135     root_ = root_handle;
136 
137     // Traverse the tree and form encoding table.
138     CreateEncodingTable();
139   }
140 
141   // Serializes the codec in the following text format:
142   // (<root_handle>, {
143   //   {0, 0, 0},
144   //   {val1, left1, right1},
145   //   {val2, left2, right2},
146   //   ...
147   // })
SerializeToText(int indent_num_whitespaces)148   std::string SerializeToText(int indent_num_whitespaces) const {
149     const bool value_is_text = std::is_same<Val, std::string>::value;
150 
151     const std::string indent1 = std::string(indent_num_whitespaces, ' ');
152     const std::string indent2 = std::string(indent_num_whitespaces + 2, ' ');
153 
154     std::stringstream code;
155     code << "(" << root_ << ", {\n";
156 
157     for (const Node& node : nodes_) {
158       code << indent2 << "{";
159       if (value_is_text) code << "\"";
160       code << node.value;
161       if (value_is_text) code << "\"";
162       code << ", " << node.left << ", " << node.right << "},\n";
163     }
164 
165     code << indent1 << "})";
166 
167     return code.str();
168   }
169 
170   // Prints the Huffman tree in the following format:
171   // w------w------'x'
172   //        w------'y'
173   // Where w stands for the weight of the node.
174   // Right tree branches appear above left branches. Taking the right path
175   // adds 1 to the code, taking the left adds 0.
PrintTree(std::ostream & out)176   void PrintTree(std::ostream& out) const { PrintTreeInternal(out, root_, 0); }
177 
178   // Traverses the tree and prints the Huffman table: value, code
179   // and optionally node weight for every leaf.
180   void PrintTable(std::ostream& out, bool print_weights = true) {
181     std::queue<std::pair<uint32_t, std::string>> queue;
182     queue.emplace(root_, "");
183 
184     while (!queue.empty()) {
185       const uint32_t node = queue.front().first;
186       const std::string code = queue.front().second;
187       queue.pop();
188       if (!RightOf(node) && !LeftOf(node)) {
189         out << ValueOf(node);
190         if (print_weights) out << " " << WeightOf(node);
191         out << " " << code << std::endl;
192       } else {
193         if (LeftOf(node)) queue.emplace(LeftOf(node), code + "0");
194 
195         if (RightOf(node)) queue.emplace(RightOf(node), code + "1");
196       }
197     }
198   }
199 
200   // Returns the Huffman table. The table was built at at construction time,
201   // this function just returns a const reference.
GetEncodingTable()202   const std::unordered_map<Val, std::pair<uint64_t, size_t>>& GetEncodingTable()
203       const {
204     return encoding_table_;
205   }
206 
207   // Encodes |val| and stores its Huffman code in the lower |num_bits| of
208   // |bits|. Returns false of |val| is not in the Huffman table.
Encode(const Val & val,uint64_t * bits,size_t * num_bits)209   bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) const {
210     auto it = encoding_table_.find(val);
211     if (it == encoding_table_.end()) return false;
212     *bits = it->second.first;
213     *num_bits = it->second.second;
214     return true;
215   }
216 
217   // Reads bits one-by-one using callback |read_bit| until a match is found.
218   // Matching value is stored in |val|. Returns false if |read_bit| terminates
219   // before a code was mathced.
220   // |read_bit| has type bool func(bool* bit). When called, the next bit is
221   // stored in |bit|. |read_bit| returns false if the stream terminates
222   // prematurely.
DecodeFromStream(const std::function<bool (bool *)> & read_bit,Val * val)223   bool DecodeFromStream(const std::function<bool(bool*)>& read_bit,
224                         Val* val) const {
225     uint32_t node = root_;
226     while (true) {
227       assert(node);
228 
229       if (!RightOf(node) && !LeftOf(node)) {
230         *val = ValueOf(node);
231         return true;
232       }
233 
234       bool go_right;
235       if (!read_bit(&go_right)) return false;
236 
237       if (go_right)
238         node = RightOf(node);
239       else
240         node = LeftOf(node);
241     }
242 
243     assert(0);
244     return false;
245   }
246 
247  private:
248   // Returns value of the node referenced by |handle|.
ValueOf(uint32_t node)249   Val ValueOf(uint32_t node) const { return nodes_.at(node).value; }
250 
251   // Returns left child of |node|.
LeftOf(uint32_t node)252   uint32_t LeftOf(uint32_t node) const { return nodes_.at(node).left; }
253 
254   // Returns right child of |node|.
RightOf(uint32_t node)255   uint32_t RightOf(uint32_t node) const { return nodes_.at(node).right; }
256 
257   // Returns weight of |node|.
WeightOf(uint32_t node)258   uint32_t WeightOf(uint32_t node) const { return nodes_.at(node).weight; }
259 
260   // Returns id of |node|.
IdOf(uint32_t node)261   uint32_t IdOf(uint32_t node) const { return nodes_.at(node).id; }
262 
263   // Returns mutable reference to value of |node|.
MutableValueOf(uint32_t node)264   Val& MutableValueOf(uint32_t node) {
265     assert(node);
266     return nodes_.at(node).value;
267   }
268 
269   // Returns mutable reference to handle of left child of |node|.
MutableLeftOf(uint32_t node)270   uint32_t& MutableLeftOf(uint32_t node) {
271     assert(node);
272     return nodes_.at(node).left;
273   }
274 
275   // Returns mutable reference to handle of right child of |node|.
MutableRightOf(uint32_t node)276   uint32_t& MutableRightOf(uint32_t node) {
277     assert(node);
278     return nodes_.at(node).right;
279   }
280 
281   // Returns mutable reference to weight of |node|.
MutableWeightOf(uint32_t node)282   uint32_t& MutableWeightOf(uint32_t node) { return nodes_.at(node).weight; }
283 
284   // Returns mutable reference to id of |node|.
MutableIdOf(uint32_t node)285   uint32_t& MutableIdOf(uint32_t node) { return nodes_.at(node).id; }
286 
287   // Returns true if |left| has bigger weight than |right|. Node ids are
288   // used as tie-breaker.
LeftIsBigger(uint32_t left,uint32_t right)289   bool LeftIsBigger(uint32_t left, uint32_t right) const {
290     if (WeightOf(left) == WeightOf(right)) {
291       assert(IdOf(left) != IdOf(right));
292       return IdOf(left) > IdOf(right);
293     }
294     return WeightOf(left) > WeightOf(right);
295   }
296 
297   // Prints subtree (helper function used by PrintTree).
PrintTreeInternal(std::ostream & out,uint32_t node,size_t depth)298   void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth) const {
299     if (!node) return;
300 
301     const size_t kTextFieldWidth = 7;
302 
303     if (!RightOf(node) && !LeftOf(node)) {
304       out << ValueOf(node) << std::endl;
305     } else {
306       if (RightOf(node)) {
307         std::stringstream label;
308         label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
309               << WeightOf(RightOf(node));
310         out << label.str();
311         PrintTreeInternal(out, RightOf(node), depth + 1);
312       }
313 
314       if (LeftOf(node)) {
315         out << std::string(depth * kTextFieldWidth, ' ');
316         std::stringstream label;
317         label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
318               << WeightOf(LeftOf(node));
319         out << label.str();
320         PrintTreeInternal(out, LeftOf(node), depth + 1);
321       }
322     }
323   }
324 
325   // Traverses the Huffman tree and saves paths to the leaves as bit
326   // sequences to encoding_table_.
CreateEncodingTable()327   void CreateEncodingTable() {
328     struct Context {
329       Context(uint32_t in_node, uint64_t in_bits, size_t in_depth)
330           : node(in_node), bits(in_bits), depth(in_depth) {}
331       uint32_t node;
332       // Huffman tree depth cannot exceed 64 as histogramm counts are expected
333       // to be positive and limited by numeric_limits<uint32_t>::max().
334       // For practical applications tree depth would be much smaller than 64.
335       uint64_t bits;
336       size_t depth;
337     };
338 
339     std::queue<Context> queue;
340     queue.emplace(root_, 0, 0);
341 
342     while (!queue.empty()) {
343       const Context& context = queue.front();
344       const uint32_t node = context.node;
345       const uint64_t bits = context.bits;
346       const size_t depth = context.depth;
347       queue.pop();
348 
349       if (!RightOf(node) && !LeftOf(node)) {
350         auto insertion_result = encoding_table_.emplace(
351             ValueOf(node), std::pair<uint64_t, size_t>(bits, depth));
352         assert(insertion_result.second);
353         (void)insertion_result;
354       } else {
355         if (LeftOf(node)) queue.emplace(LeftOf(node), bits, depth + 1);
356 
357         if (RightOf(node))
358           queue.emplace(RightOf(node), bits | (1ULL << depth), depth + 1);
359       }
360     }
361   }
362 
363   // Creates new Huffman tree node and stores it in the deleter array.
CreateNode()364   uint32_t CreateNode() {
365     const uint32_t handle = static_cast<uint32_t>(nodes_.size());
366     nodes_.emplace_back(Node());
367     nodes_.back().id = next_node_id_++;
368     return handle;
369   }
370 
371   // Huffman tree root handle.
372   uint32_t root_ = 0;
373 
374   // Huffman tree deleter.
375   std::vector<Node> nodes_;
376 
377   // Encoding table value -> {bits, num_bits}.
378   // Huffman codes are expected to never exceed 64 bit length (this is in fact
379   // impossible if frequencies are stored as uint32_t).
380   std::unordered_map<Val, std::pair<uint64_t, size_t>> encoding_table_;
381 
382   // Next node id issued by CreateNode();
383   uint32_t next_node_id_ = 1;
384 };
385 
386 }  // namespace comp
387 }  // namespace spvtools
388 
389 #endif  // SOURCE_COMP_HUFFMAN_CODEC_H_
390