1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/constants.h"
17 
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/util.h"
20 
21 namespace xla {
22 
Zero(XlaBuilder * builder,PrimitiveType type)23 XlaOp Zero(XlaBuilder* builder, PrimitiveType type) {
24   return ConstantLiteral(builder, LiteralUtil::Zero(type));
25 }
26 
Zeros(XlaBuilder * builder,const Shape & shape)27 XlaOp Zeros(XlaBuilder* builder, const Shape& shape) {
28   return Broadcast(Zero(builder, shape.element_type()),
29                    AsInt64Slice(shape.dimensions()));
30 }
31 
ZerosLike(XlaOp prototype)32 XlaOp ZerosLike(XlaOp prototype) {
33   XlaBuilder* builder = prototype.builder();
34   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
35     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
36     return Zeros(builder, shape);
37   });
38 }
39 
One(XlaBuilder * builder,PrimitiveType type)40 XlaOp One(XlaBuilder* builder, PrimitiveType type) {
41   return ConstantLiteral(builder, LiteralUtil::One(type));
42 }
43 
Epsilon(XlaBuilder * builder,PrimitiveType type)44 XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
45   switch (type) {
46     case F16:
47       return ConstantR0<Eigen::half>(
48           builder,
49           static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
50     case BF16:
51       return ConstantR0<Eigen::bfloat16>(
52           builder, static_cast<Eigen::bfloat16>(
53                        Eigen::NumTraits<Eigen::bfloat16>::epsilon()));
54     case F32:
55       return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
56     case F64:
57       return ConstantR0<double>(builder,
58                                 std::numeric_limits<double>::epsilon());
59     default:
60       return builder->ReportError(InvalidArgument(
61           "Invalid type for Epsilon (%s).", PrimitiveType_Name(type)));
62   }
63 }
64 
MinValue(XlaBuilder * builder,PrimitiveType type)65 XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) {
66   return ConstantLiteral(builder, LiteralUtil::MinValue(type));
67 }
68 
MinFiniteValue(XlaBuilder * builder,PrimitiveType type)69 XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
70   switch (type) {
71     case F16:
72       return ConstantR0<Eigen::half>(builder,
73                                      Eigen::NumTraits<Eigen::half>::lowest());
74     case BF16:
75       return ConstantR0<Eigen::bfloat16>(
76           builder, Eigen::NumTraits<Eigen::bfloat16>::lowest());
77     case F32:
78       return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
79     case F64:
80       return ConstantR0<double>(builder, -std::numeric_limits<double>::max());
81     default:
82       return MinValue(builder, type);
83   }
84 }
85 
MinPositiveNormalValue(XlaBuilder * builder,PrimitiveType type)86 XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) {
87   switch (type) {
88     case F16:
89       return ConstantR0<Eigen::half>(builder,
90                                      std::numeric_limits<Eigen::half>::min());
91     case BF16:
92       return ConstantR0<Eigen::bfloat16>(
93           builder, std::numeric_limits<Eigen::bfloat16>::min());
94     case F32:
95       return ConstantR0<float>(builder, std::numeric_limits<float>::min());
96     case F64:
97       return ConstantR0<double>(builder, std::numeric_limits<double>::min());
98     default:
99       return builder->ReportError(
100           InvalidArgument("Invalid type for MinPositiveNormalValue (%s).",
101                           PrimitiveType_Name(type)));
102   }
103 }
104 
MaxValue(XlaBuilder * builder,PrimitiveType type)105 XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) {
106   return ConstantLiteral(builder, LiteralUtil::MaxValue(type));
107 }
108 
MaxFiniteValue(XlaBuilder * builder,PrimitiveType type)109 XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
110   switch (type) {
111     case F16:
112       return ConstantR0<Eigen::half>(builder,
113                                      Eigen::NumTraits<Eigen::half>::highest());
114     case BF16:
115       return ConstantR0<Eigen::bfloat16>(
116           builder, Eigen::NumTraits<Eigen::bfloat16>::highest());
117     case F32:
118       return ConstantR0<float>(builder, std::numeric_limits<float>::max());
119     case F64:
120       return ConstantR0<double>(builder, std::numeric_limits<double>::max());
121     default:
122       return MaxValue(builder, type);
123   }
124 }
125 
NanValue(XlaBuilder * builder,PrimitiveType type)126 XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
127   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
128     switch (type) {
129       case F16:
130         return ConstantR0<Eigen::half>(
131             builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
132       case BF16:
133         return ConstantR0<Eigen::bfloat16>(
134             builder, Eigen::NumTraits<Eigen::bfloat16>::quiet_NaN());
135       case F32:
136         return ConstantR0<float>(builder,
137                                  std::numeric_limits<float>::quiet_NaN());
138       case F64:
139         return ConstantR0<double>(builder,
140                                   std::numeric_limits<double>::quiet_NaN());
141       default:
142         return InvalidArgument(
143             "Operand to NanValue was %s, but must be a real-valued "
144             "floating-point type.",
145             PrimitiveType_Name(type));
146     }
147   });
148 }
149 
150 }  // namespace xla
151