1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <limits>
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
25 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
26 #include "tensorflow/compiler/xla/tests/test_macros.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
29 #include "tensorflow/core/platform/test.h"
30 #include "tensorflow/core/platform/types.h"
31
32 namespace xla {
33 namespace {
34
35 class BitcastConvertTest : public ClientLibraryTestBase {
36 public:
BitcastConvertTest(se::Platform * platform=nullptr)37 explicit BitcastConvertTest(se::Platform* platform = nullptr)
38 : ClientLibraryTestBase(platform) {
39 mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
40 mutable_debug_options()->add_xla_disable_hlo_passes("inline");
41 }
42 };
43
TEST_F(BitcastConvertTest,ConvertR1S32ToR1S32)44 TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) {
45 XlaBuilder builder(TestName());
46 auto a = ConstantR1<int32>(&builder, {42, 64});
47 BitcastConvertType(a, S32);
48
49 std::vector<int32> expected = {42, 64};
50 ComputeAndCompareR1<int32>(&builder, expected, {});
51 }
52
TEST_F(BitcastConvertTest,ConvertR1F32ToR1F32)53 TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) {
54 XlaBuilder builder(TestName());
55 auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
56 BitcastConvertType(a, F32);
57
58 std::vector<float> expected = {42.0f, 64.0f};
59 ComputeAndCompareR1<float>(&builder, expected, {});
60 }
61
TEST_F(BitcastConvertTest,BitcastR1S32ToR1F32)62 TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) {
63 XlaBuilder builder(TestName());
64 auto a =
65 ConstantR1<int32>(&builder, {0, static_cast<int32>(0x80000000),
66 0x3F800000, static_cast<int32>(0xBF800000),
67 0x3F000000, static_cast<int32>(0xBF000000)});
68 BitcastConvertType(a, F32);
69
70 std::vector<float> expected = {0.0f, -0.0f, 1.0f, -1.0f, 0.5f, -0.5f};
71 ComputeAndCompareR1<float>(&builder, expected, {});
72 }
73
XLA_TEST_F(BitcastConvertTest,ConvertR1S0S32ToR1S0F32)74 XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) {
75 XlaBuilder builder(TestName());
76 auto a = ConstantR1<int32>(&builder, {});
77 BitcastConvertType(a, F32);
78
79 std::vector<float> expected = {};
80 ComputeAndCompareR1<float>(&builder, expected, {});
81 }
82
TEST_F(BitcastConvertTest,ConvertR1F32ToR1S32)83 TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) {
84 XlaBuilder builder(TestName());
85 auto a = ConstantR1<float>(&builder, {42.6, 64.4});
86 BitcastConvertType(a, S32);
87
88 std::vector<int32> expected = {0x422a6666, 0x4280cccd};
89 ComputeAndCompareR1<int32>(&builder, expected, {});
90 }
91
TEST_F(BitcastConvertTest,ConvertS32Extremes)92 TEST_F(BitcastConvertTest, ConvertS32Extremes) {
93 XlaBuilder builder(TestName());
94 auto a = ConstantR1<int32>(&builder, {std::numeric_limits<int32>::min(),
95 std::numeric_limits<int32>::max()});
96 BitcastConvertType(a, F32);
97
98 std::vector<float> expected = {-0.0f, NAN};
99 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0, 0));
100 }
101
TEST_F(BitcastConvertTest,ConvertMapToS32)102 TEST_F(BitcastConvertTest, ConvertMapToS32) {
103 XlaBuilder builder(TestName());
104 auto b = builder.CreateSubBuilder("convert");
105 auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in");
106 BitcastConvertType(param, S32);
107 auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
108 Map(&builder, {a}, b->BuildAndNoteError(), {0});
109
110 std::vector<int32> expected = {0x42280000, 0x42800000};
111 ComputeAndCompareR1<int32>(&builder, expected, {});
112 }
113
TEST_F(BitcastConvertTest,ConvertMapToF32)114 TEST_F(BitcastConvertTest, ConvertMapToF32) {
115 XlaBuilder builder(TestName());
116 auto b = builder.CreateSubBuilder("convert");
117 auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in");
118 BitcastConvertType(param, F32);
119 auto a = ConstantR1<int32>(&builder, {0x42280000, 0x42800000});
120 Map(&builder, {a}, b->BuildAndNoteError(), {0});
121
122 std::vector<float> expected = {42.0f, 64.0f};
123 ComputeAndCompareR1<float>(&builder, expected, {});
124 }
125
126 // Regression test for b/31758660. When ReshapeMover transforms
127 // input -> reshape -> convert
128 // to
129 // input -> convert -> reshape
130 // the new convert should have the same element type as the old convert.
TEST_F(BitcastConvertTest,ConvertReshape)131 TEST_F(BitcastConvertTest, ConvertReshape) {
132 XlaBuilder builder(TestName());
133 auto input = ConstantR1<int32>(&builder, {0x42280000});
134 auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
135 BitcastConvertType(reshape, F32);
136
137 ComputeAndCompareR0<float>(&builder, 42.0f, {});
138 }
139
140 } // namespace
141 } // namespace xla
142