1 // float-weight.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Float weight set and associated semiring operation definitions.
20 //
21 
22 #ifndef FST_LIB_FLOAT_WEIGHT_H__
23 #define FST_LIB_FLOAT_WEIGHT_H__
24 
25 #include <limits>
26 #include <climits>
27 #include <sstream>
28 #include <string>
29 
30 #include <fst/util.h>
31 #include <fst/weight.h>
32 
33 
34 namespace fst {
35 
36 // numeric limits class
37 template <class T>
38 class FloatLimits {
39  public:
PosInfinity()40   static const T PosInfinity() {
41     static const T pos_infinity = numeric_limits<T>::infinity();
42     return pos_infinity;
43   }
44 
NegInfinity()45   static const T NegInfinity() {
46     static const T neg_infinity = -PosInfinity();
47     return neg_infinity;
48   }
49 
NumberBad()50   static const T NumberBad() {
51     static const T number_bad = numeric_limits<T>::quiet_NaN();
52     return number_bad;
53   }
54 
55 };
56 
57 // weight class to be templated on floating-points types
58 template <class T = float>
59 class FloatWeightTpl {
60  public:
FloatWeightTpl()61   FloatWeightTpl() {}
62 
FloatWeightTpl(T f)63   FloatWeightTpl(T f) : value_(f) {}
64 
FloatWeightTpl(const FloatWeightTpl<T> & w)65   FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {}
66 
67   FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) {
68     value_ = w.value_;
69     return *this;
70   }
71 
Read(istream & strm)72   istream &Read(istream &strm) {
73     return ReadType(strm, &value_);
74   }
75 
Write(ostream & strm)76   ostream &Write(ostream &strm) const {
77     return WriteType(strm, value_);
78   }
79 
Hash()80   size_t Hash() const {
81     union {
82       T f;
83       size_t s;
84     } u;
85     u.s = 0;
86     u.f = value_;
87     return u.s;
88   }
89 
Value()90   const T &Value() const { return value_; }
91 
92  protected:
SetValue(const T & f)93   void SetValue(const T &f) { value_ = f; }
94 
GetPrecisionString()95   inline static string GetPrecisionString() {
96     int64 size = sizeof(T);
97     if (size == sizeof(float)) return "";
98     size *= CHAR_BIT;
99 
100     string result;
101     Int64ToStr(size, &result);
102     return result;
103   }
104 
105  private:
106   T value_;
107 };
108 
109 // Single-precision float weight
110 typedef FloatWeightTpl<float> FloatWeight;
111 
112 template <class T>
113 inline bool operator==(const FloatWeightTpl<T> &w1,
114                        const FloatWeightTpl<T> &w2) {
115   // Volatile qualifier thwarts over-aggressive compiler optimizations
116   // that lead to problems esp. with NaturalLess().
117   volatile T v1 = w1.Value();
118   volatile T v2 = w2.Value();
119   return v1 == v2;
120 }
121 
122 inline bool operator==(const FloatWeightTpl<double> &w1,
123                        const FloatWeightTpl<double> &w2) {
124   return operator==<double>(w1, w2);
125 }
126 
127 inline bool operator==(const FloatWeightTpl<float> &w1,
128                        const FloatWeightTpl<float> &w2) {
129   return operator==<float>(w1, w2);
130 }
131 
132 template <class T>
133 inline bool operator!=(const FloatWeightTpl<T> &w1,
134                        const FloatWeightTpl<T> &w2) {
135   return !(w1 == w2);
136 }
137 
138 inline bool operator!=(const FloatWeightTpl<double> &w1,
139                        const FloatWeightTpl<double> &w2) {
140   return operator!=<double>(w1, w2);
141 }
142 
143 inline bool operator!=(const FloatWeightTpl<float> &w1,
144                        const FloatWeightTpl<float> &w2) {
145   return operator!=<float>(w1, w2);
146 }
147 
148 template <class T>
149 inline bool ApproxEqual(const FloatWeightTpl<T> &w1,
150                         const FloatWeightTpl<T> &w2,
151                         float delta = kDelta) {
152   return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
153 }
154 
155 template <class T>
156 inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) {
157   if (w.Value() == FloatLimits<T>::PosInfinity())
158     return strm << "Infinity";
159   else if (w.Value() == FloatLimits<T>::NegInfinity())
160     return strm << "-Infinity";
161   else if (w.Value() != w.Value())   // Fails for NaN
162     return strm << "BadNumber";
163   else
164     return strm << w.Value();
165 }
166 
167 template <class T>
168 inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) {
169   string s;
170   strm >> s;
171   if (s == "Infinity") {
172     w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity());
173   } else if (s == "-Infinity") {
174     w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity());
175   } else {
176     char *p;
177     T f = strtod(s.c_str(), &p);
178     if (p < s.c_str() + s.size())
179       strm.clear(std::ios::badbit);
180     else
181       w = FloatWeightTpl<T>(f);
182   }
183   return strm;
184 }
185 
186 
187 // Tropical semiring: (min, +, inf, 0)
188 template <class T>
189 class TropicalWeightTpl : public FloatWeightTpl<T> {
190  public:
191   using FloatWeightTpl<T>::Value;
192 
193   typedef TropicalWeightTpl<T> ReverseWeight;
194 
TropicalWeightTpl()195   TropicalWeightTpl() : FloatWeightTpl<T>() {}
196 
TropicalWeightTpl(T f)197   TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
198 
TropicalWeightTpl(const TropicalWeightTpl<T> & w)199   TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
200 
Zero()201   static const TropicalWeightTpl<T> Zero() {
202     return TropicalWeightTpl<T>(FloatLimits<T>::PosInfinity()); }
203 
One()204   static const TropicalWeightTpl<T> One() {
205     return TropicalWeightTpl<T>(0.0F); }
206 
NoWeight()207   static const TropicalWeightTpl<T> NoWeight() {
208     return TropicalWeightTpl<T>(FloatLimits<T>::NumberBad()); }
209 
Type()210   static const string &Type() {
211     static const string type = "tropical" +
212         FloatWeightTpl<T>::GetPrecisionString();
213     return type;
214   }
215 
Member()216   bool Member() const {
217     // First part fails for IEEE NaN
218     return Value() == Value() && Value() != FloatLimits<T>::NegInfinity();
219   }
220 
221   TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
222     if (Value() == FloatLimits<T>::NegInfinity() ||
223         Value() == FloatLimits<T>::PosInfinity() ||
224         Value() != Value())
225       return *this;
226     else
227       return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
228   }
229 
Reverse()230   TropicalWeightTpl<T> Reverse() const { return *this; }
231 
Properties()232   static uint64 Properties() {
233     return kLeftSemiring | kRightSemiring | kCommutative |
234         kPath | kIdempotent;
235   }
236 };
237 
238 // Single precision tropical weight
239 typedef TropicalWeightTpl<float> TropicalWeight;
240 
241 template <class T>
Plus(const TropicalWeightTpl<T> & w1,const TropicalWeightTpl<T> & w2)242 inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
243                                  const TropicalWeightTpl<T> &w2) {
244   if (!w1.Member() || !w2.Member())
245     return TropicalWeightTpl<T>::NoWeight();
246   return w1.Value() < w2.Value() ? w1 : w2;
247 }
248 
Plus(const TropicalWeightTpl<float> & w1,const TropicalWeightTpl<float> & w2)249 inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
250                                      const TropicalWeightTpl<float> &w2) {
251   return Plus<float>(w1, w2);
252 }
253 
Plus(const TropicalWeightTpl<double> & w1,const TropicalWeightTpl<double> & w2)254 inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
255                                       const TropicalWeightTpl<double> &w2) {
256   return Plus<double>(w1, w2);
257 }
258 
259 template <class T>
Times(const TropicalWeightTpl<T> & w1,const TropicalWeightTpl<T> & w2)260 inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
261                                   const TropicalWeightTpl<T> &w2) {
262   if (!w1.Member() || !w2.Member())
263     return TropicalWeightTpl<T>::NoWeight();
264   T f1 = w1.Value(), f2 = w2.Value();
265   if (f1 == FloatLimits<T>::PosInfinity())
266     return w1;
267   else if (f2 == FloatLimits<T>::PosInfinity())
268     return w2;
269   else
270     return TropicalWeightTpl<T>(f1 + f2);
271 }
272 
Times(const TropicalWeightTpl<float> & w1,const TropicalWeightTpl<float> & w2)273 inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
274                                       const TropicalWeightTpl<float> &w2) {
275   return Times<float>(w1, w2);
276 }
277 
Times(const TropicalWeightTpl<double> & w1,const TropicalWeightTpl<double> & w2)278 inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
279                                        const TropicalWeightTpl<double> &w2) {
280   return Times<double>(w1, w2);
281 }
282 
283 template <class T>
284 inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
285                                    const TropicalWeightTpl<T> &w2,
286                                    DivideType typ = DIVIDE_ANY) {
287   if (!w1.Member() || !w2.Member())
288     return TropicalWeightTpl<T>::NoWeight();
289   T f1 = w1.Value(), f2 = w2.Value();
290   if (f2 == FloatLimits<T>::PosInfinity())
291     return FloatLimits<T>::NumberBad();
292   else if (f1 == FloatLimits<T>::PosInfinity())
293     return FloatLimits<T>::PosInfinity();
294   else
295     return TropicalWeightTpl<T>(f1 - f2);
296 }
297 
298 inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
299                                        const TropicalWeightTpl<float> &w2,
300                                        DivideType typ = DIVIDE_ANY) {
301   return Divide<float>(w1, w2, typ);
302 }
303 
304 inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
305                                         const TropicalWeightTpl<double> &w2,
306                                         DivideType typ = DIVIDE_ANY) {
307   return Divide<double>(w1, w2, typ);
308 }
309 
310 
311 // Log semiring: (log(e^-x + e^y), +, inf, 0)
312 template <class T>
313 class LogWeightTpl : public FloatWeightTpl<T> {
314  public:
315   using FloatWeightTpl<T>::Value;
316 
317   typedef LogWeightTpl ReverseWeight;
318 
LogWeightTpl()319   LogWeightTpl() : FloatWeightTpl<T>() {}
320 
LogWeightTpl(T f)321   LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
322 
LogWeightTpl(const LogWeightTpl<T> & w)323   LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
324 
Zero()325   static const LogWeightTpl<T> Zero() {
326     return LogWeightTpl<T>(FloatLimits<T>::PosInfinity());
327   }
328 
One()329   static const LogWeightTpl<T> One() {
330     return LogWeightTpl<T>(0.0F);
331   }
332 
NoWeight()333   static const LogWeightTpl<T> NoWeight() {
334     return LogWeightTpl<T>(FloatLimits<T>::NumberBad()); }
335 
Type()336   static const string &Type() {
337     static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString();
338     return type;
339   }
340 
Member()341   bool Member() const {
342     // First part fails for IEEE NaN
343     return Value() == Value() && Value() != FloatLimits<T>::NegInfinity();
344   }
345 
346   LogWeightTpl<T> Quantize(float delta = kDelta) const {
347     if (Value() == FloatLimits<T>::NegInfinity() ||
348         Value() == FloatLimits<T>::PosInfinity() ||
349         Value() != Value())
350       return *this;
351     else
352       return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
353   }
354 
Reverse()355   LogWeightTpl<T> Reverse() const { return *this; }
356 
Properties()357   static uint64 Properties() {
358     return kLeftSemiring | kRightSemiring | kCommutative;
359   }
360 };
361 
362 // Single-precision log weight
363 typedef LogWeightTpl<float> LogWeight;
364 // Double-precision log weight
365 typedef LogWeightTpl<double> Log64Weight;
366 
367 template <class T>
LogExp(T x)368 inline T LogExp(T x) { return log(1.0F + exp(-x)); }
369 
370 template <class T>
Plus(const LogWeightTpl<T> & w1,const LogWeightTpl<T> & w2)371 inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
372                             const LogWeightTpl<T> &w2) {
373   T f1 = w1.Value(), f2 = w2.Value();
374   if (f1 == FloatLimits<T>::PosInfinity())
375     return w2;
376   else if (f2 == FloatLimits<T>::PosInfinity())
377     return w1;
378   else if (f1 > f2)
379     return LogWeightTpl<T>(f2 - LogExp(f1 - f2));
380   else
381     return LogWeightTpl<T>(f1 - LogExp(f2 - f1));
382 }
383 
Plus(const LogWeightTpl<float> & w1,const LogWeightTpl<float> & w2)384 inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
385                                 const LogWeightTpl<float> &w2) {
386   return Plus<float>(w1, w2);
387 }
388 
Plus(const LogWeightTpl<double> & w1,const LogWeightTpl<double> & w2)389 inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
390                                  const LogWeightTpl<double> &w2) {
391   return Plus<double>(w1, w2);
392 }
393 
394 template <class T>
Times(const LogWeightTpl<T> & w1,const LogWeightTpl<T> & w2)395 inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
396                              const LogWeightTpl<T> &w2) {
397   if (!w1.Member() || !w2.Member())
398     return LogWeightTpl<T>::NoWeight();
399   T f1 = w1.Value(), f2 = w2.Value();
400   if (f1 == FloatLimits<T>::PosInfinity())
401     return w1;
402   else if (f2 == FloatLimits<T>::PosInfinity())
403     return w2;
404   else
405     return LogWeightTpl<T>(f1 + f2);
406 }
407 
Times(const LogWeightTpl<float> & w1,const LogWeightTpl<float> & w2)408 inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
409                                  const LogWeightTpl<float> &w2) {
410   return Times<float>(w1, w2);
411 }
412 
Times(const LogWeightTpl<double> & w1,const LogWeightTpl<double> & w2)413 inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
414                                   const LogWeightTpl<double> &w2) {
415   return Times<double>(w1, w2);
416 }
417 
418 template <class T>
419 inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
420                               const LogWeightTpl<T> &w2,
421                               DivideType typ = DIVIDE_ANY) {
422   if (!w1.Member() || !w2.Member())
423     return LogWeightTpl<T>::NoWeight();
424   T f1 = w1.Value(), f2 = w2.Value();
425   if (f2 == FloatLimits<T>::PosInfinity())
426     return FloatLimits<T>::NumberBad();
427   else if (f1 == FloatLimits<T>::PosInfinity())
428     return FloatLimits<T>::PosInfinity();
429   else
430     return LogWeightTpl<T>(f1 - f2);
431 }
432 
433 inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
434                                   const LogWeightTpl<float> &w2,
435                                   DivideType typ = DIVIDE_ANY) {
436   return Divide<float>(w1, w2, typ);
437 }
438 
439 inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
440                                    const LogWeightTpl<double> &w2,
441                                    DivideType typ = DIVIDE_ANY) {
442   return Divide<double>(w1, w2, typ);
443 }
444 
445 // MinMax semiring: (min, max, inf, -inf)
446 template <class T>
447 class MinMaxWeightTpl : public FloatWeightTpl<T> {
448  public:
449   using FloatWeightTpl<T>::Value;
450 
451   typedef MinMaxWeightTpl<T> ReverseWeight;
452 
MinMaxWeightTpl()453   MinMaxWeightTpl() : FloatWeightTpl<T>() {}
454 
MinMaxWeightTpl(T f)455   MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {}
456 
MinMaxWeightTpl(const MinMaxWeightTpl<T> & w)457   MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
458 
Zero()459   static const MinMaxWeightTpl<T> Zero() {
460     return MinMaxWeightTpl<T>(FloatLimits<T>::PosInfinity());
461   }
462 
One()463   static const MinMaxWeightTpl<T> One() {
464     return MinMaxWeightTpl<T>(FloatLimits<T>::NegInfinity());
465   }
466 
NoWeight()467   static const MinMaxWeightTpl<T> NoWeight() {
468     return MinMaxWeightTpl<T>(FloatLimits<T>::NumberBad()); }
469 
Type()470   static const string &Type() {
471     static const string type = "minmax" +
472         FloatWeightTpl<T>::GetPrecisionString();
473     return type;
474   }
475 
Member()476   bool Member() const {
477     // Fails for IEEE NaN
478     return Value() == Value();
479   }
480 
481   MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
482     // If one of infinities, or a NaN
483     if (Value() == FloatLimits<T>::NegInfinity() ||
484         Value() == FloatLimits<T>::PosInfinity() ||
485         Value() != Value())
486       return *this;
487     else
488       return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
489   }
490 
Reverse()491   MinMaxWeightTpl<T> Reverse() const { return *this; }
492 
Properties()493   static uint64 Properties() {
494     return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
495   }
496 };
497 
498 // Single-precision min-max weight
499 typedef MinMaxWeightTpl<float> MinMaxWeight;
500 
501 // Min
502 template <class T>
Plus(const MinMaxWeightTpl<T> & w1,const MinMaxWeightTpl<T> & w2)503 inline MinMaxWeightTpl<T> Plus(
504     const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
505   if (!w1.Member() || !w2.Member())
506     return MinMaxWeightTpl<T>::NoWeight();
507   return w1.Value() < w2.Value() ? w1 : w2;
508 }
509 
Plus(const MinMaxWeightTpl<float> & w1,const MinMaxWeightTpl<float> & w2)510 inline MinMaxWeightTpl<float> Plus(
511     const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
512   return Plus<float>(w1, w2);
513 }
514 
Plus(const MinMaxWeightTpl<double> & w1,const MinMaxWeightTpl<double> & w2)515 inline MinMaxWeightTpl<double> Plus(
516     const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
517   return Plus<double>(w1, w2);
518 }
519 
520 // Max
521 template <class T>
Times(const MinMaxWeightTpl<T> & w1,const MinMaxWeightTpl<T> & w2)522 inline MinMaxWeightTpl<T> Times(
523     const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
524   if (!w1.Member() || !w2.Member())
525     return MinMaxWeightTpl<T>::NoWeight();
526   return w1.Value() >= w2.Value() ? w1 : w2;
527 }
528 
Times(const MinMaxWeightTpl<float> & w1,const MinMaxWeightTpl<float> & w2)529 inline MinMaxWeightTpl<float> Times(
530     const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
531   return Times<float>(w1, w2);
532 }
533 
Times(const MinMaxWeightTpl<double> & w1,const MinMaxWeightTpl<double> & w2)534 inline MinMaxWeightTpl<double> Times(
535     const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
536   return Times<double>(w1, w2);
537 }
538 
539 // Defined only for special cases
540 template <class T>
541 inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
542                                  const MinMaxWeightTpl<T> &w2,
543                                  DivideType typ = DIVIDE_ANY) {
544   if (!w1.Member() || !w2.Member())
545     return MinMaxWeightTpl<T>::NoWeight();
546   // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2
547   return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::NumberBad();
548 }
549 
550 inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
551                                      const MinMaxWeightTpl<float> &w2,
552                                      DivideType typ = DIVIDE_ANY) {
553   return Divide<float>(w1, w2, typ);
554 }
555 
556 inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
557                                       const MinMaxWeightTpl<double> &w2,
558                                       DivideType typ = DIVIDE_ANY) {
559   return Divide<double>(w1, w2, typ);
560 }
561 
562 //
563 // WEIGHT CONVERTER SPECIALIZATIONS.
564 //
565 
566 // Convert to tropical
567 template <>
568 struct WeightConvert<LogWeight, TropicalWeight> {
569   TropicalWeight operator()(LogWeight w) const { return w.Value(); }
570 };
571 
572 template <>
573 struct WeightConvert<Log64Weight, TropicalWeight> {
574   TropicalWeight operator()(Log64Weight w) const { return w.Value(); }
575 };
576 
577 // Convert to log
578 template <>
579 struct WeightConvert<TropicalWeight, LogWeight> {
580   LogWeight operator()(TropicalWeight w) const { return w.Value(); }
581 };
582 
583 template <>
584 struct WeightConvert<Log64Weight, LogWeight> {
585   LogWeight operator()(Log64Weight w) const { return w.Value(); }
586 };
587 
588 // Convert to log64
589 template <>
590 struct WeightConvert<TropicalWeight, Log64Weight> {
591   Log64Weight operator()(TropicalWeight w) const { return w.Value(); }
592 };
593 
594 template <>
595 struct WeightConvert<LogWeight, Log64Weight> {
596   Log64Weight operator()(LogWeight w) const { return w.Value(); }
597 };
598 
599 }  // namespace fst
600 
601 #endif  // FST_LIB_FLOAT_WEIGHT_H__
602