1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/constants.h"
17 
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/util.h"
20 
21 namespace xla {
22 
Zero(XlaBuilder * builder,PrimitiveType type)23 XlaOp Zero(XlaBuilder* builder, PrimitiveType type) {
24   return ConstantLiteral(builder, LiteralUtil::Zero(type));
25 }
26 
Zeros(XlaBuilder * builder,const Shape & shape)27 XlaOp Zeros(XlaBuilder* builder, const Shape& shape) {
28   return Broadcast(Zero(builder, shape.element_type()),
29                    AsInt64Slice(shape.dimensions()));
30 }
31 
ZerosLike(XlaOp prototype)32 XlaOp ZerosLike(XlaOp prototype) {
33   XlaBuilder* builder = prototype.builder();
34   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
35     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
36     return Zeros(builder, shape);
37   });
38 }
39 
One(XlaBuilder * builder,PrimitiveType type)40 XlaOp One(XlaBuilder* builder, PrimitiveType type) {
41   return ConstantLiteral(builder, LiteralUtil::One(type));
42 }
43 
Epsilon(XlaBuilder * builder,PrimitiveType type)44 XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
45   switch (type) {
46     case F16:
47       return ConstantR0<Eigen::half>(
48           builder,
49           static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
50     case BF16:
51       return ConstantR0<bfloat16>(builder, bfloat16::epsilon());
52     case F32:
53       return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
54     case F64:
55       return ConstantR0<double>(builder,
56                                 std::numeric_limits<double>::epsilon());
57     default:
58       return builder->ReportError(InvalidArgument(
59           "Invalid type for Epsilon (%s).", PrimitiveType_Name(type)));
60   }
61 }
62 
MinValue(XlaBuilder * builder,PrimitiveType type)63 XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) {
64   return ConstantLiteral(builder, LiteralUtil::MinValue(type));
65 }
66 
MinFiniteValue(XlaBuilder * builder,PrimitiveType type)67 XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
68   switch (type) {
69     case F16:
70       return ConstantR0<Eigen::half>(builder,
71                                      Eigen::NumTraits<Eigen::half>::lowest());
72     case BF16:
73       return ConstantR0<bfloat16>(builder, bfloat16::lowest());
74     case F32:
75       return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
76     case F64:
77       return ConstantR0<double>(builder, -std::numeric_limits<double>::max());
78     default:
79       return MinValue(builder, type);
80   }
81 }
82 
MinPositiveNormalValue(XlaBuilder * builder,PrimitiveType type)83 XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) {
84   switch (type) {
85     case F16:
86       return ConstantR0<Eigen::half>(builder,
87                                      std::numeric_limits<Eigen::half>::min());
88     case BF16:
89       return ConstantR0<bfloat16>(builder, bfloat16::min_positive_normal());
90     case F32:
91       return ConstantR0<float>(builder, std::numeric_limits<float>::min());
92     case F64:
93       return ConstantR0<double>(builder, std::numeric_limits<double>::min());
94     default:
95       return builder->ReportError(
96           InvalidArgument("Invalid type for MinPositiveNormalValue (%s).",
97                           PrimitiveType_Name(type)));
98   }
99 }
100 
MaxValue(XlaBuilder * builder,PrimitiveType type)101 XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) {
102   return ConstantLiteral(builder, LiteralUtil::MaxValue(type));
103 }
104 
MaxFiniteValue(XlaBuilder * builder,PrimitiveType type)105 XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
106   switch (type) {
107     case F16:
108       return ConstantR0<Eigen::half>(builder,
109                                      Eigen::NumTraits<Eigen::half>::highest());
110     case BF16:
111       return ConstantR0<bfloat16>(builder, bfloat16::highest());
112     case F32:
113       return ConstantR0<float>(builder, std::numeric_limits<float>::max());
114     case F64:
115       return ConstantR0<double>(builder, std::numeric_limits<double>::max());
116     default:
117       return MaxValue(builder, type);
118   }
119 }
120 
NanValue(XlaBuilder * builder,PrimitiveType type)121 XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
122   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
123     switch (type) {
124       case F16:
125         return ConstantR0<Eigen::half>(
126             builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
127       case BF16:
128         return ConstantR0<bfloat16>(
129             builder, bfloat16(std::numeric_limits<float>::quiet_NaN()));
130       case F32:
131         return ConstantR0<float>(builder,
132                                  std::numeric_limits<float>::quiet_NaN());
133       case F64:
134         return ConstantR0<double>(builder,
135                                   std::numeric_limits<double>::quiet_NaN());
136       default:
137         return InvalidArgument(
138             "Operand to NanValue was %s, but must be a real-valued "
139             "floating-point type.",
140             PrimitiveType_Name(type));
141     }
142   });
143 }
144 
145 }  // namespace xla
146