1 // Copyright (c) 2015-2016 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <cfloat>
16 #include <cmath>
17 #include <cstdio>
18 #include <sstream>
19 #include <string>
20 #include <tuple>
21 
22 #include <gmock/gmock.h>
23 #include "SPIRV/hex_float.h"
24 
25 namespace {
26 using ::testing::Eq;
27 using spvutils::BitwiseCast;
28 using spvutils::Float16;
29 using spvutils::FloatProxy;
30 using spvutils::HexFloat;
31 using spvutils::ParseNormalFloat;
32 
33 // In this file "encode" means converting a number into a string,
34 // and "decode" means converting a string into a number.
35 
36 using HexFloatTest =
37     ::testing::TestWithParam<std::pair<FloatProxy<float>, std::string>>;
38 using DecodeHexFloatTest =
39     ::testing::TestWithParam<std::pair<std::string, FloatProxy<float>>>;
40 using HexDoubleTest =
41     ::testing::TestWithParam<std::pair<FloatProxy<double>, std::string>>;
42 using DecodeHexDoubleTest =
43     ::testing::TestWithParam<std::pair<std::string, FloatProxy<double>>>;
44 
45 // Hex-encodes a float value.
46 template <typename T>
EncodeViaHexFloat(const T & value)47 std::string EncodeViaHexFloat(const T& value) {
48   std::stringstream ss;
49   ss << spvutils::HexFloat<T>(value);
50   return ss.str();
51 }
52 
53 // The following two tests can't be DRY because they take different parameter
54 // types.
55 
TEST_P(HexFloatTest,EncodeCorrectly)56 TEST_P(HexFloatTest, EncodeCorrectly) {
57   EXPECT_THAT(EncodeViaHexFloat(GetParam().first), Eq(GetParam().second));
58 }
59 
TEST_P(HexDoubleTest,EncodeCorrectly)60 TEST_P(HexDoubleTest, EncodeCorrectly) {
61   EXPECT_THAT(EncodeViaHexFloat(GetParam().first), Eq(GetParam().second));
62 }
63 
64 // Decodes a hex-float string.
65 template <typename T>
Decode(const std::string & str)66 FloatProxy<T> Decode(const std::string& str) {
67   spvutils::HexFloat<FloatProxy<T>> decoded(0.f);
68   EXPECT_TRUE((std::stringstream(str) >> decoded).eof());
69   return decoded.value();
70 }
71 
TEST_P(HexFloatTest,DecodeCorrectly)72 TEST_P(HexFloatTest, DecodeCorrectly) {
73   EXPECT_THAT(Decode<float>(GetParam().second), Eq(GetParam().first));
74 }
75 
TEST_P(HexDoubleTest,DecodeCorrectly)76 TEST_P(HexDoubleTest, DecodeCorrectly) {
77   EXPECT_THAT(Decode<double>(GetParam().second), Eq(GetParam().first));
78 }
79 
80 INSTANTIATE_TEST_CASE_P(
81     Float32Tests, HexFloatTest,
82     ::testing::ValuesIn(std::vector<std::pair<FloatProxy<float>, std::string>>({
83         {0.f, "0x0p+0"},
84         {1.f, "0x1p+0"},
85         {2.f, "0x1p+1"},
86         {3.f, "0x1.8p+1"},
87         {0.5f, "0x1p-1"},
88         {0.25f, "0x1p-2"},
89         {0.75f, "0x1.8p-1"},
90         {-0.f, "-0x0p+0"},
91         {-1.f, "-0x1p+0"},
92         {-0.5f, "-0x1p-1"},
93         {-0.25f, "-0x1p-2"},
94         {-0.75f, "-0x1.8p-1"},
95 
96         // Larger numbers
97         {512.f, "0x1p+9"},
98         {-512.f, "-0x1p+9"},
99         {1024.f, "0x1p+10"},
100         {-1024.f, "-0x1p+10"},
101         {1024.f + 8.f, "0x1.02p+10"},
102         {-1024.f - 8.f, "-0x1.02p+10"},
103 
104         // Small numbers
105         {1.0f / 512.f, "0x1p-9"},
106         {1.0f / -512.f, "-0x1p-9"},
107         {1.0f / 1024.f, "0x1p-10"},
108         {1.0f / -1024.f, "-0x1p-10"},
109         {1.0f / 1024.f + 1.0f / 8.f, "0x1.02p-3"},
110         {1.0f / -1024.f - 1.0f / 8.f, "-0x1.02p-3"},
111 
112         // lowest non-denorm
113         {float(ldexp(1.0f, -126)), "0x1p-126"},
114         {float(ldexp(-1.0f, -126)), "-0x1p-126"},
115 
116         // Denormalized values
117         {float(ldexp(1.0f, -127)), "0x1p-127"},
118         {float(ldexp(1.0f, -127) / 2.0f), "0x1p-128"},
119         {float(ldexp(1.0f, -127) / 4.0f), "0x1p-129"},
120         {float(ldexp(1.0f, -127) / 8.0f), "0x1p-130"},
121         {float(ldexp(-1.0f, -127)), "-0x1p-127"},
122         {float(ldexp(-1.0f, -127) / 2.0f), "-0x1p-128"},
123         {float(ldexp(-1.0f, -127) / 4.0f), "-0x1p-129"},
124         {float(ldexp(-1.0f, -127) / 8.0f), "-0x1p-130"},
125 
126         {float(ldexp(1.0, -127) + (ldexp(1.0, -127) / 2.0f)), "0x1.8p-127"},
127         {float(ldexp(1.0, -127) / 2.0 + (ldexp(1.0, -127) / 4.0f)),
128          "0x1.8p-128"},
129 
130     })),);
131 
132 INSTANTIATE_TEST_CASE_P(
133     Float32NanTests, HexFloatTest,
134     ::testing::ValuesIn(std::vector<std::pair<FloatProxy<float>, std::string>>({
135         // Various NAN and INF cases
136         {uint32_t(0xFF800000), "-0x1p+128"},         // -inf
137         {uint32_t(0x7F800000), "0x1p+128"},          // inf
138         {uint32_t(0xFFC00000), "-0x1.8p+128"},       // -nan
139         {uint32_t(0xFF800100), "-0x1.0002p+128"},    // -nan
140         {uint32_t(0xFF800c00), "-0x1.0018p+128"},    // -nan
141         {uint32_t(0xFF80F000), "-0x1.01ep+128"},     // -nan
142         {uint32_t(0xFFFFFFFF), "-0x1.fffffep+128"},  // -nan
143         {uint32_t(0x7FC00000), "0x1.8p+128"},        // +nan
144         {uint32_t(0x7F800100), "0x1.0002p+128"},     // +nan
145         {uint32_t(0x7f800c00), "0x1.0018p+128"},     // +nan
146         {uint32_t(0x7F80F000), "0x1.01ep+128"},      // +nan
147         {uint32_t(0x7FFFFFFF), "0x1.fffffep+128"},   // +nan
148     })),);
149 
150 INSTANTIATE_TEST_CASE_P(
151     Float64Tests, HexDoubleTest,
152     ::testing::ValuesIn(
153         std::vector<std::pair<FloatProxy<double>, std::string>>({
154             {0., "0x0p+0"},
155             {1., "0x1p+0"},
156             {2., "0x1p+1"},
157             {3., "0x1.8p+1"},
158             {0.5, "0x1p-1"},
159             {0.25, "0x1p-2"},
160             {0.75, "0x1.8p-1"},
161             {-0., "-0x0p+0"},
162             {-1., "-0x1p+0"},
163             {-0.5, "-0x1p-1"},
164             {-0.25, "-0x1p-2"},
165             {-0.75, "-0x1.8p-1"},
166 
167             // Larger numbers
168             {512., "0x1p+9"},
169             {-512., "-0x1p+9"},
170             {1024., "0x1p+10"},
171             {-1024., "-0x1p+10"},
172             {1024. + 8., "0x1.02p+10"},
173             {-1024. - 8., "-0x1.02p+10"},
174 
175             // Large outside the range of normal floats
176             {ldexp(1.0, 128), "0x1p+128"},
177             {ldexp(1.0, 129), "0x1p+129"},
178             {ldexp(-1.0, 128), "-0x1p+128"},
179             {ldexp(-1.0, 129), "-0x1p+129"},
180             {ldexp(1.0, 128) + ldexp(1.0, 90), "0x1.0000000004p+128"},
181             {ldexp(1.0, 129) + ldexp(1.0, 120), "0x1.008p+129"},
182             {ldexp(-1.0, 128) + ldexp(1.0, 90), "-0x1.fffffffff8p+127"},
183             {ldexp(-1.0, 129) + ldexp(1.0, 120), "-0x1.ffp+128"},
184 
185             // Small numbers
186             {1.0 / 512., "0x1p-9"},
187             {1.0 / -512., "-0x1p-9"},
188             {1.0 / 1024., "0x1p-10"},
189             {1.0 / -1024., "-0x1p-10"},
190             {1.0 / 1024. + 1.0 / 8., "0x1.02p-3"},
191             {1.0 / -1024. - 1.0 / 8., "-0x1.02p-3"},
192 
193             // Small outside the range of normal floats
194             {ldexp(1.0, -128), "0x1p-128"},
195             {ldexp(1.0, -129), "0x1p-129"},
196             {ldexp(-1.0, -128), "-0x1p-128"},
197             {ldexp(-1.0, -129), "-0x1p-129"},
198             {ldexp(1.0, -128) + ldexp(1.0, -90), "0x1.0000000004p-90"},
199             {ldexp(1.0, -129) + ldexp(1.0, -120), "0x1.008p-120"},
200             {ldexp(-1.0, -128) + ldexp(1.0, -90), "0x1.fffffffff8p-91"},
201             {ldexp(-1.0, -129) + ldexp(1.0, -120), "0x1.ffp-121"},
202 
203             // lowest non-denorm
204             {ldexp(1.0, -1022), "0x1p-1022"},
205             {ldexp(-1.0, -1022), "-0x1p-1022"},
206 
207             // Denormalized values
208             {ldexp(1.0, -1023), "0x1p-1023"},
209             {ldexp(1.0, -1023) / 2.0, "0x1p-1024"},
210             {ldexp(1.0, -1023) / 4.0, "0x1p-1025"},
211             {ldexp(1.0, -1023) / 8.0, "0x1p-1026"},
212             {ldexp(-1.0, -1024), "-0x1p-1024"},
213             {ldexp(-1.0, -1024) / 2.0, "-0x1p-1025"},
214             {ldexp(-1.0, -1024) / 4.0, "-0x1p-1026"},
215             {ldexp(-1.0, -1024) / 8.0, "-0x1p-1027"},
216 
217             {ldexp(1.0, -1023) + (ldexp(1.0, -1023) / 2.0), "0x1.8p-1023"},
218             {ldexp(1.0, -1023) / 2.0 + (ldexp(1.0, -1023) / 4.0),
219              "0x1.8p-1024"},
220 
221         })),);
222 
223 INSTANTIATE_TEST_CASE_P(
224     Float64NanTests, HexDoubleTest,
225     ::testing::ValuesIn(std::vector<
226                         std::pair<FloatProxy<double>, std::string>>({
227         // Various NAN and INF cases
228         {uint64_t(0xFFF0000000000000LL), "-0x1p+1024"},                //-inf
229         {uint64_t(0x7FF0000000000000LL), "0x1p+1024"},                 //+inf
230         {uint64_t(0xFFF8000000000000LL), "-0x1.8p+1024"},              // -nan
231         {uint64_t(0xFFF0F00000000000LL), "-0x1.0fp+1024"},             // -nan
232         {uint64_t(0xFFF0000000000001LL), "-0x1.0000000000001p+1024"},  // -nan
233         {uint64_t(0xFFF0000300000000LL), "-0x1.00003p+1024"},          // -nan
234         {uint64_t(0xFFFFFFFFFFFFFFFFLL), "-0x1.fffffffffffffp+1024"},  // -nan
235         {uint64_t(0x7FF8000000000000LL), "0x1.8p+1024"},               // +nan
236         {uint64_t(0x7FF0F00000000000LL), "0x1.0fp+1024"},              // +nan
237         {uint64_t(0x7FF0000000000001LL), "0x1.0000000000001p+1024"},   // -nan
238         {uint64_t(0x7FF0000300000000LL), "0x1.00003p+1024"},           // -nan
239         {uint64_t(0x7FFFFFFFFFFFFFFFLL), "0x1.fffffffffffffp+1024"},   // -nan
240     })),);
241 
TEST(HexFloatStreamTest,OperatorLeftShiftPreservesFloatAndFill)242 TEST(HexFloatStreamTest, OperatorLeftShiftPreservesFloatAndFill) {
243   std::stringstream s;
244   s << std::setw(4) << std::oct << std::setfill('x') << 8 << " "
245     << FloatProxy<float>(uint32_t(0xFF800100)) << " " << std::setw(4) << 9;
246   EXPECT_THAT(s.str(), Eq(std::string("xx10 -0x1.0002p+128 xx11")));
247 }
248 
TEST(HexDoubleStreamTest,OperatorLeftShiftPreservesFloatAndFill)249 TEST(HexDoubleStreamTest, OperatorLeftShiftPreservesFloatAndFill) {
250   std::stringstream s;
251   s << std::setw(4) << std::oct << std::setfill('x') << 8 << " "
252     << FloatProxy<double>(uint64_t(0x7FF0F00000000000LL)) << " " << std::setw(4)
253     << 9;
254   EXPECT_THAT(s.str(), Eq(std::string("xx10 0x1.0fp+1024 xx11")));
255 }
256 
TEST_P(DecodeHexFloatTest,DecodeCorrectly)257 TEST_P(DecodeHexFloatTest, DecodeCorrectly) {
258   EXPECT_THAT(Decode<float>(GetParam().first), Eq(GetParam().second));
259 }
260 
TEST_P(DecodeHexDoubleTest,DecodeCorrectly)261 TEST_P(DecodeHexDoubleTest, DecodeCorrectly) {
262   EXPECT_THAT(Decode<double>(GetParam().first), Eq(GetParam().second));
263 }
264 
265 INSTANTIATE_TEST_CASE_P(
266     Float32DecodeTests, DecodeHexFloatTest,
267     ::testing::ValuesIn(std::vector<std::pair<std::string, FloatProxy<float>>>({
268         {"0x0p+000", 0.f},
269         {"0x0p0", 0.f},
270         {"0x0p-0", 0.f},
271 
272         // flush to zero cases
273         {"0x1p-500", 0.f},  // Exponent underflows.
274         {"-0x1p-500", -0.f},
275         {"0x0.00000000001p-126", 0.f},  // Fraction causes underflow.
276         {"-0x0.0000000001p-127", -0.f},
277         {"-0x0.01p-142", -0.f},  // Fraction causes additional underflow.
278         {"0x0.01p-142", 0.f},
279 
280         // Some floats that do not encode the same way as they decode.
281         {"0x2p+0", 2.f},
282         {"0xFFp+0", 255.f},
283         {"0x0.8p+0", 0.5f},
284         {"0x0.4p+0", 0.25f},
285     })),);
286 
287 INSTANTIATE_TEST_CASE_P(
288     Float32DecodeInfTests, DecodeHexFloatTest,
289     ::testing::ValuesIn(std::vector<std::pair<std::string, FloatProxy<float>>>({
290         // inf cases
291         {"-0x1p+128", uint32_t(0xFF800000)},   // -inf
292         {"0x32p+127", uint32_t(0x7F800000)},   // inf
293         {"0x32p+500", uint32_t(0x7F800000)},   // inf
294         {"-0x32p+127", uint32_t(0xFF800000)},  // -inf
295     })),);
296 
297 INSTANTIATE_TEST_CASE_P(
298     Float64DecodeTests, DecodeHexDoubleTest,
299     ::testing::ValuesIn(
300         std::vector<std::pair<std::string, FloatProxy<double>>>({
301             {"0x0p+000", 0.},
302             {"0x0p0", 0.},
303             {"0x0p-0", 0.},
304 
305             // flush to zero cases
306             {"0x1p-5000", 0.},  // Exponent underflows.
307             {"-0x1p-5000", -0.},
308             {"0x0.0000000000000001p-1023", 0.},  // Fraction causes underflow.
309             {"-0x0.000000000000001p-1024", -0.},
310             {"-0x0.01p-1090", -0.f},  // Fraction causes additional underflow.
311             {"0x0.01p-1090", 0.},
312 
313             // Some floats that do not encode the same way as they decode.
314             {"0x2p+0", 2.},
315             {"0xFFp+0", 255.},
316             {"0x0.8p+0", 0.5},
317             {"0x0.4p+0", 0.25},
318         })),);
319 
320 INSTANTIATE_TEST_CASE_P(
321     Float64DecodeInfTests, DecodeHexDoubleTest,
322     ::testing::ValuesIn(
323         std::vector<std::pair<std::string, FloatProxy<double>>>({
324             // inf cases
325             {"-0x1p+1024", uint64_t(0xFFF0000000000000)},   // -inf
326             {"0x32p+1023", uint64_t(0x7FF0000000000000)},   // inf
327             {"0x32p+5000", uint64_t(0x7FF0000000000000)},   // inf
328             {"-0x32p+1023", uint64_t(0xFFF0000000000000)},  // -inf
329         })),);
330 
TEST(FloatProxy,ValidConversion)331 TEST(FloatProxy, ValidConversion) {
332   EXPECT_THAT(FloatProxy<float>(1.f).getAsFloat(), Eq(1.0f));
333   EXPECT_THAT(FloatProxy<float>(32.f).getAsFloat(), Eq(32.0f));
334   EXPECT_THAT(FloatProxy<float>(-1.f).getAsFloat(), Eq(-1.0f));
335   EXPECT_THAT(FloatProxy<float>(0.f).getAsFloat(), Eq(0.0f));
336   EXPECT_THAT(FloatProxy<float>(-0.f).getAsFloat(), Eq(-0.0f));
337   EXPECT_THAT(FloatProxy<float>(1.2e32f).getAsFloat(), Eq(1.2e32f));
338 
339   EXPECT_TRUE(std::isinf(FloatProxy<float>(uint32_t(0xFF800000)).getAsFloat()));
340   EXPECT_TRUE(std::isinf(FloatProxy<float>(uint32_t(0x7F800000)).getAsFloat()));
341   EXPECT_TRUE(std::isnan(FloatProxy<float>(uint32_t(0xFFC00000)).getAsFloat()));
342   EXPECT_TRUE(std::isnan(FloatProxy<float>(uint32_t(0xFF800100)).getAsFloat()));
343   EXPECT_TRUE(std::isnan(FloatProxy<float>(uint32_t(0xFF800c00)).getAsFloat()));
344   EXPECT_TRUE(std::isnan(FloatProxy<float>(uint32_t(0xFF80F000)).getAsFloat()));
345   EXPECT_TRUE(std::isnan(FloatProxy<float>(uint32_t(0xFFFFFFFF)).getAsFloat()));
346   EXPECT_TRUE(std::isnan(FloatProxy<float>(uint32_t(0x7FC00000)).getAsFloat()));
347   EXPECT_TRUE(std::isnan(FloatProxy<float>(uint32_t(0x7F800100)).getAsFloat()));
348   EXPECT_TRUE(std::isnan(FloatProxy<float>(uint32_t(0x7f800c00)).getAsFloat()));
349   EXPECT_TRUE(std::isnan(FloatProxy<float>(uint32_t(0x7F80F000)).getAsFloat()));
350   EXPECT_TRUE(std::isnan(FloatProxy<float>(uint32_t(0x7FFFFFFF)).getAsFloat()));
351 
352   EXPECT_THAT(FloatProxy<float>(uint32_t(0xFF800000)).data(), Eq(0xFF800000u));
353   EXPECT_THAT(FloatProxy<float>(uint32_t(0x7F800000)).data(), Eq(0x7F800000u));
354   EXPECT_THAT(FloatProxy<float>(uint32_t(0xFFC00000)).data(), Eq(0xFFC00000u));
355   EXPECT_THAT(FloatProxy<float>(uint32_t(0xFF800100)).data(), Eq(0xFF800100u));
356   EXPECT_THAT(FloatProxy<float>(uint32_t(0xFF800c00)).data(), Eq(0xFF800c00u));
357   EXPECT_THAT(FloatProxy<float>(uint32_t(0xFF80F000)).data(), Eq(0xFF80F000u));
358   EXPECT_THAT(FloatProxy<float>(uint32_t(0xFFFFFFFF)).data(), Eq(0xFFFFFFFFu));
359   EXPECT_THAT(FloatProxy<float>(uint32_t(0x7FC00000)).data(), Eq(0x7FC00000u));
360   EXPECT_THAT(FloatProxy<float>(uint32_t(0x7F800100)).data(), Eq(0x7F800100u));
361   EXPECT_THAT(FloatProxy<float>(uint32_t(0x7f800c00)).data(), Eq(0x7f800c00u));
362   EXPECT_THAT(FloatProxy<float>(uint32_t(0x7F80F000)).data(), Eq(0x7F80F000u));
363   EXPECT_THAT(FloatProxy<float>(uint32_t(0x7FFFFFFF)).data(), Eq(0x7FFFFFFFu));
364 }
365 
TEST(FloatProxy,Nan)366 TEST(FloatProxy, Nan) {
367   EXPECT_TRUE(FloatProxy<float>(uint32_t(0xFFC00000)).isNan());
368   EXPECT_TRUE(FloatProxy<float>(uint32_t(0xFF800100)).isNan());
369   EXPECT_TRUE(FloatProxy<float>(uint32_t(0xFF800c00)).isNan());
370   EXPECT_TRUE(FloatProxy<float>(uint32_t(0xFF80F000)).isNan());
371   EXPECT_TRUE(FloatProxy<float>(uint32_t(0xFFFFFFFF)).isNan());
372   EXPECT_TRUE(FloatProxy<float>(uint32_t(0x7FC00000)).isNan());
373   EXPECT_TRUE(FloatProxy<float>(uint32_t(0x7F800100)).isNan());
374   EXPECT_TRUE(FloatProxy<float>(uint32_t(0x7f800c00)).isNan());
375   EXPECT_TRUE(FloatProxy<float>(uint32_t(0x7F80F000)).isNan());
376   EXPECT_TRUE(FloatProxy<float>(uint32_t(0x7FFFFFFF)).isNan());
377 }
378 
TEST(FloatProxy,Negation)379 TEST(FloatProxy, Negation) {
380   EXPECT_THAT((-FloatProxy<float>(1.f)).getAsFloat(), Eq(-1.0f));
381   EXPECT_THAT((-FloatProxy<float>(0.f)).getAsFloat(), Eq(-0.0f));
382 
383   EXPECT_THAT((-FloatProxy<float>(-1.f)).getAsFloat(), Eq(1.0f));
384   EXPECT_THAT((-FloatProxy<float>(-0.f)).getAsFloat(), Eq(0.0f));
385 
386   EXPECT_THAT((-FloatProxy<float>(32.f)).getAsFloat(), Eq(-32.0f));
387   EXPECT_THAT((-FloatProxy<float>(-32.f)).getAsFloat(), Eq(32.0f));
388 
389   EXPECT_THAT((-FloatProxy<float>(1.2e32f)).getAsFloat(), Eq(-1.2e32f));
390   EXPECT_THAT((-FloatProxy<float>(-1.2e32f)).getAsFloat(), Eq(1.2e32f));
391 
392   EXPECT_THAT(
393       (-FloatProxy<float>(std::numeric_limits<float>::infinity())).getAsFloat(),
394       Eq(-std::numeric_limits<float>::infinity()));
395   EXPECT_THAT((-FloatProxy<float>(-std::numeric_limits<float>::infinity()))
396                   .getAsFloat(),
397               Eq(std::numeric_limits<float>::infinity()));
398 }
399 
400 // Test conversion of FloatProxy values to strings.
401 //
402 // In previous cases, we always wrapped the FloatProxy value in a HexFloat
403 // before conversion to a string.  In the following cases, the FloatProxy
404 // decides for itself whether to print as a regular number or as a hex float.
405 
406 using FloatProxyFloatTest =
407     ::testing::TestWithParam<std::pair<FloatProxy<float>, std::string>>;
408 using FloatProxyDoubleTest =
409     ::testing::TestWithParam<std::pair<FloatProxy<double>, std::string>>;
410 
411 // Converts a float value to a string via a FloatProxy.
412 template <typename T>
EncodeViaFloatProxy(const T & value)413 std::string EncodeViaFloatProxy(const T& value) {
414   std::stringstream ss;
415   ss << value;
416   return ss.str();
417 }
418 
419 // Converts a floating point string so that the exponent prefix
420 // is 'e', and the exponent value does not have leading zeros.
421 // The Microsoft runtime library likes to write things like "2.5E+010".
422 // Convert that to "2.5e+10".
423 // We don't care what happens to strings that are not floating point
424 // strings.
NormalizeExponentInFloatString(std::string in)425 std::string NormalizeExponentInFloatString(std::string in) {
426   std::string result;
427   // Reserve one spot for the terminating null, even when the sscanf fails.
428   std::vector<char> prefix(in.size() + 1);
429   char e;
430   char plus_or_minus;
431   int exponent;  // in base 10
432   if ((4 == std::sscanf(in.c_str(), "%[-+.0123456789]%c%c%d", prefix.data(), &e,
433                         &plus_or_minus, &exponent)) &&
434       (e == 'e' || e == 'E') &&
435       (plus_or_minus == '-' || plus_or_minus == '+')) {
436     // It looks like a floating point value with exponent.
437     std::stringstream out;
438     out << prefix.data() << 'e' << plus_or_minus << exponent;
439     result = out.str();
440   } else {
441     result = in;
442   }
443   return result;
444 }
445 
TEST(NormalizeFloat,Sample)446 TEST(NormalizeFloat, Sample) {
447   EXPECT_THAT(NormalizeExponentInFloatString(""), Eq(""));
448   EXPECT_THAT(NormalizeExponentInFloatString("1e-12"), Eq("1e-12"));
449   EXPECT_THAT(NormalizeExponentInFloatString("1E+14"), Eq("1e+14"));
450   EXPECT_THAT(NormalizeExponentInFloatString("1e-0012"), Eq("1e-12"));
451   EXPECT_THAT(NormalizeExponentInFloatString("1.263E+014"), Eq("1.263e+14"));
452 }
453 
454 // The following two tests can't be DRY because they take different parameter
455 // types.
TEST_P(FloatProxyFloatTest,EncodeCorrectly)456 TEST_P(FloatProxyFloatTest, EncodeCorrectly) {
457   EXPECT_THAT(
458       NormalizeExponentInFloatString(EncodeViaFloatProxy(GetParam().first)),
459       Eq(GetParam().second));
460 }
461 
TEST_P(FloatProxyDoubleTest,EncodeCorrectly)462 TEST_P(FloatProxyDoubleTest, EncodeCorrectly) {
463   EXPECT_THAT(
464       NormalizeExponentInFloatString(EncodeViaFloatProxy(GetParam().first)),
465       Eq(GetParam().second));
466 }
467 
468 INSTANTIATE_TEST_CASE_P(
469     Float32Tests, FloatProxyFloatTest,
470     ::testing::ValuesIn(std::vector<std::pair<FloatProxy<float>, std::string>>({
471         // Zero
472         {0.f, "0"},
473         // Normal numbers
474         {1.f, "1"},
475         {-0.25f, "-0.25"},
476         {1000.0f, "1000"},
477 
478         // Still normal numbers, but with large magnitude exponents.
479         {float(ldexp(1.f, 126)), "8.50706e+37"},
480         {float(ldexp(-1.f, -126)), "-1.17549e-38"},
481 
482         // denormalized values are printed as hex floats.
483         {float(ldexp(1.0f, -127)), "0x1p-127"},
484         {float(ldexp(1.5f, -128)), "0x1.8p-128"},
485         {float(ldexp(1.25, -129)), "0x1.4p-129"},
486         {float(ldexp(1.125, -130)), "0x1.2p-130"},
487         {float(ldexp(-1.0f, -127)), "-0x1p-127"},
488         {float(ldexp(-1.0f, -128)), "-0x1p-128"},
489         {float(ldexp(-1.0f, -129)), "-0x1p-129"},
490         {float(ldexp(-1.5f, -130)), "-0x1.8p-130"},
491 
492         // NaNs
493         {FloatProxy<float>(uint32_t(0xFFC00000)), "-0x1.8p+128"},
494         {FloatProxy<float>(uint32_t(0xFF800100)), "-0x1.0002p+128"},
495 
496         {std::numeric_limits<float>::infinity(), "0x1p+128"},
497         {-std::numeric_limits<float>::infinity(), "-0x1p+128"},
498     })),);
499 
500 INSTANTIATE_TEST_CASE_P(
501     Float64Tests, FloatProxyDoubleTest,
502     ::testing::ValuesIn(
503         std::vector<std::pair<FloatProxy<double>, std::string>>({
504             {0., "0"},
505             {1., "1"},
506             {-0.25, "-0.25"},
507             {1000.0, "1000"},
508 
509             // Large outside the range of normal floats
510             {ldexp(1.0, 128), "3.40282366920938e+38"},
511             {ldexp(1.5, 129), "1.02084710076282e+39"},
512             {ldexp(-1.0, 128), "-3.40282366920938e+38"},
513             {ldexp(-1.5, 129), "-1.02084710076282e+39"},
514 
515             // Small outside the range of normal floats
516             {ldexp(1.5, -129), "2.20405190779179e-39"},
517             {ldexp(-1.5, -129), "-2.20405190779179e-39"},
518 
519             // lowest non-denorm
520             {ldexp(1.0, -1022), "2.2250738585072e-308"},
521             {ldexp(-1.0, -1022), "-2.2250738585072e-308"},
522 
523             // Denormalized values
524             {ldexp(1.125, -1023), "0x1.2p-1023"},
525             {ldexp(-1.375, -1024), "-0x1.6p-1024"},
526 
527             // NaNs
528             {uint64_t(0x7FF8000000000000LL), "0x1.8p+1024"},
529             {uint64_t(0xFFF0F00000000000LL), "-0x1.0fp+1024"},
530 
531             // Infinity
532             {std::numeric_limits<double>::infinity(), "0x1p+1024"},
533             {-std::numeric_limits<double>::infinity(), "-0x1p+1024"},
534 
535         })),);
536 
537 // double is used so that unbiased_exponent can be used with the output
538 // of ldexp directly.
unbiased_exponent(double f)539 int32_t unbiased_exponent(double f) {
540   return spvutils::HexFloat<spvutils::FloatProxy<float>>(
541       static_cast<float>(f)).getUnbiasedNormalizedExponent();
542 }
543 
unbiased_half_exponent(uint16_t f)544 int16_t unbiased_half_exponent(uint16_t f) {
545   return spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>(f)
546       .getUnbiasedNormalizedExponent();
547 }
548 
TEST(HexFloatOperationTest,UnbiasedExponent)549 TEST(HexFloatOperationTest, UnbiasedExponent) {
550   // Float cases
551   EXPECT_EQ(0, unbiased_exponent(ldexp(1.0f, 0)));
552   EXPECT_EQ(-32, unbiased_exponent(ldexp(1.0f, -32)));
553   EXPECT_EQ(42, unbiased_exponent(ldexp(1.0f, 42)));
554   EXPECT_EQ(125, unbiased_exponent(ldexp(1.0f, 125)));
555   // Saturates to 128
556   EXPECT_EQ(128, unbiased_exponent(ldexp(1.0f, 256)));
557 
558   EXPECT_EQ(-100, unbiased_exponent(ldexp(1.0f, -100)));
559   EXPECT_EQ(-127, unbiased_exponent(ldexp(1.0f, -127))); // First denorm
560   EXPECT_EQ(-128, unbiased_exponent(ldexp(1.0f, -128)));
561   EXPECT_EQ(-129, unbiased_exponent(ldexp(1.0f, -129)));
562   EXPECT_EQ(-140, unbiased_exponent(ldexp(1.0f, -140)));
563   // Smallest representable number
564   EXPECT_EQ(-126 - 23, unbiased_exponent(ldexp(1.0f, -126 - 23)));
565   // Should get rounded to 0 first.
566   EXPECT_EQ(0, unbiased_exponent(ldexp(1.0f, -127 - 23)));
567 
568   // Float16 cases
569   // The exponent is represented in the bits 0x7C00
570   // The offset is -15
571   EXPECT_EQ(0, unbiased_half_exponent(0x3C00));
572   EXPECT_EQ(3, unbiased_half_exponent(0x4800));
573   EXPECT_EQ(-1, unbiased_half_exponent(0x3800));
574   EXPECT_EQ(-14, unbiased_half_exponent(0x0400));
575   EXPECT_EQ(16, unbiased_half_exponent(0x7C00));
576   EXPECT_EQ(10, unbiased_half_exponent(0x6400));
577 
578   // Smallest representable number
579   EXPECT_EQ(-24, unbiased_half_exponent(0x0001));
580 }
581 
582 // Creates a float that is the sum of 1/(2 ^ fractions[i]) for i in factions
float_fractions(const std::vector<uint32_t> & fractions)583 float float_fractions(const std::vector<uint32_t>& fractions) {
584   float f = 0;
585   for(int32_t i: fractions) {
586     f += std::ldexp(1.0f, -i);
587   }
588   return f;
589 }
590 
591 // Returns the normalized significand of a HexFloat<FloatProxy<float>>
592 // that was created by calling float_fractions with the input fractions,
593 // raised to the power of exp.
normalized_significand(const std::vector<uint32_t> & fractions,uint32_t exp)594 uint32_t normalized_significand(const std::vector<uint32_t>& fractions, uint32_t exp) {
595   return spvutils::HexFloat<spvutils::FloatProxy<float>>(
596              static_cast<float>(ldexp(float_fractions(fractions), exp)))
597       .getNormalizedSignificand();
598 }
599 
600 // Sets the bits from MSB to LSB of the significand part of a float.
601 // For example 0 would set the bit 23 (counting from LSB to MSB),
602 // and 1 would set the 22nd bit.
bits_set(const std::vector<uint32_t> & bits)603 uint32_t bits_set(const std::vector<uint32_t>& bits) {
604   const uint32_t top_bit = 1u << 22u;
605   uint32_t val= 0;
606   for(uint32_t i: bits) {
607     val |= top_bit >> i;
608   }
609   return val;
610 }
611 
612 // The same as bits_set but for a Float16 value instead of 32-bit floating
613 // point.
half_bits_set(const std::vector<uint32_t> & bits)614 uint16_t half_bits_set(const std::vector<uint32_t>& bits) {
615   const uint32_t top_bit = 1u << 9u;
616   uint32_t val= 0;
617   for(uint32_t i: bits) {
618     val |= top_bit >> i;
619   }
620   return static_cast<uint16_t>(val);
621 }
622 
TEST(HexFloatOperationTest,NormalizedSignificand)623 TEST(HexFloatOperationTest, NormalizedSignificand) {
624   // For normalized numbers (the following) it should be a simple matter
625   // of getting rid of the top implicit bit
626   EXPECT_EQ(bits_set({}), normalized_significand({0}, 0));
627   EXPECT_EQ(bits_set({0}), normalized_significand({0, 1}, 0));
628   EXPECT_EQ(bits_set({0, 1}), normalized_significand({0, 1, 2}, 0));
629   EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 0));
630   EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 32));
631   EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 126));
632 
633   // For denormalized numbers we expect the normalized significand to
634   // shift as if it were normalized. This means, in practice that the
635   // top_most set bit will be cut off. Looks very similar to above (on purpose)
636   EXPECT_EQ(bits_set({}), normalized_significand({0}, -127));
637   EXPECT_EQ(bits_set({3}), normalized_significand({0, 4}, -128));
638   EXPECT_EQ(bits_set({3}), normalized_significand({0, 4}, -127));
639   EXPECT_EQ(bits_set({}), normalized_significand({22}, -127));
640   EXPECT_EQ(bits_set({0}), normalized_significand({21, 22}, -127));
641 }
642 
643 // Returns the 32-bit floating point value created by
644 // calling setFromSignUnbiasedExponentAndNormalizedSignificand
645 // on a HexFloat<FloatProxy<float>>
set_from_sign(bool negative,int32_t unbiased_exponent,uint32_t significand,bool round_denorm_up)646 float set_from_sign(bool negative, int32_t unbiased_exponent,
647                    uint32_t significand, bool round_denorm_up) {
648   spvutils::HexFloat<spvutils::FloatProxy<float>>  f(0.f);
649   f.setFromSignUnbiasedExponentAndNormalizedSignificand(
650       negative, unbiased_exponent, significand, round_denorm_up);
651   return f.value().getAsFloat();
652 }
653 
TEST(HexFloatOperationTests,SetFromSignUnbiasedExponentAndNormalizedSignificand)654 TEST(HexFloatOperationTests,
655      SetFromSignUnbiasedExponentAndNormalizedSignificand) {
656 
657   EXPECT_EQ(1.f, set_from_sign(false, 0, 0, false));
658 
659   // Tests insertion of various denormalized numbers with and without round up.
660   EXPECT_EQ(static_cast<float>(ldexp(1.f, -149)), set_from_sign(false, -149, 0, false));
661   EXPECT_EQ(static_cast<float>(ldexp(1.f, -149)), set_from_sign(false, -149, 0, true));
662   EXPECT_EQ(0.f, set_from_sign(false, -150, 1, false));
663   EXPECT_EQ(static_cast<float>(ldexp(1.f, -149)), set_from_sign(false, -150, 1, true));
664 
665   EXPECT_EQ(ldexp(1.0f, -127), set_from_sign(false, -127, 0, false));
666   EXPECT_EQ(ldexp(1.0f, -128), set_from_sign(false, -128, 0, false));
667   EXPECT_EQ(float_fractions({0, 1, 2, 5}),
668             set_from_sign(false, 0, bits_set({0, 1, 4}), false));
669   EXPECT_EQ(ldexp(float_fractions({0, 1, 2, 5}), -32),
670             set_from_sign(false, -32, bits_set({0, 1, 4}), false));
671   EXPECT_EQ(ldexp(float_fractions({0, 1, 2, 5}), -128),
672             set_from_sign(false, -128, bits_set({0, 1, 4}), false));
673 
674   // The negative cases from above.
675   EXPECT_EQ(-1.f, set_from_sign(true, 0, 0, false));
676   EXPECT_EQ(-ldexp(1.0, -127), set_from_sign(true, -127, 0, false));
677   EXPECT_EQ(-ldexp(1.0, -128), set_from_sign(true, -128, 0, false));
678   EXPECT_EQ(-float_fractions({0, 1, 2, 5}),
679             set_from_sign(true, 0, bits_set({0, 1, 4}), false));
680   EXPECT_EQ(-ldexp(float_fractions({0, 1, 2, 5}), -32),
681             set_from_sign(true, -32, bits_set({0, 1, 4}), false));
682   EXPECT_EQ(-ldexp(float_fractions({0, 1, 2, 5}), -128),
683             set_from_sign(true, -128, bits_set({0, 1, 4}), false));
684 }
685 
TEST(HexFloatOperationTests,NonRounding)686 TEST(HexFloatOperationTests, NonRounding) {
687   // Rounding from 32-bit hex-float to 32-bit hex-float should be trivial,
688   // except in the denorm case which is a bit more complex.
689   using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
690   bool carry_bit = false;
691 
692   spvutils::round_direction rounding[] = {
693       spvutils::kRoundToZero,
694       spvutils::kRoundToNearestEven,
695       spvutils::kRoundToPositiveInfinity,
696       spvutils::kRoundToNegativeInfinity};
697 
698   // Everything fits, so this should be straight-forward
699   for (spvutils::round_direction round : rounding) {
700     EXPECT_EQ(bits_set({}), HF(0.f).getRoundedNormalizedSignificand<HF>(
701                                 round, &carry_bit));
702     EXPECT_FALSE(carry_bit);
703 
704     EXPECT_EQ(bits_set({0}),
705               HF(float_fractions({0, 1}))
706                   .getRoundedNormalizedSignificand<HF>(round, &carry_bit));
707     EXPECT_FALSE(carry_bit);
708 
709     EXPECT_EQ(bits_set({1, 3}),
710               HF(float_fractions({0, 2, 4}))
711                   .getRoundedNormalizedSignificand<HF>(round, &carry_bit));
712     EXPECT_FALSE(carry_bit);
713 
714     EXPECT_EQ(
715         bits_set({0, 1, 4}),
716         HF(static_cast<float>(-ldexp(float_fractions({0, 1, 2, 5}), -128)))
717             .getRoundedNormalizedSignificand<HF>(round, &carry_bit));
718     EXPECT_FALSE(carry_bit);
719 
720     EXPECT_EQ(
721         bits_set({0, 1, 4, 22}),
722         HF(static_cast<float>(float_fractions({0, 1, 2, 5, 23})))
723             .getRoundedNormalizedSignificand<HF>(round, &carry_bit));
724     EXPECT_FALSE(carry_bit);
725   }
726 }
727 
728 struct RoundSignificandCase {
729   float source_float;
730   std::pair<int16_t, bool> expected_results;
731   spvutils::round_direction round;
732 };
733 
734 using HexFloatRoundTest =
735     ::testing::TestWithParam<RoundSignificandCase>;
736 
TEST_P(HexFloatRoundTest,RoundDownToFP16)737 TEST_P(HexFloatRoundTest, RoundDownToFP16) {
738   using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
739   using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
740 
741   HF input_value(GetParam().source_float);
742   bool carry_bit = false;
743   EXPECT_EQ(GetParam().expected_results.first,
744             input_value.getRoundedNormalizedSignificand<HF16>(
745                 GetParam().round, &carry_bit));
746   EXPECT_EQ(carry_bit, GetParam().expected_results.second);
747 }
748 
749 // clang-format off
750 INSTANTIATE_TEST_CASE_P(F32ToF16, HexFloatRoundTest,
751   ::testing::ValuesIn(std::vector<RoundSignificandCase>(
752   {
753     {float_fractions({0}), std::make_pair(half_bits_set({}), false), spvutils::kRoundToZero},
754     {float_fractions({0}), std::make_pair(half_bits_set({}), false), spvutils::kRoundToNearestEven},
755     {float_fractions({0}), std::make_pair(half_bits_set({}), false), spvutils::kRoundToPositiveInfinity},
756     {float_fractions({0}), std::make_pair(half_bits_set({}), false), spvutils::kRoundToNegativeInfinity},
757     {float_fractions({0, 1}), std::make_pair(half_bits_set({0}), false), spvutils::kRoundToZero},
758 
759     {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), spvutils::kRoundToZero},
760     {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0, 9}), false), spvutils::kRoundToPositiveInfinity},
761     {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), spvutils::kRoundToNegativeInfinity},
762     {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), spvutils::kRoundToNearestEven},
763 
764     {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 9}), false), spvutils::kRoundToZero},
765     {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 8}), false), spvutils::kRoundToPositiveInfinity},
766     {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 9}), false), spvutils::kRoundToNegativeInfinity},
767     {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 8}), false), spvutils::kRoundToNearestEven},
768 
769     {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), spvutils::kRoundToZero},
770     {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), spvutils::kRoundToPositiveInfinity},
771     {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), spvutils::kRoundToNegativeInfinity},
772     {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), spvutils::kRoundToNearestEven},
773 
774     {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), spvutils::kRoundToZero},
775     {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), spvutils::kRoundToPositiveInfinity},
776     {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), spvutils::kRoundToNegativeInfinity},
777     {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), spvutils::kRoundToNearestEven},
778 
779     {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0}), false), spvutils::kRoundToZero},
780     {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0, 9}), false), spvutils::kRoundToPositiveInfinity},
781     {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0}), false), spvutils::kRoundToNegativeInfinity},
782     {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0, 9}), false), spvutils::kRoundToNearestEven},
783 
784     // Carries
785     {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), false), spvutils::kRoundToZero},
786     {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({}), true), spvutils::kRoundToPositiveInfinity},
787     {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), false), spvutils::kRoundToNegativeInfinity},
788     {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({}), true), spvutils::kRoundToNearestEven},
789 
790     // Cases where original number was denorm. Note: this should have no effect
791     // the number is pre-normalized.
792     {static_cast<float>(ldexp(float_fractions({0, 1, 11, 13}), -128)), std::make_pair(half_bits_set({0}), false), spvutils::kRoundToZero},
793     {static_cast<float>(ldexp(float_fractions({0, 1, 11, 13}), -129)), std::make_pair(half_bits_set({0, 9}), false), spvutils::kRoundToPositiveInfinity},
794     {static_cast<float>(ldexp(float_fractions({0, 1, 11, 13}), -131)), std::make_pair(half_bits_set({0}), false), spvutils::kRoundToNegativeInfinity},
795     {static_cast<float>(ldexp(float_fractions({0, 1, 11, 13}), -130)), std::make_pair(half_bits_set({0, 9}), false), spvutils::kRoundToNearestEven},
796   })),);
797 // clang-format on
798 
799 struct UpCastSignificandCase {
800   uint16_t source_half;
801   uint32_t expected_result;
802 };
803 
804 using HexFloatRoundUpSignificandTest =
805     ::testing::TestWithParam<UpCastSignificandCase>;
TEST_P(HexFloatRoundUpSignificandTest,Widening)806 TEST_P(HexFloatRoundUpSignificandTest, Widening) {
807   using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
808   using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
809   bool carry_bit = false;
810 
811   spvutils::round_direction rounding[] = {
812       spvutils::kRoundToZero,
813       spvutils::kRoundToNearestEven,
814       spvutils::kRoundToPositiveInfinity,
815       spvutils::kRoundToNegativeInfinity};
816 
817   // Everything fits, so everything should just be bit-shifts.
818   for (spvutils::round_direction round : rounding) {
819     carry_bit = false;
820     HF16 input_value(GetParam().source_half);
821     EXPECT_EQ(
822         GetParam().expected_result,
823         input_value.getRoundedNormalizedSignificand<HF>(round, &carry_bit))
824         << std::hex << "0x"
825         << input_value.getRoundedNormalizedSignificand<HF>(round, &carry_bit)
826         << "  0x" << GetParam().expected_result;
827     EXPECT_FALSE(carry_bit);
828   }
829 }
830 
831 INSTANTIATE_TEST_CASE_P(F16toF32, HexFloatRoundUpSignificandTest,
832   // 0xFC00 of the source 16-bit hex value cover the sign and the exponent.
833   // They are ignored for this test.
834   ::testing::ValuesIn(std::vector<UpCastSignificandCase>(
835   {
836     {0x3F00, 0x600000},
837     {0x0F00, 0x600000},
838     {0x0F01, 0x602000},
839     {0x0FFF, 0x7FE000},
840   })),);
841 
842 struct DownCastTest {
843   float source_float;
844   uint16_t expected_half;
845   std::vector<spvutils::round_direction> directions;
846 };
847 
get_round_text(spvutils::round_direction direction)848 std::string get_round_text(spvutils::round_direction direction) {
849 #define CASE(round_direction) \
850   case round_direction:      \
851     return #round_direction
852 
853   switch (direction) {
854     CASE(spvutils::kRoundToZero);
855     CASE(spvutils::kRoundToPositiveInfinity);
856     CASE(spvutils::kRoundToNegativeInfinity);
857     CASE(spvutils::kRoundToNearestEven);
858   }
859 #undef CASE
860   return "";
861 }
862 
863 using HexFloatFP32To16Tests = ::testing::TestWithParam<DownCastTest>;
864 
TEST_P(HexFloatFP32To16Tests,NarrowingCasts)865 TEST_P(HexFloatFP32To16Tests, NarrowingCasts) {
866   using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
867   using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
868   HF f(GetParam().source_float);
869   for (auto round : GetParam().directions) {
870     HF16 half(0);
871     f.castTo(half, round);
872     EXPECT_EQ(GetParam().expected_half, half.value().getAsFloat().get_value())
873         << get_round_text(round) << "  " << std::hex
874         << spvutils::BitwiseCast<uint32_t>(GetParam().source_float)
875         << " cast to: " << half.value().getAsFloat().get_value();
876   }
877 }
878 
879 const uint16_t positive_infinity = 0x7C00;
880 const uint16_t negative_infinity = 0xFC00;
881 
882 INSTANTIATE_TEST_CASE_P(F32ToF16, HexFloatFP32To16Tests,
883   ::testing::ValuesIn(std::vector<DownCastTest>(
884   {
885     // Exactly representable as half.
886     {0.f, 0x0, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
887     {-0.f, 0x8000, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
888     {1.0f, 0x3C00, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
889     {-1.0f, 0xBC00, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
890 
891     {float_fractions({0, 1, 10}) , 0x3E01, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
892     {-float_fractions({0, 1, 10}) , 0xBE01, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
893     {static_cast<float>(ldexp(float_fractions({0, 1, 10}), 3)), 0x4A01, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
894     {static_cast<float>(-ldexp(float_fractions({0, 1, 10}), 3)), 0xCA01, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
895 
896 
897     // Underflow
898     {static_cast<float>(ldexp(1.0f, -25)), 0x0, {spvutils::kRoundToZero, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
899     {static_cast<float>(ldexp(1.0f, -25)), 0x1, {spvutils::kRoundToPositiveInfinity}},
900     {static_cast<float>(-ldexp(1.0f, -25)), 0x8000, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNearestEven}},
901     {static_cast<float>(-ldexp(1.0f, -25)), 0x8001, {spvutils::kRoundToNegativeInfinity}},
902     {static_cast<float>(ldexp(1.0f, -24)), 0x1, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
903 
904     // Overflow
905     {static_cast<float>(ldexp(1.0f, 16)), positive_infinity, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
906     {static_cast<float>(ldexp(1.0f, 18)), positive_infinity, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
907     {static_cast<float>(ldexp(1.3f, 16)), positive_infinity, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
908     {static_cast<float>(-ldexp(1.0f, 16)), negative_infinity, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
909     {static_cast<float>(-ldexp(1.0f, 18)), negative_infinity, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
910     {static_cast<float>(-ldexp(1.3f, 16)), negative_infinity, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
911 
912     // Transfer of Infinities
913     {std::numeric_limits<float>::infinity(), positive_infinity, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
914     {-std::numeric_limits<float>::infinity(), negative_infinity, {spvutils::kRoundToZero, spvutils::kRoundToPositiveInfinity, spvutils::kRoundToNegativeInfinity, spvutils::kRoundToNearestEven}},
915 
916     // Nans are below because we cannot test for equality.
917   })),);
918 
919 struct UpCastCase{
920   uint16_t source_half;
921   float expected_float;
922 };
923 
924 using HexFloatFP16To32Tests = ::testing::TestWithParam<UpCastCase>;
TEST_P(HexFloatFP16To32Tests,WideningCasts)925 TEST_P(HexFloatFP16To32Tests, WideningCasts) {
926   using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
927   using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
928   HF16 f(GetParam().source_half);
929 
930   spvutils::round_direction rounding[] = {
931       spvutils::kRoundToZero,
932       spvutils::kRoundToNearestEven,
933       spvutils::kRoundToPositiveInfinity,
934       spvutils::kRoundToNegativeInfinity};
935 
936   // Everything fits, so everything should just be bit-shifts.
937   for (spvutils::round_direction round : rounding) {
938     HF flt(0.f);
939     f.castTo(flt, round);
940     EXPECT_EQ(GetParam().expected_float, flt.value().getAsFloat())
941         << get_round_text(round) << "  " << std::hex
942         << spvutils::BitwiseCast<uint16_t>(GetParam().source_half)
943         << " cast to: " << flt.value().getAsFloat();
944   }
945 }
946 
947 INSTANTIATE_TEST_CASE_P(F16ToF32, HexFloatFP16To32Tests,
948   ::testing::ValuesIn(std::vector<UpCastCase>(
949   {
950     {0x0000, 0.f},
951     {0x8000, -0.f},
952     {0x3C00, 1.0f},
953     {0xBC00, -1.0f},
954     {0x3F00, float_fractions({0, 1, 2})},
955     {0xBF00, -float_fractions({0, 1, 2})},
956     {0x3F01, float_fractions({0, 1, 2, 10})},
957     {0xBF01, -float_fractions({0, 1, 2, 10})},
958 
959     // denorm
960     {0x0001, static_cast<float>(ldexp(1.0, -24))},
961     {0x0002, static_cast<float>(ldexp(1.0, -23))},
962     {0x8001, static_cast<float>(-ldexp(1.0, -24))},
963     {0x8011, static_cast<float>(-ldexp(1.0, -20) + -ldexp(1.0, -24))},
964 
965     // inf
966     {0x7C00, std::numeric_limits<float>::infinity()},
967     {0xFC00, -std::numeric_limits<float>::infinity()},
968   })),);
969 
TEST(HexFloatOperationTests,NanTests)970 TEST(HexFloatOperationTests, NanTests) {
971   using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
972   using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
973   spvutils::round_direction rounding[] = {
974       spvutils::kRoundToZero,
975       spvutils::kRoundToNearestEven,
976       spvutils::kRoundToPositiveInfinity,
977       spvutils::kRoundToNegativeInfinity};
978 
979   // Everything fits, so everything should just be bit-shifts.
980   for (spvutils::round_direction round : rounding) {
981     HF16 f16(0);
982     HF f(0.f);
983     HF(std::numeric_limits<float>::quiet_NaN()).castTo(f16, round);
984     EXPECT_TRUE(f16.value().isNan());
985     HF(std::numeric_limits<float>::signaling_NaN()).castTo(f16, round);
986     EXPECT_TRUE(f16.value().isNan());
987 
988     HF16(0x7C01).castTo(f, round);
989     EXPECT_TRUE(f.value().isNan());
990     HF16(0x7C11).castTo(f, round);
991     EXPECT_TRUE(f.value().isNan());
992     HF16(0xFC01).castTo(f, round);
993     EXPECT_TRUE(f.value().isNan());
994     HF16(0x7C10).castTo(f, round);
995     EXPECT_TRUE(f.value().isNan());
996     HF16(0xFF00).castTo(f, round);
997     EXPECT_TRUE(f.value().isNan());
998   }
999 }
1000 
1001 // A test case for parsing good and bad HexFloat<FloatProxy<T>> literals.
1002 template <typename T>
1003 struct FloatParseCase {
1004   std::string literal;
1005   bool negate_value;
1006   bool expect_success;
1007   HexFloat<FloatProxy<T>> expected_value;
1008 };
1009 
1010 using ParseNormalFloatTest = ::testing::TestWithParam<FloatParseCase<float>>;
1011 
TEST_P(ParseNormalFloatTest,Samples)1012 TEST_P(ParseNormalFloatTest, Samples) {
1013   std::stringstream input(GetParam().literal);
1014   HexFloat<FloatProxy<float>> parsed_value(0.0f);
1015   ParseNormalFloat(input, GetParam().negate_value, parsed_value);
1016   EXPECT_NE(GetParam().expect_success, input.fail())
1017       << " literal: " << GetParam().literal
1018       << " negate: " << GetParam().negate_value;
1019   if (GetParam().expect_success) {
1020     EXPECT_THAT(parsed_value.value(), Eq(GetParam().expected_value.value()))
1021         << " literal: " << GetParam().literal
1022         << " negate: " << GetParam().negate_value;
1023   }
1024 }
1025 
1026 // Returns a FloatParseCase with expected failure.
1027 template <typename T>
BadFloatParseCase(std::string literal,bool negate_value,T expected_value)1028 FloatParseCase<T> BadFloatParseCase(std::string literal, bool negate_value,
1029                                     T expected_value) {
1030   HexFloat<FloatProxy<T>> proxy_expected_value(expected_value);
1031   return FloatParseCase<T>{literal, negate_value, false, proxy_expected_value};
1032 }
1033 
1034 // Returns a FloatParseCase that should successfully parse to a given value.
1035 template <typename T>
GoodFloatParseCase(std::string literal,bool negate_value,T expected_value)1036 FloatParseCase<T> GoodFloatParseCase(std::string literal, bool negate_value,
1037                                      T expected_value) {
1038   HexFloat<FloatProxy<T>> proxy_expected_value(expected_value);
1039   return FloatParseCase<T>{literal, negate_value, true, proxy_expected_value};
1040 }
1041 
1042 INSTANTIATE_TEST_CASE_P(
1043     FloatParse, ParseNormalFloatTest,
1044     ::testing::ValuesIn(std::vector<FloatParseCase<float>>{
1045         // Failing cases due to trivially incorrect syntax.
1046         BadFloatParseCase("abc", false, 0.0f),
1047         BadFloatParseCase("abc", true, 0.0f),
1048 
1049         // Valid cases.
1050         GoodFloatParseCase("0", false, 0.0f),
1051         GoodFloatParseCase("0.0", false, 0.0f),
1052         GoodFloatParseCase("-0.0", false, -0.0f),
1053         GoodFloatParseCase("2.0", false, 2.0f),
1054         GoodFloatParseCase("-2.0", false, -2.0f),
1055         GoodFloatParseCase("+2.0", false, 2.0f),
1056         // Cases with negate_value being true.
1057         GoodFloatParseCase("0.0", true, -0.0f),
1058         GoodFloatParseCase("2.0", true, -2.0f),
1059 
1060         // When negate_value is true, we should not accept a
1061         // leading minus or plus.
1062         BadFloatParseCase("-0.0", true, 0.0f),
1063         BadFloatParseCase("-2.0", true, 0.0f),
1064         BadFloatParseCase("+0.0", true, 0.0f),
1065         BadFloatParseCase("+2.0", true, 0.0f),
1066 
1067         // Overflow is an error for 32-bit float parsing.
1068         BadFloatParseCase("1e40", false, FLT_MAX),
1069         BadFloatParseCase("1e40", true, -FLT_MAX),
1070         BadFloatParseCase("-1e40", false, -FLT_MAX),
1071         // We can't have -1e40 and negate_value == true since
1072         // that represents an original case of "--1e40" which
1073         // is invalid.
1074   }),);
1075 
1076 using ParseNormalFloat16Test =
1077     ::testing::TestWithParam<FloatParseCase<Float16>>;
1078 
TEST_P(ParseNormalFloat16Test,Samples)1079 TEST_P(ParseNormalFloat16Test, Samples) {
1080   std::stringstream input(GetParam().literal);
1081   HexFloat<FloatProxy<Float16>> parsed_value(0);
1082   ParseNormalFloat(input, GetParam().negate_value, parsed_value);
1083   EXPECT_NE(GetParam().expect_success, input.fail())
1084       << " literal: " << GetParam().literal
1085       << " negate: " << GetParam().negate_value;
1086   if (GetParam().expect_success) {
1087     EXPECT_THAT(parsed_value.value(), Eq(GetParam().expected_value.value()))
1088         << " literal: " << GetParam().literal
1089         << " negate: " << GetParam().negate_value;
1090   }
1091 }
1092 
1093 INSTANTIATE_TEST_CASE_P(
1094     Float16Parse, ParseNormalFloat16Test,
1095     ::testing::ValuesIn(std::vector<FloatParseCase<Float16>>{
1096         // Failing cases due to trivially incorrect syntax.
1097         BadFloatParseCase<Float16>("abc", false, uint16_t{0}),
1098         BadFloatParseCase<Float16>("abc", true, uint16_t{0}),
1099 
1100         // Valid cases.
1101         GoodFloatParseCase<Float16>("0", false, uint16_t{0}),
1102         GoodFloatParseCase<Float16>("0.0", false, uint16_t{0}),
1103         GoodFloatParseCase<Float16>("-0.0", false, uint16_t{0x8000}),
1104         GoodFloatParseCase<Float16>("2.0", false, uint16_t{0x4000}),
1105         GoodFloatParseCase<Float16>("-2.0", false, uint16_t{0xc000}),
1106         GoodFloatParseCase<Float16>("+2.0", false, uint16_t{0x4000}),
1107         // Cases with negate_value being true.
1108         GoodFloatParseCase<Float16>("0.0", true, uint16_t{0x8000}),
1109         GoodFloatParseCase<Float16>("2.0", true, uint16_t{0xc000}),
1110 
1111         // When negate_value is true, we should not accept a leading minus or
1112         // plus.
1113         BadFloatParseCase<Float16>("-0.0", true, uint16_t{0}),
1114         BadFloatParseCase<Float16>("-2.0", true, uint16_t{0}),
1115         BadFloatParseCase<Float16>("+0.0", true, uint16_t{0}),
1116         BadFloatParseCase<Float16>("+2.0", true, uint16_t{0}),
1117     }),);
1118 
1119 // A test case for detecting infinities.
1120 template <typename T>
1121 struct OverflowParseCase {
1122   std::string input;
1123   bool expect_success;
1124   T expected_value;
1125 };
1126 
1127 using FloatProxyParseOverflowFloatTest =
1128     ::testing::TestWithParam<OverflowParseCase<float>>;
1129 
TEST_P(FloatProxyParseOverflowFloatTest,Sample)1130 TEST_P(FloatProxyParseOverflowFloatTest, Sample) {
1131   std::istringstream input(GetParam().input);
1132   HexFloat<FloatProxy<float>> value(0.0f);
1133   input >> value;
1134   EXPECT_NE(GetParam().expect_success, input.fail());
1135   if (GetParam().expect_success) {
1136     EXPECT_THAT(value.value().getAsFloat(), GetParam().expected_value);
1137   }
1138 }
1139 
1140 INSTANTIATE_TEST_CASE_P(
1141     FloatOverflow, FloatProxyParseOverflowFloatTest,
1142     ::testing::ValuesIn(std::vector<OverflowParseCase<float>>({
1143         {"0", true, 0.0f},
1144         {"0.0", true, 0.0f},
1145         {"1.0", true, 1.0f},
1146         {"1e38", true, 1e38f},
1147         {"-1e38", true, -1e38f},
1148         {"1e40", false, FLT_MAX},
1149         {"-1e40", false, -FLT_MAX},
1150         {"1e400", false, FLT_MAX},
1151         {"-1e400", false, -FLT_MAX},
1152     })),);
1153 
1154 using FloatProxyParseOverflowDoubleTest =
1155     ::testing::TestWithParam<OverflowParseCase<double>>;
1156 
TEST_P(FloatProxyParseOverflowDoubleTest,Sample)1157 TEST_P(FloatProxyParseOverflowDoubleTest, Sample) {
1158   std::istringstream input(GetParam().input);
1159   HexFloat<FloatProxy<double>> value(0.0);
1160   input >> value;
1161   EXPECT_NE(GetParam().expect_success, input.fail());
1162   if (GetParam().expect_success) {
1163     EXPECT_THAT(value.value().getAsFloat(), Eq(GetParam().expected_value));
1164   }
1165 }
1166 
1167 INSTANTIATE_TEST_CASE_P(
1168     DoubleOverflow, FloatProxyParseOverflowDoubleTest,
1169     ::testing::ValuesIn(std::vector<OverflowParseCase<double>>({
1170         {"0", true, 0.0},
1171         {"0.0", true, 0.0},
1172         {"1.0", true, 1.0},
1173         {"1e38", true, 1e38},
1174         {"-1e38", true, -1e38},
1175         {"1e40", true, 1e40},
1176         {"-1e40", true, -1e40},
1177         {"1e400", false, DBL_MAX},
1178         {"-1e400", false, -DBL_MAX},
1179     })),);
1180 
1181 using FloatProxyParseOverflowFloat16Test =
1182     ::testing::TestWithParam<OverflowParseCase<uint16_t>>;
1183 
TEST_P(FloatProxyParseOverflowFloat16Test,Sample)1184 TEST_P(FloatProxyParseOverflowFloat16Test, Sample) {
1185   std::istringstream input(GetParam().input);
1186   HexFloat<FloatProxy<Float16>> value(0);
1187   input >> value;
1188   EXPECT_NE(GetParam().expect_success, input.fail()) << " literal: "
1189                                                      << GetParam().input;
1190   if (GetParam().expect_success) {
1191     EXPECT_THAT(value.value().data(), Eq(GetParam().expected_value))
1192         << " literal: " << GetParam().input;
1193   }
1194 }
1195 
1196 INSTANTIATE_TEST_CASE_P(
1197     Float16Overflow, FloatProxyParseOverflowFloat16Test,
1198     ::testing::ValuesIn(std::vector<OverflowParseCase<uint16_t>>({
1199         {"0", true, uint16_t{0}},
1200         {"0.0", true, uint16_t{0}},
1201         {"1.0", true, uint16_t{0x3c00}},
1202         // Overflow for 16-bit float is an error, and returns max or
1203         // lowest value.
1204         {"1e38", false, uint16_t{0x7bff}},
1205         {"1e40", false, uint16_t{0x7bff}},
1206         {"1e400", false, uint16_t{0x7bff}},
1207         {"-1e38", false, uint16_t{0xfbff}},
1208         {"-1e40", false, uint16_t{0xfbff}},
1209         {"-1e400", false, uint16_t{0xfbff}},
1210     })),);
1211 
TEST(FloatProxy,Max)1212 TEST(FloatProxy, Max) {
1213   EXPECT_THAT(FloatProxy<Float16>::max().getAsFloat().get_value(),
1214               Eq(uint16_t{0x7bff}));
1215   EXPECT_THAT(FloatProxy<float>::max().getAsFloat(),
1216               Eq(std::numeric_limits<float>::max()));
1217   EXPECT_THAT(FloatProxy<double>::max().getAsFloat(),
1218               Eq(std::numeric_limits<double>::max()));
1219 }
1220 
TEST(FloatProxy,Lowest)1221 TEST(FloatProxy, Lowest) {
1222   EXPECT_THAT(FloatProxy<Float16>::lowest().getAsFloat().get_value(),
1223               Eq(uint16_t{0xfbff}));
1224   EXPECT_THAT(FloatProxy<float>::lowest().getAsFloat(),
1225               Eq(std::numeric_limits<float>::lowest()));
1226   EXPECT_THAT(FloatProxy<double>::lowest().getAsFloat(),
1227               Eq(std::numeric_limits<double>::lowest()));
1228 }
1229 
1230 // TODO(awoloszyn): Add fp16 tests and HexFloatTraits.
1231 }  // anonymous namespace
1232