1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5
6 http://www.apache.org/licenses/LICENSE-2.0
7
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14
15 #include "tensorflow/core/platform/numbers.h"
16
17 #include <ctype.h>
18 #include <float.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21
22 #include <algorithm>
23 #include <cinttypes>
24 #include <cmath>
25 #include <locale>
26 #include <unordered_map>
27
28 #include "double-conversion/double-conversion.h"
29 #include "tensorflow/core/platform/str_util.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/stringprintf.h"
33 #include "tensorflow/core/platform/types.h"
34
35 namespace tensorflow {
36
37 namespace {
38
39 template <typename T>
GetSpecialNumsSingleton()40 const std::unordered_map<string, T>* GetSpecialNumsSingleton() {
41 static const std::unordered_map<string, T>* special_nums =
42 CHECK_NOTNULL((new const std::unordered_map<string, T>{
43 {"inf", std::numeric_limits<T>::infinity()},
44 {"+inf", std::numeric_limits<T>::infinity()},
45 {"-inf", -std::numeric_limits<T>::infinity()},
46 {"infinity", std::numeric_limits<T>::infinity()},
47 {"+infinity", std::numeric_limits<T>::infinity()},
48 {"-infinity", -std::numeric_limits<T>::infinity()},
49 {"nan", std::numeric_limits<T>::quiet_NaN()},
50 {"+nan", std::numeric_limits<T>::quiet_NaN()},
51 {"-nan", -std::numeric_limits<T>::quiet_NaN()},
52 }));
53 return special_nums;
54 }
55
56 template <typename T>
locale_independent_strtonum(const char * str,const char ** endptr)57 T locale_independent_strtonum(const char* str, const char** endptr) {
58 auto special_nums = GetSpecialNumsSingleton<T>();
59 std::stringstream s(str);
60
61 // Check if str is one of the special numbers.
62 string special_num_str;
63 s >> special_num_str;
64
65 for (size_t i = 0; i < special_num_str.length(); ++i) {
66 special_num_str[i] =
67 std::tolower(special_num_str[i], std::locale::classic());
68 }
69
70 auto entry = special_nums->find(special_num_str);
71 if (entry != special_nums->end()) {
72 *endptr = str + (s.eof() ? static_cast<std::iostream::pos_type>(strlen(str))
73 : s.tellg());
74 return entry->second;
75 } else {
76 // Perhaps it's a hex number
77 if (special_num_str.compare(0, 2, "0x") == 0 ||
78 special_num_str.compare(0, 3, "-0x") == 0) {
79 return strtol(str, const_cast<char**>(endptr), 16);
80 }
81 }
82 // Reset the stream
83 s.str(str);
84 s.clear();
85 // Use the "C" locale
86 s.imbue(std::locale::classic());
87
88 T result;
89 s >> result;
90
91 // Set to result to what strto{f,d} functions would have returned. If the
92 // number was outside the range, the stringstream sets the fail flag, but
93 // returns the +/-max() value, whereas strto{f,d} functions return +/-INF.
94 if (s.fail()) {
95 if (result == std::numeric_limits<T>::max() ||
96 result == std::numeric_limits<T>::infinity()) {
97 result = std::numeric_limits<T>::infinity();
98 s.clear(s.rdstate() & ~std::ios::failbit);
99 } else if (result == -std::numeric_limits<T>::max() ||
100 result == -std::numeric_limits<T>::infinity()) {
101 result = -std::numeric_limits<T>::infinity();
102 s.clear(s.rdstate() & ~std::ios::failbit);
103 }
104 }
105
106 if (endptr) {
107 *endptr =
108 str +
109 (s.fail() ? static_cast<std::iostream::pos_type>(0)
110 : (s.eof() ? static_cast<std::iostream::pos_type>(strlen(str))
111 : s.tellg()));
112 }
113 return result;
114 }
115
116 static inline const double_conversion::StringToDoubleConverter&
StringToFloatConverter()117 StringToFloatConverter() {
118 static const double_conversion::StringToDoubleConverter converter(
119 double_conversion::StringToDoubleConverter::ALLOW_LEADING_SPACES |
120 double_conversion::StringToDoubleConverter::ALLOW_HEX |
121 double_conversion::StringToDoubleConverter::ALLOW_TRAILING_SPACES |
122 double_conversion::StringToDoubleConverter::ALLOW_CASE_INSENSIBILITY,
123 0., 0., "inf", "nan");
124 return converter;
125 }
126
127 } // namespace
128
129 namespace strings {
130
FastInt32ToBufferLeft(int32 i,char * buffer)131 size_t FastInt32ToBufferLeft(int32 i, char* buffer) {
132 uint32 u = i;
133 size_t length = 0;
134 if (i < 0) {
135 *buffer++ = '-';
136 ++length;
137 // We need to do the negation in modular (i.e., "unsigned")
138 // arithmetic; MSVC++ apparently warns for plain "-u", so
139 // we write the equivalent expression "0 - u" instead.
140 u = 0 - u;
141 }
142 length += FastUInt32ToBufferLeft(u, buffer);
143 return length;
144 }
145
FastUInt32ToBufferLeft(uint32 i,char * buffer)146 size_t FastUInt32ToBufferLeft(uint32 i, char* buffer) {
147 char* start = buffer;
148 do {
149 *buffer++ = ((i % 10) + '0');
150 i /= 10;
151 } while (i > 0);
152 *buffer = 0;
153 std::reverse(start, buffer);
154 return buffer - start;
155 }
156
FastInt64ToBufferLeft(int64 i,char * buffer)157 size_t FastInt64ToBufferLeft(int64 i, char* buffer) {
158 uint64 u = i;
159 size_t length = 0;
160 if (i < 0) {
161 *buffer++ = '-';
162 ++length;
163 u = 0 - u;
164 }
165 length += FastUInt64ToBufferLeft(u, buffer);
166 return length;
167 }
168
FastUInt64ToBufferLeft(uint64 i,char * buffer)169 size_t FastUInt64ToBufferLeft(uint64 i, char* buffer) {
170 char* start = buffer;
171 do {
172 *buffer++ = ((i % 10) + '0');
173 i /= 10;
174 } while (i > 0);
175 *buffer = 0;
176 std::reverse(start, buffer);
177 return buffer - start;
178 }
179
180 static const double kDoublePrecisionCheckMax = DBL_MAX / 1.000000000000001;
181
DoubleToBuffer(double value,char * buffer)182 size_t DoubleToBuffer(double value, char* buffer) {
183 // DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all
184 // platforms these days. Just in case some system exists where DBL_DIG
185 // is significantly larger -- and risks overflowing our buffer -- we have
186 // this assert.
187 static_assert(DBL_DIG < 20, "DBL_DIG is too big");
188
189 if (std::isnan(value)) {
190 int snprintf_result = snprintf(buffer, kFastToBufferSize, "%snan",
191 std::signbit(value) ? "-" : "");
192 // Paranoid check to ensure we don't overflow the buffer.
193 DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
194 return snprintf_result;
195 }
196
197 if (std::abs(value) <= kDoublePrecisionCheckMax) {
198 int snprintf_result =
199 snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG, value);
200
201 // The snprintf should never overflow because the buffer is significantly
202 // larger than the precision we asked for.
203 DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
204
205 if (locale_independent_strtonum<double>(buffer, nullptr) == value) {
206 // Round-tripping the string to double works; we're done.
207 return snprintf_result;
208 }
209 // else: full precision formatting needed. Fall through.
210 }
211
212 int snprintf_result =
213 snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG + 2, value);
214
215 // Should never overflow; see above.
216 DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
217
218 return snprintf_result;
219 }
220
221 namespace {
SafeFirstChar(StringPiece str)222 char SafeFirstChar(StringPiece str) {
223 if (str.empty()) return '\0';
224 return str[0];
225 }
SkipSpaces(StringPiece * str)226 void SkipSpaces(StringPiece* str) {
227 while (isspace(SafeFirstChar(*str))) str->remove_prefix(1);
228 }
229 } // namespace
230
safe_strto64(StringPiece str,int64 * value)231 bool safe_strto64(StringPiece str, int64* value) {
232 SkipSpaces(&str);
233
234 int64 vlimit = kint64max;
235 int sign = 1;
236 if (absl::ConsumePrefix(&str, "-")) {
237 sign = -1;
238 // Different limit for positive and negative integers.
239 vlimit = kint64min;
240 }
241
242 if (!isdigit(SafeFirstChar(str))) return false;
243
244 int64 result = 0;
245 if (sign == 1) {
246 do {
247 int digit = SafeFirstChar(str) - '0';
248 if ((vlimit - digit) / 10 < result) {
249 return false;
250 }
251 result = result * 10 + digit;
252 str.remove_prefix(1);
253 } while (isdigit(SafeFirstChar(str)));
254 } else {
255 do {
256 int digit = SafeFirstChar(str) - '0';
257 if ((vlimit + digit) / 10 > result) {
258 return false;
259 }
260 result = result * 10 - digit;
261 str.remove_prefix(1);
262 } while (isdigit(SafeFirstChar(str)));
263 }
264
265 SkipSpaces(&str);
266 if (!str.empty()) return false;
267
268 *value = result;
269 return true;
270 }
271
safe_strtou64(StringPiece str,uint64 * value)272 bool safe_strtou64(StringPiece str, uint64* value) {
273 SkipSpaces(&str);
274 if (!isdigit(SafeFirstChar(str))) return false;
275
276 uint64 result = 0;
277 do {
278 int digit = SafeFirstChar(str) - '0';
279 if ((kuint64max - digit) / 10 < result) {
280 return false;
281 }
282 result = result * 10 + digit;
283 str.remove_prefix(1);
284 } while (isdigit(SafeFirstChar(str)));
285
286 SkipSpaces(&str);
287 if (!str.empty()) return false;
288
289 *value = result;
290 return true;
291 }
292
safe_strto32(StringPiece str,int32 * value)293 bool safe_strto32(StringPiece str, int32* value) {
294 SkipSpaces(&str);
295
296 int64 vmax = kint32max;
297 int sign = 1;
298 if (absl::ConsumePrefix(&str, "-")) {
299 sign = -1;
300 // Different max for positive and negative integers.
301 ++vmax;
302 }
303
304 if (!isdigit(SafeFirstChar(str))) return false;
305
306 int64 result = 0;
307 do {
308 result = result * 10 + SafeFirstChar(str) - '0';
309 if (result > vmax) {
310 return false;
311 }
312 str.remove_prefix(1);
313 } while (isdigit(SafeFirstChar(str)));
314
315 SkipSpaces(&str);
316
317 if (!str.empty()) return false;
318
319 *value = static_cast<int32>(result * sign);
320 return true;
321 }
322
safe_strtou32(StringPiece str,uint32 * value)323 bool safe_strtou32(StringPiece str, uint32* value) {
324 SkipSpaces(&str);
325 if (!isdigit(SafeFirstChar(str))) return false;
326
327 int64 result = 0;
328 do {
329 result = result * 10 + SafeFirstChar(str) - '0';
330 if (result > kuint32max) {
331 return false;
332 }
333 str.remove_prefix(1);
334 } while (isdigit(SafeFirstChar(str)));
335
336 SkipSpaces(&str);
337 if (!str.empty()) return false;
338
339 *value = static_cast<uint32>(result);
340 return true;
341 }
342
safe_strtof(StringPiece str,float * value)343 bool safe_strtof(StringPiece str, float* value) {
344 int processed_characters_count = -1;
345 auto len = str.size();
346
347 // If string length exceeds buffer size or int max, fail.
348 if (len >= kFastToBufferSize) return false;
349 if (len > std::numeric_limits<int>::max()) return false;
350
351 *value = StringToFloatConverter().StringToFloat(
352 str.data(), static_cast<int>(len), &processed_characters_count);
353 return processed_characters_count > 0;
354 }
355
safe_strtod(StringPiece str,double * value)356 bool safe_strtod(StringPiece str, double* value) {
357 int processed_characters_count = -1;
358 auto len = str.size();
359
360 // If string length exceeds buffer size or int max, fail.
361 if (len >= kFastToBufferSize) return false;
362 if (len > std::numeric_limits<int>::max()) return false;
363
364 *value = StringToFloatConverter().StringToDouble(
365 str.data(), static_cast<int>(len), &processed_characters_count);
366 return processed_characters_count > 0;
367 }
368
FloatToBuffer(float value,char * buffer)369 size_t FloatToBuffer(float value, char* buffer) {
370 // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
371 // platforms these days. Just in case some system exists where FLT_DIG
372 // is significantly larger -- and risks overflowing our buffer -- we have
373 // this assert.
374 static_assert(FLT_DIG < 10, "FLT_DIG is too big");
375
376 if (std::isnan(value)) {
377 int snprintf_result = snprintf(buffer, kFastToBufferSize, "%snan",
378 std::signbit(value) ? "-" : "");
379 // Paranoid check to ensure we don't overflow the buffer.
380 DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
381 return snprintf_result;
382 }
383
384 int snprintf_result =
385 snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG, value);
386
387 // The snprintf should never overflow because the buffer is significantly
388 // larger than the precision we asked for.
389 DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
390
391 float parsed_value;
392 if (!safe_strtof(buffer, &parsed_value) || parsed_value != value) {
393 snprintf_result =
394 snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG + 3, value);
395
396 // Should never overflow; see above.
397 DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
398 }
399 return snprintf_result;
400 }
401
FpToString(Fprint fp)402 string FpToString(Fprint fp) {
403 char buf[17];
404 snprintf(buf, sizeof(buf), "%016llx", static_cast<long long>(fp));
405 return string(buf);
406 }
407
StringToFp(const string & s,Fprint * fp)408 bool StringToFp(const string& s, Fprint* fp) {
409 char junk;
410 uint64_t result;
411 if (sscanf(s.c_str(), "%" SCNx64 "%c", &result, &junk) == 1) {
412 *fp = result;
413 return true;
414 } else {
415 return false;
416 }
417 }
418
Uint64ToHexString(uint64 v,char * buf)419 StringPiece Uint64ToHexString(uint64 v, char* buf) {
420 static const char* hexdigits = "0123456789abcdef";
421 const int num_byte = 16;
422 buf[num_byte] = '\0';
423 for (int i = num_byte - 1; i >= 0; i--) {
424 buf[i] = hexdigits[v & 0xf];
425 v >>= 4;
426 }
427 return StringPiece(buf, num_byte);
428 }
429
HexStringToUint64(const StringPiece & s,uint64 * result)430 bool HexStringToUint64(const StringPiece& s, uint64* result) {
431 uint64 v = 0;
432 if (s.empty()) {
433 return false;
434 }
435 for (size_t i = 0; i < s.size(); i++) {
436 char c = s[i];
437 if (c >= '0' && c <= '9') {
438 v = (v << 4) + (c - '0');
439 } else if (c >= 'a' && c <= 'f') {
440 v = (v << 4) + 10 + (c - 'a');
441 } else if (c >= 'A' && c <= 'F') {
442 v = (v << 4) + 10 + (c - 'A');
443 } else {
444 return false;
445 }
446 }
447 *result = v;
448 return true;
449 }
450
HumanReadableNum(int64 value)451 string HumanReadableNum(int64 value) {
452 string s;
453 if (value < 0) {
454 s += "-";
455 value = -value;
456 }
457 if (value < 1000) {
458 Appendf(&s, "%lld", static_cast<long long>(value));
459 } else if (value >= static_cast<int64>(1e15)) {
460 // Number bigger than 1E15; use that notation.
461 Appendf(&s, "%0.3G", static_cast<double>(value));
462 } else {
463 static const char units[] = "kMBT";
464 const char* unit = units;
465 while (value >= static_cast<int64>(1000000)) {
466 value /= static_cast<int64>(1000);
467 ++unit;
468 CHECK(unit < units + TF_ARRAYSIZE(units));
469 }
470 Appendf(&s, "%.2f%c", value / 1000.0, *unit);
471 }
472 return s;
473 }
474
HumanReadableNumBytes(int64 num_bytes)475 string HumanReadableNumBytes(int64 num_bytes) {
476 if (num_bytes == kint64min) {
477 // Special case for number with not representable negation.
478 return "-8E";
479 }
480
481 const char* neg_str = (num_bytes < 0) ? "-" : "";
482 if (num_bytes < 0) {
483 num_bytes = -num_bytes;
484 }
485
486 // Special case for bytes.
487 if (num_bytes < 1024) {
488 // No fractions for bytes.
489 char buf[8]; // Longest possible string is '-XXXXB'
490 snprintf(buf, sizeof(buf), "%s%lldB", neg_str,
491 static_cast<long long>(num_bytes));
492 return string(buf);
493 }
494
495 static const char units[] = "KMGTPE"; // int64 only goes up to E.
496 const char* unit = units;
497 while (num_bytes >= static_cast<int64>(1024) * 1024) {
498 num_bytes /= 1024;
499 ++unit;
500 CHECK(unit < units + TF_ARRAYSIZE(units));
501 }
502
503 // We use SI prefixes.
504 char buf[16];
505 snprintf(buf, sizeof(buf), ((*unit == 'K') ? "%s%.1f%ciB" : "%s%.2f%ciB"),
506 neg_str, num_bytes / 1024.0, *unit);
507 return string(buf);
508 }
509
HumanReadableElapsedTime(double seconds)510 string HumanReadableElapsedTime(double seconds) {
511 string human_readable;
512
513 if (seconds < 0) {
514 human_readable = "-";
515 seconds = -seconds;
516 }
517
518 // Start with us and keep going up to years.
519 // The comparisons must account for rounding to prevent the format breaking
520 // the tested condition and returning, e.g., "1e+03 us" instead of "1 ms".
521 const double microseconds = seconds * 1.0e6;
522 if (microseconds < 999.5) {
523 strings::Appendf(&human_readable, "%0.3g us", microseconds);
524 return human_readable;
525 }
526 double milliseconds = seconds * 1e3;
527 if (milliseconds >= .995 && milliseconds < 1) {
528 // Round half to even in Appendf would convert this to 0.999 ms.
529 milliseconds = 1.0;
530 }
531 if (milliseconds < 999.5) {
532 strings::Appendf(&human_readable, "%0.3g ms", milliseconds);
533 return human_readable;
534 }
535 if (seconds < 60.0) {
536 strings::Appendf(&human_readable, "%0.3g s", seconds);
537 return human_readable;
538 }
539 seconds /= 60.0;
540 if (seconds < 60.0) {
541 strings::Appendf(&human_readable, "%0.3g min", seconds);
542 return human_readable;
543 }
544 seconds /= 60.0;
545 if (seconds < 24.0) {
546 strings::Appendf(&human_readable, "%0.3g h", seconds);
547 return human_readable;
548 }
549 seconds /= 24.0;
550 if (seconds < 30.0) {
551 strings::Appendf(&human_readable, "%0.3g days", seconds);
552 return human_readable;
553 }
554 if (seconds < 365.2425) {
555 strings::Appendf(&human_readable, "%0.3g months", seconds / 30.436875);
556 return human_readable;
557 }
558 seconds /= 365.2425;
559 strings::Appendf(&human_readable, "%0.3g years", seconds);
560 return human_readable;
561 }
562
563 } // namespace strings
564 } // namespace tensorflow
565