1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests that constants in program memory round trip as expected.
17
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/array3d.h"
23 #include "tensorflow/compiler/xla/array4d.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace xla {
37 namespace {
38
39 class ConstantsTest : public ClientLibraryTestBase {
40 protected:
41 const ErrorSpec error_spec_{1e-3, 1e-5};
42 };
43
TEST_F(ConstantsTest,ZeroCellF32)44 TEST_F(ConstantsTest, ZeroCellF32) {
45 XlaBuilder builder(TestName());
46 ConstantR1<float>(&builder, {});
47
48 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
49 }
50
TEST_F(ConstantsTest,OneCellF32)51 TEST_F(ConstantsTest, OneCellF32) {
52 std::vector<float> constant = {2.0};
53
54 XlaBuilder builder(TestName());
55 ConstantR1<float>(&builder, constant);
56
57 ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
58 }
59
TEST_F(ConstantsTest,OneCellS32)60 TEST_F(ConstantsTest, OneCellS32) {
61 std::vector<int32> constant = {2};
62
63 XlaBuilder builder(TestName());
64 ConstantR1<int32>(&builder, constant);
65
66 ComputeAndCompareR1<int32>(&builder, constant, {});
67 }
68
TEST_F(ConstantsTest,OneCellU32)69 TEST_F(ConstantsTest, OneCellU32) {
70 std::vector<uint32> constant = {2};
71
72 XlaBuilder builder(TestName());
73 ConstantR1<uint32>(&builder, constant);
74
75 ComputeAndCompareR1<uint32>(&builder, constant, {});
76 }
77
TEST_F(ConstantsTest,EightCells)78 TEST_F(ConstantsTest, EightCells) {
79 std::vector<float> constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
80
81 XlaBuilder builder(TestName());
82 ConstantR1<float>(&builder, constant);
83
84 ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
85 }
86
TEST_F(ConstantsTest,SixteenCells)87 TEST_F(ConstantsTest, SixteenCells) {
88 std::vector<float> constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
89 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0};
90
91 XlaBuilder builder(TestName());
92 ConstantR1<float>(&builder, constant);
93
94 ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
95 }
96
TEST_F(ConstantsTest,Empty_0x2)97 TEST_F(ConstantsTest, Empty_0x2) {
98 XlaBuilder builder(TestName());
99 ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 2));
100
101 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
102 }
103
TEST_F(ConstantsTest,Small_2x2)104 TEST_F(ConstantsTest, Small_2x2) {
105 std::unique_ptr<Array2D<float>> constant =
106 MakeLinspaceArray2D(100.0, 200.0, 2, 2);
107
108 XlaBuilder builder(TestName());
109 ConstantR2FromArray2D<float>(&builder, *constant);
110
111 ComputeAndCompareR2<float>(&builder, *constant, {}, error_spec_);
112 }
113
TEST_F(ConstantsTest,Empty_3x0x2)114 TEST_F(ConstantsTest, Empty_3x0x2) {
115 XlaBuilder builder(TestName());
116 ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(
117 Array3D<float>(3, 0, 2)));
118
119 ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
120 }
121
TEST_F(ConstantsTest,Small_2x2x2)122 TEST_F(ConstantsTest, Small_2x2x2) {
123 XlaBuilder builder(TestName());
124 Array3D<float> array3d({
125 // x0 x1
126 {{1.f, 2.f}, // y0
127 {3.f, 4.f}}, // y1
128
129 {{5.f, 6.f}, // y0
130 {7.f, 8.f}}, // y1
131 });
132 ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(array3d));
133
134 ComputeAndCompareR3<float>(&builder, array3d, {});
135 }
136
TEST_F(ConstantsTest,Small_3x2x1x1)137 TEST_F(ConstantsTest, Small_3x2x1x1) {
138 Array4D<float> input_array(3, 2, 1, 1);
139 Array2D<float> pz({
140 // z0 z1
141 {-1.0f, 4.1f}, // p0
142 {2.0f, 4.1f}, // p1
143 {5.0f, 4.4f}, // p2
144 });
145 input_array.FillWithPZ(pz);
146 Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array);
147
148 {
149 XlaBuilder builder(TestName());
150 ConstantLiteral(&builder, input_literal);
151 ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
152 }
153
154 {
155 XlaBuilder builder(TestName());
156 ConstantR4FromArray4D<float>(&builder, input_array);
157 ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
158 }
159 }
160
161 // TODO(b/29263943): Support tuple constants.
TEST_F(ConstantsTest,DISABLED_TupleConstant)162 TEST_F(ConstantsTest, DISABLED_TupleConstant) {
163 XlaBuilder builder(TestName());
164 ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices(
165 {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
166 LiteralUtil::CreateR1<float>({2.0, 42})}));
167
168 Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
169
170 LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
171 LiteralSlice(result, {0}), error_spec_);
172 LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(result, {1}),
173 error_spec_);
174 }
175
TEST_F(ConstantsTest,Token)176 TEST_F(ConstantsTest, Token) {
177 XlaBuilder builder(TestName());
178 ConstantLiteral(&builder, LiteralUtil::CreateToken());
179 // TODO(b/80000000): tokens cannot be returned from computations.
180 Tuple(&builder, {});
181 TF_ASSERT_OK(Execute(&builder, {}).status());
182 }
183
TEST_F(ConstantsTest,FullLike)184 TEST_F(ConstantsTest, FullLike) {
185 XlaBuilder b(TestName());
186 auto val1 = Iota(&b, F32, 3);
187 auto val2 = FullLike(val1, 10);
188 val1 + val2;
189 ComputeAndCompareR1<float>(&b, {10, 11, 12}, {}, error_spec_);
190 }
191
TEST_F(ConstantsTest,IllegalFullLikeOnTuple)192 TEST_F(ConstantsTest, IllegalFullLikeOnTuple) {
193 XlaBuilder b(TestName());
194 auto tuple = Tuple(&b, {Iota(&b, F32, 3), Iota(&b, F32, 1)});
195 FullLike(tuple, 10); // Illegal; can't do FullLike on a tuple.
196 EXPECT_FALSE(b.Build().ok());
197 }
198
TEST_F(ConstantsTest,FullLikeScalar)199 TEST_F(ConstantsTest, FullLikeScalar) {
200 XlaBuilder b(TestName());
201 auto scalar1 = ConstantR0WithType(&b, F32, 1);
202 auto scalar2 = FullLike(scalar1, 2);
203 scalar1 - scalar2;
204 ComputeAndCompareR0<float>(&b, -1, {}, error_spec_);
205 }
206
207 class ConstantsHloTest : public HloTestBase {};
208
209 // TODO(b/121147351): Fails on GPU. Not clear if this is expected behavior.
XLA_TEST_F(ConstantsHloTest,DISABLED_ON_GPU (BitcastOfConstant))210 XLA_TEST_F(ConstantsHloTest, DISABLED_ON_GPU(BitcastOfConstant)) {
211 const char* testcase = R"(
212 HloModule module, is_scheduled=true
213
214 func {
215 lhs = s32[] parameter(0)
216 rhs = s32[] parameter(1)
217 ROOT mul = s32[] add(lhs, rhs)
218 }
219
220 ENTRY test {
221 constant.0 = s32[1]{0} constant({0})
222 parameter.0 = s32[] parameter(0)
223 constant-as-scalar = s32[] bitcast(constant.0)
224 ROOT result = s32[] call(parameter.0, constant-as-scalar), to_apply=func
225 }
226 )";
227 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
228 auto param = LiteralUtil::CreateR0<int32>(1);
229 auto result = ExecuteNoHloPasses(std::move(module), {¶m});
230 EXPECT_TRUE(LiteralTestUtil::Equal(param, result));
231 }
232
233 } // namespace
234 } // namespace xla
235