1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests that constants in program memory round trip as expected.
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/array3d.h"
23 #include "tensorflow/compiler/xla/array4d.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace xla {
37 namespace {
38 
39 class ConstantsTest : public ClientLibraryTestBase {
40  protected:
41   const ErrorSpec error_spec_{1e-3, 1e-5};
42 };
43 
TEST_F(ConstantsTest,ZeroCellF32)44 TEST_F(ConstantsTest, ZeroCellF32) {
45   XlaBuilder builder(TestName());
46   ConstantR1<float>(&builder, {});
47 
48   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
49 }
50 
TEST_F(ConstantsTest,OneCellF32)51 TEST_F(ConstantsTest, OneCellF32) {
52   std::vector<float> constant = {2.0};
53 
54   XlaBuilder builder(TestName());
55   ConstantR1<float>(&builder, constant);
56 
57   ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
58 }
59 
TEST_F(ConstantsTest,OneCellS32)60 TEST_F(ConstantsTest, OneCellS32) {
61   std::vector<int32> constant = {2};
62 
63   XlaBuilder builder(TestName());
64   ConstantR1<int32>(&builder, constant);
65 
66   ComputeAndCompareR1<int32>(&builder, constant, {});
67 }
68 
TEST_F(ConstantsTest,OneCellU32)69 TEST_F(ConstantsTest, OneCellU32) {
70   std::vector<uint32> constant = {2};
71 
72   XlaBuilder builder(TestName());
73   ConstantR1<uint32>(&builder, constant);
74 
75   ComputeAndCompareR1<uint32>(&builder, constant, {});
76 }
77 
TEST_F(ConstantsTest,EightCells)78 TEST_F(ConstantsTest, EightCells) {
79   std::vector<float> constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
80 
81   XlaBuilder builder(TestName());
82   ConstantR1<float>(&builder, constant);
83 
84   ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
85 }
86 
TEST_F(ConstantsTest,SixteenCells)87 TEST_F(ConstantsTest, SixteenCells) {
88   std::vector<float> constant = {0.0, 1.0, 2.0,  3.0,  4.0,  5.0,  6.0,  7.0,
89                                  8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0};
90 
91   XlaBuilder builder(TestName());
92   ConstantR1<float>(&builder, constant);
93 
94   ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
95 }
96 
TEST_F(ConstantsTest,Empty_0x2)97 TEST_F(ConstantsTest, Empty_0x2) {
98   XlaBuilder builder(TestName());
99   ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 2));
100 
101   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
102 }
103 
TEST_F(ConstantsTest,Small_2x2)104 TEST_F(ConstantsTest, Small_2x2) {
105   std::unique_ptr<Array2D<float>> constant =
106       MakeLinspaceArray2D(100.0, 200.0, 2, 2);
107 
108   XlaBuilder builder(TestName());
109   ConstantR2FromArray2D<float>(&builder, *constant);
110 
111   ComputeAndCompareR2<float>(&builder, *constant, {}, error_spec_);
112 }
113 
TEST_F(ConstantsTest,Empty_3x0x2)114 TEST_F(ConstantsTest, Empty_3x0x2) {
115   XlaBuilder builder(TestName());
116   ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(
117                                 Array3D<float>(3, 0, 2)));
118 
119   ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
120 }
121 
TEST_F(ConstantsTest,Small_2x2x2)122 TEST_F(ConstantsTest, Small_2x2x2) {
123   XlaBuilder builder(TestName());
124   Array3D<float> array3d({
125       // x0  x1
126       {{1.f, 2.f},   // y0
127        {3.f, 4.f}},  // y1
128 
129       {{5.f, 6.f},   // y0
130        {7.f, 8.f}},  // y1
131   });
132   ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(array3d));
133 
134   ComputeAndCompareR3<float>(&builder, array3d, {});
135 }
136 
TEST_F(ConstantsTest,Small_3x2x1x1)137 TEST_F(ConstantsTest, Small_3x2x1x1) {
138   Array4D<float> input_array(3, 2, 1, 1);
139   Array2D<float> pz({
140       // z0 z1
141       {-1.0f, 4.1f},  // p0
142       {2.0f, 4.1f},   // p1
143       {5.0f, 4.4f},   // p2
144   });
145   input_array.FillWithPZ(pz);
146   Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array);
147 
148   {
149     XlaBuilder builder(TestName());
150     ConstantLiteral(&builder, input_literal);
151     ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
152   }
153 
154   {
155     XlaBuilder builder(TestName());
156     ConstantR4FromArray4D<float>(&builder, input_array);
157     ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
158   }
159 }
160 
161 // TODO(b/29263943): Support tuple constants.
TEST_F(ConstantsTest,DISABLED_TupleConstant)162 TEST_F(ConstantsTest, DISABLED_TupleConstant) {
163   XlaBuilder builder(TestName());
164   ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices(
165                                 {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
166                                  LiteralUtil::CreateR1<float>({2.0, 42})}));
167 
168   Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
169 
170   LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
171                                        LiteralSlice(result, {0}), error_spec_);
172   LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(result, {1}),
173                                        error_spec_);
174 }
175 
TEST_F(ConstantsTest,Token)176 TEST_F(ConstantsTest, Token) {
177   XlaBuilder builder(TestName());
178   ConstantLiteral(&builder, LiteralUtil::CreateToken());
179   // TODO(b/80000000): tokens cannot be returned from computations.
180   Tuple(&builder, {});
181   TF_ASSERT_OK(Execute(&builder, {}).status());
182 }
183 
TEST_F(ConstantsTest,FullLike)184 TEST_F(ConstantsTest, FullLike) {
185   XlaBuilder b(TestName());
186   auto val1 = Iota(&b, F32, 3);
187   auto val2 = FullLike(val1, 10);
188   val1 + val2;
189   ComputeAndCompareR1<float>(&b, {10, 11, 12}, {}, error_spec_);
190 }
191 
TEST_F(ConstantsTest,IllegalFullLikeOnTuple)192 TEST_F(ConstantsTest, IllegalFullLikeOnTuple) {
193   XlaBuilder b(TestName());
194   auto tuple = Tuple(&b, {Iota(&b, F32, 3), Iota(&b, F32, 1)});
195   FullLike(tuple, 10);  // Illegal; can't do FullLike on a tuple.
196   EXPECT_FALSE(b.Build().ok());
197 }
198 
TEST_F(ConstantsTest,FullLikeScalar)199 TEST_F(ConstantsTest, FullLikeScalar) {
200   XlaBuilder b(TestName());
201   auto scalar1 = ConstantR0WithType(&b, F32, 1);
202   auto scalar2 = FullLike(scalar1, 2);
203   scalar1 - scalar2;
204   ComputeAndCompareR0<float>(&b, -1, {}, error_spec_);
205 }
206 
207 class ConstantsHloTest : public HloTestBase {};
208 
209 // TODO(b/121147351): Fails on GPU. Not clear if this is expected behavior.
XLA_TEST_F(ConstantsHloTest,DISABLED_ON_GPU (BitcastOfConstant))210 XLA_TEST_F(ConstantsHloTest, DISABLED_ON_GPU(BitcastOfConstant)) {
211   const char* testcase = R"(
212     HloModule module, is_scheduled=true
213 
214     func {
215       lhs = s32[] parameter(0)
216       rhs = s32[] parameter(1)
217       ROOT mul = s32[] add(lhs, rhs)
218     }
219 
220     ENTRY test {
221       constant.0 = s32[1]{0} constant({0})
222       parameter.0 = s32[] parameter(0)
223       constant-as-scalar = s32[] bitcast(constant.0)
224       ROOT result = s32[] call(parameter.0, constant-as-scalar), to_apply=func
225     }
226   )";
227   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
228   auto param = LiteralUtil::CreateR0<int32>(1);
229   auto result = ExecuteNoHloPasses(std::move(module), {&param});
230   EXPECT_TRUE(LiteralTestUtil::Equal(param, result));
231 }
232 
233 }  // namespace
234 }  // namespace xla
235