1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/constants.h"
17 #include "tensorflow/compiler/xla/client/xla_builder.h"
18 #include "tensorflow/compiler/xla/test.h"
19 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
20 #include "tensorflow/compiler/xla/tests/test_macros.h"
21 #include "tensorflow/compiler/xla/types.h"
22 #include "tensorflow/compiler/xla/xla_data.pb.h"
23 
24 namespace xla {
25 namespace {
26 
27 using ConstantsTest = ClientLibraryTestBase;
28 
29 using ::testing::HasSubstr;
30 
XLA_TEST_F(ConstantsTest,ConstantR0WithTypeS32)31 XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32) {
32   XlaBuilder builder(TestName());
33   ConstantR0WithType(&builder, xla::S32, 4);
34   ComputeAndCompareR0<int32>(&builder, 4, {});
35 }
36 
XLA_TEST_F(ConstantsTest,ConstantR0WithTypeS32DoesNotAcceptFloats)37 XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32DoesNotAcceptFloats) {
38   XlaBuilder builder(TestName());
39   ConstantR0WithType(&builder, xla::S32, 4.5);
40   auto statusor = builder.Build();
41   ASSERT_FALSE(statusor.ok());
42   EXPECT_THAT(statusor.status().error_message(), HasSubstr("Invalid cast"));
43 }
44 
XLA_TEST_F(ConstantsTest,ConstantR0WithTypeF32)45 XLA_TEST_F(ConstantsTest, ConstantR0WithTypeF32) {
46   XlaBuilder builder(TestName());
47   ConstantR0WithType(&builder, xla::F32, -7);
48   ComputeAndCompareR0<float>(&builder, -7, {});
49   ConstantR0WithType(&builder, xla::F32, 0.5);
50   ComputeAndCompareR0<float>(&builder, 0.5, {});
51 }
52 
XLA_TEST_F(ConstantsTest,ScalarLikeS32)53 XLA_TEST_F(ConstantsTest, ScalarLikeS32) {
54   XlaBuilder builder(TestName());
55   ScalarLike(ConstantR0<int32>(&builder, 42), -3);
56   ComputeAndCompareR0<int32>(&builder, -3, {});
57 }
58 
XLA_TEST_F(ConstantsTest,ScalarLikeF32)59 XLA_TEST_F(ConstantsTest, ScalarLikeF32) {
60   XlaBuilder builder(TestName());
61   ScalarLike(ConstantR0<float>(&builder, 42.75), -3.2);
62   ComputeAndCompareR0<float>(&builder, -3.2, {});
63 }
64 
XLA_TEST_F(ConstantsTest,ZeroS32)65 XLA_TEST_F(ConstantsTest, ZeroS32) {
66   XlaBuilder builder(TestName());
67   Zero(&builder, S32);
68   ComputeAndCompareR0<int32>(&builder, 0, {});
69 }
70 
XLA_TEST_F(ConstantsTest,ZeroF32)71 XLA_TEST_F(ConstantsTest, ZeroF32) {
72   XlaBuilder builder(TestName());
73   Zero(&builder, F32);
74   ComputeAndCompareR0<float>(&builder, 0.0, {});
75 }
76 
XLA_TEST_F(ConstantsTest,ZerosS32)77 XLA_TEST_F(ConstantsTest, ZerosS32) {
78   XlaBuilder builder(TestName());
79   Zeros(&builder, ShapeUtil::MakeShape(S32, {2, 2}));
80   ComputeAndCompareR2<int32>(&builder, {{0, 0}, {0, 0}}, {});
81 }
82 
XLA_TEST_F(ConstantsTest,ZerosLikeF32)83 XLA_TEST_F(ConstantsTest, ZerosLikeF32) {
84   XlaBuilder builder(TestName());
85   ZerosLike(ConstantR1<float>(&builder, {1., 2., 3.}));
86   ComputeAndCompareR1<float>(&builder, {0., 0., 0.}, {});
87 }
88 
XLA_TEST_F(ConstantsTest,OneS32)89 XLA_TEST_F(ConstantsTest, OneS32) {
90   XlaBuilder builder(TestName());
91   One(&builder, S32);
92   ComputeAndCompareR0<int32>(&builder, 1, {});
93 }
94 
XLA_TEST_F(ConstantsTest,OneF32)95 XLA_TEST_F(ConstantsTest, OneF32) {
96   XlaBuilder builder(TestName());
97   One(&builder, F32);
98   ComputeAndCompareR0<float>(&builder, 1., {});
99 }
100 
XLA_TEST_F(ConstantsTest,EpsilonF32)101 XLA_TEST_F(ConstantsTest, EpsilonF32) {
102   XlaBuilder builder(TestName());
103   Epsilon(&builder, F32);
104   ComputeAndCompareR0<float>(&builder, std::numeric_limits<float>::epsilon(),
105                              {});
106 }
107 
XLA_TEST_F(ConstantsTest,MinFiniteValueS32)108 XLA_TEST_F(ConstantsTest, MinFiniteValueS32) {
109   XlaBuilder builder(TestName());
110   MinFiniteValue(&builder, S32);
111   ComputeAndCompareR0<int32>(&builder, std::numeric_limits<int32>::min(), {});
112 }
113 
XLA_TEST_F(ConstantsTest,MaxFiniteValueS32)114 XLA_TEST_F(ConstantsTest, MaxFiniteValueS32) {
115   XlaBuilder builder(TestName());
116   MaxFiniteValue(&builder, S32);
117   ComputeAndCompareR0<int32>(&builder, std::numeric_limits<int32>::max(), {});
118 }
119 
XLA_TEST_F(ConstantsTest,MinFiniteValueF32)120 XLA_TEST_F(ConstantsTest, MinFiniteValueF32) {
121   XlaBuilder builder(TestName());
122   MinFiniteValue(&builder, F32);
123   ComputeAndCompareR0<float>(&builder, -std::numeric_limits<float>::max(), {});
124 }
125 
XLA_TEST_F(ConstantsTest,MaxFiniteValueF32)126 XLA_TEST_F(ConstantsTest, MaxFiniteValueF32) {
127   XlaBuilder builder(TestName());
128   MaxFiniteValue(&builder, F32);
129   ComputeAndCompareR0<float>(&builder, std::numeric_limits<float>::max(), {});
130 }
131 
XLA_TEST_F(ConstantsTest,MinValueS32)132 XLA_TEST_F(ConstantsTest, MinValueS32) {
133   XlaBuilder builder(TestName());
134   MinValue(&builder, S32);
135   ComputeAndCompareR0<int32>(&builder, std::numeric_limits<int32>::min(), {});
136 }
137 
XLA_TEST_F(ConstantsTest,MaxValueS32)138 XLA_TEST_F(ConstantsTest, MaxValueS32) {
139   XlaBuilder builder(TestName());
140   MaxValue(&builder, S32);
141   ComputeAndCompareR0<int32>(&builder, std::numeric_limits<int32>::max(), {});
142 }
143 
XLA_TEST_F(ConstantsTest,MinValueF32)144 XLA_TEST_F(ConstantsTest, MinValueF32) {
145   XlaBuilder builder(TestName());
146   MinValue(&builder, F32);
147   ComputeAndCompareR0<float>(&builder, -std::numeric_limits<float>::infinity(),
148                              {});
149 }
150 
XLA_TEST_F(ConstantsTest,MaxValueF32)151 XLA_TEST_F(ConstantsTest, MaxValueF32) {
152   XlaBuilder builder(TestName());
153   MaxValue(&builder, F32);
154   ComputeAndCompareR0<float>(&builder, std::numeric_limits<float>::infinity(),
155                              {});
156 }
157 
XLA_TEST_F(ConstantsTest,NanValueF32)158 XLA_TEST_F(ConstantsTest, NanValueF32) {
159   XlaBuilder builder(TestName());
160   NanValue(&builder, F32);
161   ComputeAndCompareR0<float>(&builder, std::numeric_limits<float>::quiet_NaN(),
162                              {});
163 }
164 
165 }  // namespace
166 }  // namespace xla
167