1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
20 #include "tensorflow/core/platform/macros.h"
21
22 namespace tensorflow {
23 namespace {
24
25 // Gymnastics with nudged zero point is to ensure that the real zero maps to
26 // an integer, which is required for e.g. zero-padding in convolutional layers.
CpuNudge(const float min,const float max,const float quant_min,const float quant_max,float * nudged_min,float * nudged_max,float * scale)27 void CpuNudge(const float min, const float max, const float quant_min,
28 const float quant_max, float* nudged_min, float* nudged_max,
29 float* scale) {
30 *scale = (max - min) / (quant_max - quant_min);
31
32 const float zero_point_from_min = quant_min - min / *scale;
33 float nudged_zero_point;
34 if (zero_point_from_min <= quant_min) {
35 nudged_zero_point = quant_min;
36 } else if (zero_point_from_min >= quant_max) {
37 nudged_zero_point = quant_max;
38 } else {
39 nudged_zero_point = std::round(zero_point_from_min);
40 }
41
42 *nudged_min = (quant_min - nudged_zero_point) * (*scale);
43 *nudged_max = (quant_max - nudged_zero_point) * (*scale);
44 }
45
46 // An XLA version of CpuNudge().
XlaNudge(xla::ComputationBuilder * b,const DataType data_type,const xla::ComputationDataHandle & min,const xla::ComputationDataHandle & max,const float quant_min_value,const float quant_max_value,xla::ComputationDataHandle * nudged_min,xla::ComputationDataHandle * nudged_max,xla::ComputationDataHandle * scale)47 void XlaNudge(xla::ComputationBuilder* b, const DataType data_type,
48 const xla::ComputationDataHandle& min,
49 const xla::ComputationDataHandle& max,
50 const float quant_min_value, const float quant_max_value,
51 xla::ComputationDataHandle* nudged_min,
52 xla::ComputationDataHandle* nudged_max,
53 xla::ComputationDataHandle* scale) {
54 *scale = b->Div(b->Sub(max, min),
55 XlaHelpers::FloatLiteral(b, data_type,
56 quant_max_value - quant_min_value));
57 xla::ComputationDataHandle quant_min =
58 XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
59 xla::ComputationDataHandle zero_point_from_min =
60 b->Sub(quant_min, b->Div(min, *scale));
61 xla::ComputationDataHandle quant_max =
62 XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
63 xla::ComputationDataHandle nudged_zero_point =
64 b->Select(b->Le(zero_point_from_min, quant_min), quant_min,
65 b->Select(b->Ge(zero_point_from_min, quant_max), quant_max,
66 b->Round(zero_point_from_min)));
67 *nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale);
68 *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale);
69 }
70
Quantize(xla::ComputationBuilder * b,const xla::ComputationDataHandle & input,const DataType data_type,const xla::ComputationDataHandle & nudged_input_min,const xla::ComputationDataHandle & nudged_input_max,const xla::ComputationDataHandle & input_scale)71 xla::ComputationDataHandle Quantize(
72 xla::ComputationBuilder* b, const xla::ComputationDataHandle& input,
73 const DataType data_type,
74 const xla::ComputationDataHandle& nudged_input_min,
75 const xla::ComputationDataHandle& nudged_input_max,
76 const xla::ComputationDataHandle& input_scale) {
77 xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
78 xla::ComputationDataHandle inv_scale = b->Div(one, input_scale);
79 xla::ComputationDataHandle half =
80 XlaHelpers::FloatLiteral(b, data_type, 0.5f);
81
82 xla::ComputationDataHandle clamped =
83 b->Clamp(nudged_input_min, input, nudged_input_max);
84 xla::ComputationDataHandle clamped_shifted =
85 b->Sub(clamped, nudged_input_min);
86 xla::ComputationDataHandle rounded =
87 b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half));
88 return b->Add(b->Mul(rounded, input_scale), nudged_input_min);
89 }
90
91 class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
92 public:
FakeQuantWithMinMaxArgsOp(OpKernelConstruction * ctx)93 explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx)
94 : XlaOpKernel(ctx) {
95 int num_bits;
96 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
97 OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
98 errors::InvalidArgument("num_bits is out of range, expected "
99 "between 2 and 16, was: ",
100 num_bits));
101 bool narrow_range;
102 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
103 quant_min_ = narrow_range ? 1 : 0;
104 quant_max_ = (1 << num_bits) - 1;
105
106 float input_min, input_max;
107 OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
108 OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
109 CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_,
110 &nudged_input_max_, &input_scale_);
111 }
112
Compile(XlaOpKernelContext * ctx)113 void Compile(XlaOpKernelContext* ctx) override {
114 xla::ComputationDataHandle input = ctx->Input(0);
115 const DataType data_type = ctx->input_type(0);
116
117 xla::ComputationBuilder* b = ctx->builder();
118 xla::ComputationDataHandle nudged_input_min =
119 XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
120 xla::ComputationDataHandle nudged_input_max =
121 XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
122 xla::ComputationDataHandle input_scale =
123 XlaHelpers::FloatLiteral(b, data_type, input_scale_);
124 xla::ComputationDataHandle output = Quantize(
125 b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
126 ctx->SetOutput(0, output);
127 }
128
129 private:
130 float quant_min_;
131 float quant_max_;
132 float nudged_input_min_;
133 float nudged_input_max_;
134 float input_scale_;
135 };
136
137 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgs"), FakeQuantWithMinMaxArgsOp);
138
139 class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
140 public:
FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction * ctx)141 explicit FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction* ctx)
142 : XlaOpKernel(ctx) {
143 int num_bits;
144 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
145 OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
146 errors::InvalidArgument("num_bits is out of range, expected "
147 "between 2 and 16, was: ",
148 num_bits));
149 bool narrow_range;
150 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
151 const float quant_min = narrow_range ? 1 : 0;
152 const float quant_max = (1 << num_bits) - 1;
153
154 float input_min, input_max, scale;
155 OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
156 OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
157 CpuNudge(input_min, input_max, quant_min, quant_max, &nudged_input_min_,
158 &nudged_input_max_, &scale);
159 }
160
Compile(XlaOpKernelContext * ctx)161 void Compile(XlaOpKernelContext* ctx) override {
162 xla::ComputationDataHandle gradient = ctx->Input(0);
163 const TensorShape gradient_shape = ctx->InputShape(0);
164 xla::ComputationDataHandle input = ctx->Input(1);
165 const DataType data_type = ctx->input_type(1);
166
167 xla::ComputationBuilder* b = ctx->builder();
168 xla::ComputationDataHandle nudged_input_min =
169 XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
170 xla::ComputationDataHandle nudged_input_max =
171 XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
172
173 xla::ComputationDataHandle between_nudged_min_max =
174 b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
175 xla::ComputationDataHandle zeroes = b->Broadcast(
176 XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes());
177 xla::ComputationDataHandle output =
178 b->Select(between_nudged_min_max, gradient, zeroes);
179 ctx->SetOutput(0, output);
180 }
181
182 private:
183 float nudged_input_min_;
184 float nudged_input_max_;
185 };
186
187 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgsGradient"),
188 FakeQuantWithMinMaxArgsGradOp);
189
190 class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
191 public:
FakeQuantWithMinMaxVarsOp(OpKernelConstruction * ctx)192 explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx)
193 : XlaOpKernel(ctx) {
194 int num_bits;
195 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
196 OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
197 errors::InvalidArgument("num_bits is out of range, expected "
198 "between 2 and 16, was: ",
199 num_bits));
200 bool narrow_range;
201 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
202 quant_min_ = narrow_range ? 1 : 0;
203 quant_max_ = (1 << num_bits) - 1;
204 }
205
Compile(XlaOpKernelContext * ctx)206 void Compile(XlaOpKernelContext* ctx) override {
207 xla::ComputationDataHandle input = ctx->Input(0);
208 const DataType data_type = ctx->input_type(0);
209 xla::ComputationDataHandle input_min = ctx->Input(1);
210 xla::ComputationDataHandle input_max = ctx->Input(2);
211
212 xla::ComputationBuilder* b = ctx->builder();
213 xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
214 XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
215 &nudged_input_min, &nudged_input_max, &input_scale);
216
217 xla::ComputationDataHandle output = Quantize(
218 b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
219 ctx->SetOutput(0, output);
220 }
221
222 private:
223 float quant_min_;
224 float quant_max_;
225 };
226
227 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVars"), FakeQuantWithMinMaxVarsOp);
228
229 class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
230 public:
FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction * ctx)231 explicit FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction* ctx)
232 : XlaOpKernel(ctx) {
233 int num_bits;
234 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
235 OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
236 errors::InvalidArgument("num_bits is out of range, expected "
237 "between 2 and 16, was: ",
238 num_bits));
239 bool narrow_range;
240 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
241 quant_min_ = narrow_range ? 1 : 0;
242 quant_max_ = (1 << num_bits) - 1;
243 }
244
Compile(XlaOpKernelContext * ctx)245 void Compile(XlaOpKernelContext* ctx) override {
246 xla::ComputationDataHandle gradient = ctx->Input(0);
247 const TensorShape gradient_shape = ctx->InputShape(0);
248 xla::ComputationDataHandle input = ctx->Input(1);
249 const DataType data_type = ctx->input_type(1);
250 xla::ComputationDataHandle input_min = ctx->Input(2);
251 xla::ComputationDataHandle input_max = ctx->Input(3);
252
253 xla::ComputationBuilder* b = ctx->builder();
254 xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
255 XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
256 &nudged_input_min, &nudged_input_max, &input_scale);
257
258 xla::ComputationDataHandle between_nudged_min_max =
259 b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
260 xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type);
261 xla::ComputationDataHandle zeroes =
262 b->Broadcast(zero, gradient_shape.dim_sizes());
263 xla::ComputationDataHandle output0 =
264 b->Select(between_nudged_min_max, gradient, zeroes);
265 ctx->SetOutput(0, output0);
266
267 xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min);
268 xla::ComputationDataHandle output1 =
269 b->ReduceAll(b->Select(below_min, gradient, zeroes), zero,
270 *ctx->GetOrCreateAdd(data_type));
271 ctx->SetOutput(1, output1);
272
273 xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max);
274 xla::ComputationDataHandle output2 =
275 b->ReduceAll(b->Select(above_max, gradient, zeroes), zero,
276 *ctx->GetOrCreateAdd(data_type));
277 ctx->SetOutput(2, output2);
278 }
279
280 private:
281 float quant_min_;
282 float quant_max_;
283 };
284
285 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVarsGradient"),
286 FakeQuantWithMinMaxVarsGradOp);
287
288 } // namespace
289 } // namespace tensorflow
290