1 // float-weight.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Float weight set and associated semiring operation definitions.
20 //
21
22 #ifndef FST_LIB_FLOAT_WEIGHT_H__
23 #define FST_LIB_FLOAT_WEIGHT_H__
24
25 #include <limits>
26 #include <climits>
27 #include <sstream>
28 #include <string>
29
30 #include <fst/util.h>
31 #include <fst/weight.h>
32
33
34 namespace fst {
35
36 // numeric limits class
37 template <class T>
38 class FloatLimits {
39 public:
PosInfinity()40 static const T PosInfinity() {
41 static const T pos_infinity = numeric_limits<T>::infinity();
42 return pos_infinity;
43 }
44
NegInfinity()45 static const T NegInfinity() {
46 static const T neg_infinity = -PosInfinity();
47 return neg_infinity;
48 }
49
NumberBad()50 static const T NumberBad() {
51 static const T number_bad = numeric_limits<T>::quiet_NaN();
52 return number_bad;
53 }
54
55 };
56
57 // weight class to be templated on floating-points types
58 template <class T = float>
59 class FloatWeightTpl {
60 public:
FloatWeightTpl()61 FloatWeightTpl() {}
62
FloatWeightTpl(T f)63 FloatWeightTpl(T f) : value_(f) {}
64
FloatWeightTpl(const FloatWeightTpl<T> & w)65 FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {}
66
67 FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) {
68 value_ = w.value_;
69 return *this;
70 }
71
Read(istream & strm)72 istream &Read(istream &strm) {
73 return ReadType(strm, &value_);
74 }
75
Write(ostream & strm)76 ostream &Write(ostream &strm) const {
77 return WriteType(strm, value_);
78 }
79
Hash()80 size_t Hash() const {
81 union {
82 T f;
83 size_t s;
84 } u;
85 u.s = 0;
86 u.f = value_;
87 return u.s;
88 }
89
Value()90 const T &Value() const { return value_; }
91
92 protected:
SetValue(const T & f)93 void SetValue(const T &f) { value_ = f; }
94
GetPrecisionString()95 inline static string GetPrecisionString() {
96 int64 size = sizeof(T);
97 if (size == sizeof(float)) return "";
98 size *= CHAR_BIT;
99
100 string result;
101 Int64ToStr(size, &result);
102 return result;
103 }
104
105 private:
106 T value_;
107 };
108
109 // Single-precision float weight
110 typedef FloatWeightTpl<float> FloatWeight;
111
112 template <class T>
113 inline bool operator==(const FloatWeightTpl<T> &w1,
114 const FloatWeightTpl<T> &w2) {
115 // Volatile qualifier thwarts over-aggressive compiler optimizations
116 // that lead to problems esp. with NaturalLess().
117 volatile T v1 = w1.Value();
118 volatile T v2 = w2.Value();
119 return v1 == v2;
120 }
121
122 inline bool operator==(const FloatWeightTpl<double> &w1,
123 const FloatWeightTpl<double> &w2) {
124 return operator==<double>(w1, w2);
125 }
126
127 inline bool operator==(const FloatWeightTpl<float> &w1,
128 const FloatWeightTpl<float> &w2) {
129 return operator==<float>(w1, w2);
130 }
131
132 template <class T>
133 inline bool operator!=(const FloatWeightTpl<T> &w1,
134 const FloatWeightTpl<T> &w2) {
135 return !(w1 == w2);
136 }
137
138 inline bool operator!=(const FloatWeightTpl<double> &w1,
139 const FloatWeightTpl<double> &w2) {
140 return operator!=<double>(w1, w2);
141 }
142
143 inline bool operator!=(const FloatWeightTpl<float> &w1,
144 const FloatWeightTpl<float> &w2) {
145 return operator!=<float>(w1, w2);
146 }
147
148 template <class T>
149 inline bool ApproxEqual(const FloatWeightTpl<T> &w1,
150 const FloatWeightTpl<T> &w2,
151 float delta = kDelta) {
152 return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
153 }
154
155 template <class T>
156 inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) {
157 if (w.Value() == FloatLimits<T>::PosInfinity())
158 return strm << "Infinity";
159 else if (w.Value() == FloatLimits<T>::NegInfinity())
160 return strm << "-Infinity";
161 else if (w.Value() != w.Value()) // Fails for NaN
162 return strm << "BadNumber";
163 else
164 return strm << w.Value();
165 }
166
167 template <class T>
168 inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) {
169 string s;
170 strm >> s;
171 if (s == "Infinity") {
172 w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity());
173 } else if (s == "-Infinity") {
174 w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity());
175 } else {
176 char *p;
177 T f = strtod(s.c_str(), &p);
178 if (p < s.c_str() + s.size())
179 strm.clear(std::ios::badbit);
180 else
181 w = FloatWeightTpl<T>(f);
182 }
183 return strm;
184 }
185
186
187 // Tropical semiring: (min, +, inf, 0)
188 template <class T>
189 class TropicalWeightTpl : public FloatWeightTpl<T> {
190 public:
191 using FloatWeightTpl<T>::Value;
192
193 typedef TropicalWeightTpl<T> ReverseWeight;
194
TropicalWeightTpl()195 TropicalWeightTpl() : FloatWeightTpl<T>() {}
196
TropicalWeightTpl(T f)197 TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
198
TropicalWeightTpl(const TropicalWeightTpl<T> & w)199 TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
200
Zero()201 static const TropicalWeightTpl<T> Zero() {
202 return TropicalWeightTpl<T>(FloatLimits<T>::PosInfinity()); }
203
One()204 static const TropicalWeightTpl<T> One() {
205 return TropicalWeightTpl<T>(0.0F); }
206
NoWeight()207 static const TropicalWeightTpl<T> NoWeight() {
208 return TropicalWeightTpl<T>(FloatLimits<T>::NumberBad()); }
209
Type()210 static const string &Type() {
211 static const string type = "tropical" +
212 FloatWeightTpl<T>::GetPrecisionString();
213 return type;
214 }
215
Member()216 bool Member() const {
217 // First part fails for IEEE NaN
218 return Value() == Value() && Value() != FloatLimits<T>::NegInfinity();
219 }
220
221 TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
222 if (Value() == FloatLimits<T>::NegInfinity() ||
223 Value() == FloatLimits<T>::PosInfinity() ||
224 Value() != Value())
225 return *this;
226 else
227 return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
228 }
229
Reverse()230 TropicalWeightTpl<T> Reverse() const { return *this; }
231
Properties()232 static uint64 Properties() {
233 return kLeftSemiring | kRightSemiring | kCommutative |
234 kPath | kIdempotent;
235 }
236 };
237
238 // Single precision tropical weight
239 typedef TropicalWeightTpl<float> TropicalWeight;
240
241 template <class T>
Plus(const TropicalWeightTpl<T> & w1,const TropicalWeightTpl<T> & w2)242 inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
243 const TropicalWeightTpl<T> &w2) {
244 if (!w1.Member() || !w2.Member())
245 return TropicalWeightTpl<T>::NoWeight();
246 return w1.Value() < w2.Value() ? w1 : w2;
247 }
248
Plus(const TropicalWeightTpl<float> & w1,const TropicalWeightTpl<float> & w2)249 inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
250 const TropicalWeightTpl<float> &w2) {
251 return Plus<float>(w1, w2);
252 }
253
Plus(const TropicalWeightTpl<double> & w1,const TropicalWeightTpl<double> & w2)254 inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
255 const TropicalWeightTpl<double> &w2) {
256 return Plus<double>(w1, w2);
257 }
258
259 template <class T>
Times(const TropicalWeightTpl<T> & w1,const TropicalWeightTpl<T> & w2)260 inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
261 const TropicalWeightTpl<T> &w2) {
262 if (!w1.Member() || !w2.Member())
263 return TropicalWeightTpl<T>::NoWeight();
264 T f1 = w1.Value(), f2 = w2.Value();
265 if (f1 == FloatLimits<T>::PosInfinity())
266 return w1;
267 else if (f2 == FloatLimits<T>::PosInfinity())
268 return w2;
269 else
270 return TropicalWeightTpl<T>(f1 + f2);
271 }
272
Times(const TropicalWeightTpl<float> & w1,const TropicalWeightTpl<float> & w2)273 inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
274 const TropicalWeightTpl<float> &w2) {
275 return Times<float>(w1, w2);
276 }
277
Times(const TropicalWeightTpl<double> & w1,const TropicalWeightTpl<double> & w2)278 inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
279 const TropicalWeightTpl<double> &w2) {
280 return Times<double>(w1, w2);
281 }
282
283 template <class T>
284 inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
285 const TropicalWeightTpl<T> &w2,
286 DivideType typ = DIVIDE_ANY) {
287 if (!w1.Member() || !w2.Member())
288 return TropicalWeightTpl<T>::NoWeight();
289 T f1 = w1.Value(), f2 = w2.Value();
290 if (f2 == FloatLimits<T>::PosInfinity())
291 return FloatLimits<T>::NumberBad();
292 else if (f1 == FloatLimits<T>::PosInfinity())
293 return FloatLimits<T>::PosInfinity();
294 else
295 return TropicalWeightTpl<T>(f1 - f2);
296 }
297
298 inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
299 const TropicalWeightTpl<float> &w2,
300 DivideType typ = DIVIDE_ANY) {
301 return Divide<float>(w1, w2, typ);
302 }
303
304 inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
305 const TropicalWeightTpl<double> &w2,
306 DivideType typ = DIVIDE_ANY) {
307 return Divide<double>(w1, w2, typ);
308 }
309
310
311 // Log semiring: (log(e^-x + e^y), +, inf, 0)
312 template <class T>
313 class LogWeightTpl : public FloatWeightTpl<T> {
314 public:
315 using FloatWeightTpl<T>::Value;
316
317 typedef LogWeightTpl ReverseWeight;
318
LogWeightTpl()319 LogWeightTpl() : FloatWeightTpl<T>() {}
320
LogWeightTpl(T f)321 LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
322
LogWeightTpl(const LogWeightTpl<T> & w)323 LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
324
Zero()325 static const LogWeightTpl<T> Zero() {
326 return LogWeightTpl<T>(FloatLimits<T>::PosInfinity());
327 }
328
One()329 static const LogWeightTpl<T> One() {
330 return LogWeightTpl<T>(0.0F);
331 }
332
NoWeight()333 static const LogWeightTpl<T> NoWeight() {
334 return LogWeightTpl<T>(FloatLimits<T>::NumberBad()); }
335
Type()336 static const string &Type() {
337 static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString();
338 return type;
339 }
340
Member()341 bool Member() const {
342 // First part fails for IEEE NaN
343 return Value() == Value() && Value() != FloatLimits<T>::NegInfinity();
344 }
345
346 LogWeightTpl<T> Quantize(float delta = kDelta) const {
347 if (Value() == FloatLimits<T>::NegInfinity() ||
348 Value() == FloatLimits<T>::PosInfinity() ||
349 Value() != Value())
350 return *this;
351 else
352 return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
353 }
354
Reverse()355 LogWeightTpl<T> Reverse() const { return *this; }
356
Properties()357 static uint64 Properties() {
358 return kLeftSemiring | kRightSemiring | kCommutative;
359 }
360 };
361
362 // Single-precision log weight
363 typedef LogWeightTpl<float> LogWeight;
364 // Double-precision log weight
365 typedef LogWeightTpl<double> Log64Weight;
366
367 template <class T>
LogExp(T x)368 inline T LogExp(T x) { return log(1.0F + exp(-x)); }
369
370 template <class T>
Plus(const LogWeightTpl<T> & w1,const LogWeightTpl<T> & w2)371 inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
372 const LogWeightTpl<T> &w2) {
373 T f1 = w1.Value(), f2 = w2.Value();
374 if (f1 == FloatLimits<T>::PosInfinity())
375 return w2;
376 else if (f2 == FloatLimits<T>::PosInfinity())
377 return w1;
378 else if (f1 > f2)
379 return LogWeightTpl<T>(f2 - LogExp(f1 - f2));
380 else
381 return LogWeightTpl<T>(f1 - LogExp(f2 - f1));
382 }
383
Plus(const LogWeightTpl<float> & w1,const LogWeightTpl<float> & w2)384 inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
385 const LogWeightTpl<float> &w2) {
386 return Plus<float>(w1, w2);
387 }
388
Plus(const LogWeightTpl<double> & w1,const LogWeightTpl<double> & w2)389 inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
390 const LogWeightTpl<double> &w2) {
391 return Plus<double>(w1, w2);
392 }
393
394 template <class T>
Times(const LogWeightTpl<T> & w1,const LogWeightTpl<T> & w2)395 inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
396 const LogWeightTpl<T> &w2) {
397 if (!w1.Member() || !w2.Member())
398 return LogWeightTpl<T>::NoWeight();
399 T f1 = w1.Value(), f2 = w2.Value();
400 if (f1 == FloatLimits<T>::PosInfinity())
401 return w1;
402 else if (f2 == FloatLimits<T>::PosInfinity())
403 return w2;
404 else
405 return LogWeightTpl<T>(f1 + f2);
406 }
407
Times(const LogWeightTpl<float> & w1,const LogWeightTpl<float> & w2)408 inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
409 const LogWeightTpl<float> &w2) {
410 return Times<float>(w1, w2);
411 }
412
Times(const LogWeightTpl<double> & w1,const LogWeightTpl<double> & w2)413 inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
414 const LogWeightTpl<double> &w2) {
415 return Times<double>(w1, w2);
416 }
417
418 template <class T>
419 inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
420 const LogWeightTpl<T> &w2,
421 DivideType typ = DIVIDE_ANY) {
422 if (!w1.Member() || !w2.Member())
423 return LogWeightTpl<T>::NoWeight();
424 T f1 = w1.Value(), f2 = w2.Value();
425 if (f2 == FloatLimits<T>::PosInfinity())
426 return FloatLimits<T>::NumberBad();
427 else if (f1 == FloatLimits<T>::PosInfinity())
428 return FloatLimits<T>::PosInfinity();
429 else
430 return LogWeightTpl<T>(f1 - f2);
431 }
432
433 inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
434 const LogWeightTpl<float> &w2,
435 DivideType typ = DIVIDE_ANY) {
436 return Divide<float>(w1, w2, typ);
437 }
438
439 inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
440 const LogWeightTpl<double> &w2,
441 DivideType typ = DIVIDE_ANY) {
442 return Divide<double>(w1, w2, typ);
443 }
444
445 // MinMax semiring: (min, max, inf, -inf)
446 template <class T>
447 class MinMaxWeightTpl : public FloatWeightTpl<T> {
448 public:
449 using FloatWeightTpl<T>::Value;
450
451 typedef MinMaxWeightTpl<T> ReverseWeight;
452
MinMaxWeightTpl()453 MinMaxWeightTpl() : FloatWeightTpl<T>() {}
454
MinMaxWeightTpl(T f)455 MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {}
456
MinMaxWeightTpl(const MinMaxWeightTpl<T> & w)457 MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
458
Zero()459 static const MinMaxWeightTpl<T> Zero() {
460 return MinMaxWeightTpl<T>(FloatLimits<T>::PosInfinity());
461 }
462
One()463 static const MinMaxWeightTpl<T> One() {
464 return MinMaxWeightTpl<T>(FloatLimits<T>::NegInfinity());
465 }
466
NoWeight()467 static const MinMaxWeightTpl<T> NoWeight() {
468 return MinMaxWeightTpl<T>(FloatLimits<T>::NumberBad()); }
469
Type()470 static const string &Type() {
471 static const string type = "minmax" +
472 FloatWeightTpl<T>::GetPrecisionString();
473 return type;
474 }
475
Member()476 bool Member() const {
477 // Fails for IEEE NaN
478 return Value() == Value();
479 }
480
481 MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
482 // If one of infinities, or a NaN
483 if (Value() == FloatLimits<T>::NegInfinity() ||
484 Value() == FloatLimits<T>::PosInfinity() ||
485 Value() != Value())
486 return *this;
487 else
488 return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
489 }
490
Reverse()491 MinMaxWeightTpl<T> Reverse() const { return *this; }
492
Properties()493 static uint64 Properties() {
494 return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
495 }
496 };
497
498 // Single-precision min-max weight
499 typedef MinMaxWeightTpl<float> MinMaxWeight;
500
501 // Min
502 template <class T>
Plus(const MinMaxWeightTpl<T> & w1,const MinMaxWeightTpl<T> & w2)503 inline MinMaxWeightTpl<T> Plus(
504 const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
505 if (!w1.Member() || !w2.Member())
506 return MinMaxWeightTpl<T>::NoWeight();
507 return w1.Value() < w2.Value() ? w1 : w2;
508 }
509
Plus(const MinMaxWeightTpl<float> & w1,const MinMaxWeightTpl<float> & w2)510 inline MinMaxWeightTpl<float> Plus(
511 const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
512 return Plus<float>(w1, w2);
513 }
514
Plus(const MinMaxWeightTpl<double> & w1,const MinMaxWeightTpl<double> & w2)515 inline MinMaxWeightTpl<double> Plus(
516 const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
517 return Plus<double>(w1, w2);
518 }
519
520 // Max
521 template <class T>
Times(const MinMaxWeightTpl<T> & w1,const MinMaxWeightTpl<T> & w2)522 inline MinMaxWeightTpl<T> Times(
523 const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
524 if (!w1.Member() || !w2.Member())
525 return MinMaxWeightTpl<T>::NoWeight();
526 return w1.Value() >= w2.Value() ? w1 : w2;
527 }
528
Times(const MinMaxWeightTpl<float> & w1,const MinMaxWeightTpl<float> & w2)529 inline MinMaxWeightTpl<float> Times(
530 const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
531 return Times<float>(w1, w2);
532 }
533
Times(const MinMaxWeightTpl<double> & w1,const MinMaxWeightTpl<double> & w2)534 inline MinMaxWeightTpl<double> Times(
535 const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
536 return Times<double>(w1, w2);
537 }
538
539 // Defined only for special cases
540 template <class T>
541 inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
542 const MinMaxWeightTpl<T> &w2,
543 DivideType typ = DIVIDE_ANY) {
544 if (!w1.Member() || !w2.Member())
545 return MinMaxWeightTpl<T>::NoWeight();
546 // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2
547 return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::NumberBad();
548 }
549
550 inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
551 const MinMaxWeightTpl<float> &w2,
552 DivideType typ = DIVIDE_ANY) {
553 return Divide<float>(w1, w2, typ);
554 }
555
556 inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
557 const MinMaxWeightTpl<double> &w2,
558 DivideType typ = DIVIDE_ANY) {
559 return Divide<double>(w1, w2, typ);
560 }
561
562 //
563 // WEIGHT CONVERTER SPECIALIZATIONS.
564 //
565
566 // Convert to tropical
567 template <>
568 struct WeightConvert<LogWeight, TropicalWeight> {
569 TropicalWeight operator()(LogWeight w) const { return w.Value(); }
570 };
571
572 template <>
573 struct WeightConvert<Log64Weight, TropicalWeight> {
574 TropicalWeight operator()(Log64Weight w) const { return w.Value(); }
575 };
576
577 // Convert to log
578 template <>
579 struct WeightConvert<TropicalWeight, LogWeight> {
580 LogWeight operator()(TropicalWeight w) const { return w.Value(); }
581 };
582
583 template <>
584 struct WeightConvert<Log64Weight, LogWeight> {
585 LogWeight operator()(Log64Weight w) const { return w.Value(); }
586 };
587
588 // Convert to log64
589 template <>
590 struct WeightConvert<TropicalWeight, Log64Weight> {
591 Log64Weight operator()(TropicalWeight w) const { return w.Value(); }
592 };
593
594 template <>
595 struct WeightConvert<LogWeight, Log64Weight> {
596 Log64Weight operator()(LogWeight w) const { return w.Value(); }
597 };
598
599 } // namespace fst
600
601 #endif // FST_LIB_FLOAT_WEIGHT_H__
602