1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
18
19 // See docs in ../ops/math_ops.cc.
20 #define _USE_MATH_DEFINES
21 #include <cmath>
22
23 #define EIGEN_USE_THREADS
24
25 #include "tensorflow/core/platform/bfloat16.h"
26
27
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/variant_op_registry.h"
32 #include "tensorflow/core/kernels/cwise_ops.h"
33 #include "tensorflow/core/kernels/cwise_ops_gradients.h"
34 #include "tensorflow/core/kernels/fill_functor.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/util/bcast.h"
37
38 namespace tensorflow {
39
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42
43 class BinaryOpShared : public OpKernel {
44 public:
45 explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in);
46
47 protected:
48 struct BinaryOpState {
49 // Sets up bcast with the shape of in0 and in1, ensures that the bcast
50 // is valid, and if so, set out, either by allocating a new buffer using
51 // ctx->output(...) or by creating an alias for an owned input buffer for
52 // in-place computation.
53 // Caller must check ctx->status() upon return for non-ok status.
54 // If ctx->status().ok() is true, then out is guaranteed to be allocated.
55 explicit BinaryOpState(OpKernelContext* ctx);
56
57 const Tensor& in0;
58 const Tensor& in1;
59
60 BCast bcast;
61 Tensor* out = nullptr;
62 int64 out_num_elements;
63
64 int64 in0_num_elements;
65 int64 in1_num_elements;
66
67 int ndims;
68 bool result;
69 };
70
71 void SetUnimplementedError(OpKernelContext* ctx);
72 void SetComputeError(OpKernelContext* ctx);
73 };
74
75 // Coefficient-wise binary operations:
76 // Device: E.g., CPUDevice, GPUDevice.
77 // Functor: defined in cwise_ops.h. E.g., functor::add.
78 template <typename Device, typename Functor>
79 class BinaryOp : public BinaryOpShared {
80 public:
81 typedef typename Functor::in_type Tin; // Input scalar data type.
82 typedef typename Functor::out_type Tout; // Output scalar data type.
83
BinaryOp(OpKernelConstruction * ctx)84 explicit BinaryOp(OpKernelConstruction* ctx)
85 : BinaryOpShared(ctx, DataTypeToEnum<Tout>::v(),
86 DataTypeToEnum<Tin>::v()) {}
87
Compute(OpKernelContext * ctx)88 void Compute(OpKernelContext* ctx) override {
89 const Tensor& input_0 = ctx->input(0);
90 const Tensor& input_1 = ctx->input(1);
91 const Device& eigen_device = ctx->eigen_device<Device>();
92 bool error = false;
93 bool* const error_ptr = Functor::has_errors ? &error : nullptr;
94
95 // NOTE: Handle three simple cases before building the BinaryOpState, which
96 // is relatively expensive for small operations.
97 if (input_0.shape() == input_1.shape()) {
98 // tensor op tensor with no broadcasting.
99 Tensor* out;
100 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
101 {0, 1}, 0, input_0.shape(), &out));
102 functor::BinaryFunctor<Device, Functor, 1>()(
103 eigen_device, out->template flat<Tout>(),
104 input_0.template flat<Tin>(), input_1.template flat<Tin>(),
105 error_ptr);
106 if (Functor::has_errors && error) {
107 SetComputeError(ctx);
108 }
109 return;
110 } else if (input_0.shape().dims() == 0) {
111 // scalar op tensor.
112 Tensor* out;
113 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
114 {1}, 0, input_1.shape(), &out));
115
116 functor::BinaryFunctor<Device, Functor, 1>().Left(
117 eigen_device, out->template flat<Tout>(),
118 input_0.template scalar<Tin>(), input_1.template flat<Tin>(),
119 error_ptr);
120 if (Functor::has_errors && error) {
121 SetComputeError(ctx);
122 }
123 return;
124 } else if (input_1.shape().dims() == 0) {
125 // tensor op scalar.
126 Tensor* out;
127 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
128 {0}, 0, input_0.shape(), &out));
129 functor::BinaryFunctor<Device, Functor, 1>().Right(
130 eigen_device, out->template flat<Tout>(),
131 input_0.template flat<Tin>(), input_1.template scalar<Tin>(),
132 error_ptr);
133 if (Functor::has_errors && error) {
134 SetComputeError(ctx);
135 }
136 return;
137 }
138
139 // 'state': Shared helper not dependent on T to reduce code size
140 BinaryOpState state(ctx);
141 if (ctx->status().code() == error::RESOURCE_EXHAUSTED) {
142 // Stop when BinaryOpState's constructor failed due to OOM.
143 return;
144 }
145 auto& bcast = state.bcast;
146 Tensor* out = state.out;
147 if (!bcast.IsValid()) {
148 if (ctx->status().ok()) {
149 if (state.result) {
150 functor::SetOneFunctor<Device, bool>()(eigen_device,
151 out->flat<bool>());
152 } else {
153 functor::SetZeroFunctor<Device, bool>()(eigen_device,
154 out->flat<bool>());
155 }
156 }
157 return;
158 }
159
160 auto& in0 = state.in0;
161 auto& in1 = state.in1;
162 if (state.out_num_elements == 0) {
163 return;
164 }
165
166 const int ndims = state.ndims;
167 if (ndims <= 1) {
168 auto out_flat = out->flat<Tout>();
169 if (state.in1_num_elements == 1) {
170 // tensor op scalar
171 functor::BinaryFunctor<Device, Functor, 1>().Right(
172 eigen_device, out_flat, in0.template flat<Tin>(),
173 in1.template scalar<Tin>(), error_ptr);
174 } else if (state.in0_num_elements == 1) {
175 // scalar op tensor
176 functor::BinaryFunctor<Device, Functor, 1>().Left(
177 eigen_device, out_flat, in0.template scalar<Tin>(),
178 in1.template flat<Tin>(), error_ptr);
179 } else {
180 functor::BinaryFunctor<Device, Functor, 1>()(
181 eigen_device, out_flat, in0.template flat<Tin>(),
182 in1.template flat<Tin>(), error_ptr);
183 }
184 } else if (ndims == 2) {
185 functor::BinaryFunctor<Device, Functor, 2>().BCast(
186 eigen_device, out->shaped<Tout, 2>(bcast.result_shape()),
187 in0.template shaped<Tin, 2>(bcast.x_reshape()),
188 BCast::ToIndexArray<2>(bcast.x_bcast()),
189 in1.template shaped<Tin, 2>(bcast.y_reshape()),
190 BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr);
191 } else if (ndims == 3) {
192 functor::BinaryFunctor<Device, Functor, 3>().BCast(
193 eigen_device, out->shaped<Tout, 3>(bcast.result_shape()),
194 in0.template shaped<Tin, 3>(bcast.x_reshape()),
195 BCast::ToIndexArray<3>(bcast.x_bcast()),
196 in1.template shaped<Tin, 3>(bcast.y_reshape()),
197 BCast::ToIndexArray<3>(bcast.y_bcast()), error_ptr);
198 } else if (ndims == 4) {
199 functor::BinaryFunctor<Device, Functor, 4>().BCast(
200 eigen_device, out->shaped<Tout, 4>(bcast.result_shape()),
201 in0.template shaped<Tin, 4>(bcast.x_reshape()),
202 BCast::ToIndexArray<4>(bcast.x_bcast()),
203 in1.template shaped<Tin, 4>(bcast.y_reshape()),
204 BCast::ToIndexArray<4>(bcast.y_bcast()), error_ptr);
205 } else if (ndims == 5) {
206 functor::BinaryFunctor<Device, Functor, 5>().BCast(
207 eigen_device, out->shaped<Tout, 5>(bcast.result_shape()),
208 in0.template shaped<Tin, 5>(bcast.x_reshape()),
209 BCast::ToIndexArray<5>(bcast.x_bcast()),
210 in1.template shaped<Tin, 5>(bcast.y_reshape()),
211 BCast::ToIndexArray<5>(bcast.y_bcast()), error_ptr);
212 } else {
213 SetUnimplementedError(ctx);
214 }
215 if (Functor::has_errors && error) {
216 SetComputeError(ctx);
217 }
218 }
219 };
220
221 template <typename Device, typename T>
222 class ApproximateEqualOp : public OpKernel {
223 public:
ApproximateEqualOp(OpKernelConstruction * context)224 explicit ApproximateEqualOp(OpKernelConstruction* context)
225 : OpKernel(context) {
226 float tolerance;
227 OP_REQUIRES_OK(context, context->GetAttr("tolerance", &tolerance));
228 tolerance_ = T(tolerance);
229 }
Compute(OpKernelContext * context)230 void Compute(OpKernelContext* context) override {
231 const Tensor& x_input = context->input(0);
232 const Tensor& y_input = context->input(1);
233 OP_REQUIRES(
234 context, x_input.shape() == y_input.shape(),
235 errors::InvalidArgument("x and y must be of the same shape. ",
236 "x shape: ", x_input.shape().DebugString(),
237 ". y shape: ", y_input.shape().DebugString()));
238 Tensor* z_output = nullptr;
239 OP_REQUIRES_OK(context,
240 context->allocate_output(0, x_input.shape(), &z_output));
241 const Device& d = context->eigen_device<Device>();
242 typename TTypes<T>::ConstFlat x(x_input.flat<T>());
243 typename TTypes<T>::ConstFlat y(y_input.flat<T>());
244 typename TTypes<bool>::Flat z(z_output->flat<bool>());
245 functor::ApproximateEqual<Device, T>()(d, x, y, tolerance_, z);
246 }
247
248 private:
249 T tolerance_;
250 };
251
252 // Basic coefficient-wise binary operations that are known to not require
253 // any broadcasting. This is the case for example of the gradients of
254 // unary operations.
255 // Device: E.g., CPUDevice, GPUDevice.
256 // Functor: defined above. E.g., functor::tanh_grad.
257 template <typename Device, typename Functor>
258 class SimpleBinaryOp : public OpKernel {
259 public:
260 typedef typename Functor::in_type Tin; // Input scalar data type.
261 typedef typename Functor::out_type Tout; // Output scalar data type.
262
SimpleBinaryOp(OpKernelConstruction * ctx)263 explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
264
Compute(OpKernelContext * ctx)265 void Compute(OpKernelContext* ctx) override {
266 const Tensor& in0 = ctx->input(0);
267 const Tensor& in1 = ctx->input(1);
268 auto in0_flat = in0.flat<Tin>();
269 auto in1_flat = in1.flat<Tin>();
270 const Device& eigen_device = ctx->eigen_device<Device>();
271
272 Tensor* out = nullptr;
273 if (std::is_same<Tin, Tout>::value) {
274 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
275 {0, 1}, 0, in0.shape(), &out));
276 } else {
277 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
278 }
279 auto out_flat = out->flat<Tout>();
280 functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
281 in0_flat, in1_flat);
282 }
283 };
284
285 // Coefficient-wise unary operations:
286 // Device: E.g., CPUDevice, GPUDevice.
287 // Functor: defined in cwise_ops.h. E.g., functor::sqrt.
288 template <typename Device, typename Functor>
289 class UnaryOp : public OpKernel {
290 public:
291 typedef typename Functor::in_type Tin; // Input scalar data type.
292 typedef typename Functor::out_type Tout; // Output scalar data type.
293 // Tin may be different from Tout. E.g., abs: complex64 -> float
294
UnaryOp(OpKernelConstruction * ctx)295 explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
296 auto in = DataTypeToEnum<Tin>::v();
297 auto out = DataTypeToEnum<Tout>::v();
298 OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out}));
299 }
300
Compute(OpKernelContext * ctx)301 void Compute(OpKernelContext* ctx) override {
302 const Tensor& inp = ctx->input(0);
303 Tensor* out = nullptr;
304 if (std::is_same<Tin, Tout>::value) {
305 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
306 {0}, 0, inp.shape(), &out));
307 } else {
308 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
309 }
310 functor::UnaryFunctor<Device, Functor>()(
311 ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
312 }
313 };
314
315 template <typename Device, VariantUnaryOp OpEnum>
316 class UnaryVariantOp : public OpKernel {
317 public:
UnaryVariantOp(OpKernelConstruction * ctx)318 explicit UnaryVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
319
Compute(OpKernelContext * ctx)320 void Compute(OpKernelContext* ctx) override {
321 const Tensor& inp = ctx->input(0);
322 OP_REQUIRES(
323 ctx, TensorShapeUtils::IsScalar(inp.shape()),
324 errors::InvalidArgument("Non-scalar variants are not supported."));
325 const Variant& v = inp.scalar<Variant>()();
326 Variant v_out;
327 OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(ctx, OpEnum, v, &v_out));
328 int numa_node = ctx->device()->NumaNode();
329 Tensor out(cpu_allocator(numa_node), DT_VARIANT, TensorShape());
330 out.scalar<Variant>()() = std::move(v_out);
331 ctx->set_output(0, std::move(out));
332 }
333 };
334
335 namespace functor {
336
337 template <typename D, typename Out, typename Rhs>
Assign(const D & d,Out out,Rhs rhs)338 void Assign(const D& d, Out out, Rhs rhs) {
339 out.device(d) = rhs;
340 }
341
342 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, NDIMS>
343 // for functors with no error checking.
344 template <typename Functor, int NDIMS>
345 struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
346 void operator()(const CPUDevice& d, typename Functor::tout_type out,
347 typename Functor::tin_type in0,
348 typename Functor::tin_type in1, bool* error) {
349 Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
350 }
351
352 void Left(const CPUDevice& d, typename Functor::tout_type out,
353 typename Functor::tscalar_type scalar,
354 typename Functor::tin_type in, bool* error) {
355 typedef typename Functor::out_type Tout;
356 typedef typename Functor::in_type Tin;
357 typedef typename Functor::func Binary;
358 typedef
359 typename Eigen::internal::scalar_left<Tout, Tin, Binary,
360 /*is_scalar_in_host_memory=*/true>
361 Unary;
362 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
363 }
364
365 void Right(const CPUDevice& d, typename Functor::tout_type out,
366 typename Functor::tin_type in,
367 typename Functor::tscalar_type scalar, bool* error) {
368 typedef typename Functor::out_type Tout;
369 typedef typename Functor::in_type Tin;
370 typedef typename Functor::func Binary;
371 typedef typename Eigen::internal::scalar_right<
372 Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
373 Unary;
374 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
375 }
376
377 void BCast(const CPUDevice& dev,
378 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
379 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
380 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
381 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
382 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
383 bool* error) {
384 typename Functor::func func;
385 if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
386 Assign(dev, out, in0.binaryExpr(in1, func));
387 } else if (AllOne<NDIMS>(bcast0)) {
388 auto rhs = in1.broadcast(bcast1);
389 Assign(dev, out, in0.binaryExpr(rhs, func));
390 } else if (AllOne<NDIMS>(bcast1)) {
391 auto lhs = in0.broadcast(bcast0);
392 Assign(dev, out, lhs.binaryExpr(in1, func));
393 } else {
394 auto lhs = in0.broadcast(bcast0);
395 auto rhs = in1.broadcast(bcast1);
396 Assign(dev, out, lhs.binaryExpr(rhs, func));
397 }
398 }
399 };
400
401 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2>
402 // for functors with no error checking.
403 template <typename Functor>
404 struct BinaryFunctor<CPUDevice, Functor, 2, false> {
405 enum { NDIMS = 2 };
406
407 void operator()(const CPUDevice& d, typename Functor::tout_type out,
408 typename Functor::tin_type in0,
409 typename Functor::tin_type in1, bool* error) {
410 Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
411 }
412
413 void Left(const CPUDevice& d, typename Functor::tout_type out,
414 typename Functor::tscalar_type scalar,
415 typename Functor::tin_type in, bool* error) {
416 typedef typename Functor::out_type Tout;
417 typedef typename Functor::in_type Tin;
418 typedef typename Functor::func Binary;
419 typedef
420 typename Eigen::internal::scalar_left<Tout, Tin, Binary,
421 /*is_scalar_in_host_memory=*/true>
422 Unary;
423 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
424 }
425
426 void Right(const CPUDevice& d, typename Functor::tout_type out,
427 typename Functor::tin_type in,
428 typename Functor::tscalar_type scalar, bool* error) {
429 typedef typename Functor::out_type Tout;
430 typedef typename Functor::in_type Tin;
431 typedef typename Functor::func Binary;
432 typedef typename Eigen::internal::scalar_right<
433 Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
434 Unary;
435 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
436 }
437
438 #if !defined(EIGEN_HAS_INDEX_LIST)
439 inline Eigen::DSizes<int, 2> NByOne(int n) {
440 return Eigen::DSizes<int, 2>(n, 1);
441 }
442 inline Eigen::DSizes<int, 2> OneByM(int m) {
443 return Eigen::DSizes<int, 2>(1, m);
444 }
445 #else
446 inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) {
447 Eigen::IndexList<int, Eigen::type2index<1>> ret;
448 ret.set(0, n);
449 return ret;
450 }
451 inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) {
452 Eigen::IndexList<Eigen::type2index<1>, int> ret;
453 ret.set(1, m);
454 return ret;
455 }
456 #endif
457
458 void BCast(const CPUDevice& dev,
459 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
460 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
461 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
462 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
463 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
464 bool* error) {
465 typedef typename Functor::in_type T;
466 typename Functor::func func;
467 if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) {
468 // Optimize for speed by using Eigen::type2index and avoid
469 // .broadcast() when we know it's a no-op.
470 //
471 // Here, we need to handle 6 cases depending on how many "1"
472 // exist in in0 and in1's shapes (4 numbers in total). It's not
473 // possible that two shapes have more than 2 1s because those
474 // are simplified to NDIMS==1 case.
475 //
476 // Because this optimization increases the binary size for each
477 // Functor (+, -, *, /, <, <=, etc.), type and ndim combination.
478 // we only apply such optimization for selected ops/types/ndims.
479 //
480 // Because NDIMS, Functor::use_broadcast_optimization and
481 // use_broadcast_optimization<T> are compile-time constant, gcc
482 // does a decent job avoiding generating code when conditions
483 // are not met.
484 const int a = in0.dimension(0); // in0 is shape [a, b]
485 const int b = in0.dimension(1);
486 const int c = in1.dimension(0); // in1 is shape [c, d]
487 const int d = in1.dimension(1);
488 if ((a == 1) && (d == 1)) {
489 auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
490 auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
491 Assign(dev, out, lhs.binaryExpr(rhs, func));
492 return;
493 }
494 if ((b == 1) && (c == 1)) {
495 auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
496 auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
497 Assign(dev, out, lhs.binaryExpr(rhs, func));
498 return;
499 }
500 if (a == 1) {
501 auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
502 auto rhs = in1;
503 Assign(dev, out, lhs.binaryExpr(rhs, func));
504 return;
505 }
506 if (b == 1) {
507 auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
508 auto rhs = in1;
509 Assign(dev, out, lhs.binaryExpr(rhs, func));
510 return;
511 }
512 if (c == 1) {
513 auto lhs = in0;
514 auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
515 Assign(dev, out, lhs.binaryExpr(rhs, func));
516 return;
517 }
518 if (d == 1) {
519 auto lhs = in0;
520 auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
521 Assign(dev, out, lhs.binaryExpr(rhs, func));
522 return;
523 }
524
525 const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
526 const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
527 if (bcast0_all_one && !bcast1_all_one) {
528 auto lhs = in0; // No need to do broadcast for in0
529 auto rhs = in1.broadcast(bcast1);
530 Assign(dev, out, lhs.binaryExpr(rhs, func));
531 return;
532 }
533
534 if (!bcast0_all_one && bcast1_all_one) {
535 auto lhs = in0.broadcast(bcast0);
536 auto rhs = in1; // No need to do broadcast for in1
537 Assign(dev, out, lhs.binaryExpr(rhs, func));
538 return;
539 }
540 }
541
542 // Fallback path. Always works and probably slower.
543 auto lhs = in0.broadcast(bcast0);
544 auto rhs = in1.broadcast(bcast1);
545 Assign(dev, out, lhs.binaryExpr(rhs, func));
546 }
547 };
548
549 // Version of BinaryFunctor with error handling.
550 template <typename Functor, int NDIMS>
551 struct BinaryFunctor<CPUDevice, Functor, NDIMS, true> {
552 void operator()(const CPUDevice& d, typename Functor::tout_type out,
553 typename Functor::tin_type in0,
554 typename Functor::tin_type in1, bool* error) {
555 Assign(d, out, in0.binaryExpr(in1, typename Functor::func(error)));
556 }
557
558 void Left(const CPUDevice& d, typename Functor::tout_type out,
559 typename Functor::tscalar_type scalar,
560 typename Functor::tin_type in, bool* error) {
561 typedef typename Functor::out_type Tout;
562 typedef typename Functor::in_type Tin;
563 typedef typename Functor::func Binary;
564 typedef
565 typename Eigen::internal::scalar_left<Tout, Tin, Binary,
566 /*is_scalar_in_host_memory=*/true>
567 Unary;
568 Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
569 }
570
571 void Right(const CPUDevice& d, typename Functor::tout_type out,
572 typename Functor::tin_type in,
573 typename Functor::tscalar_type scalar, bool* error) {
574 typedef typename Functor::out_type Tout;
575 typedef typename Functor::in_type Tin;
576 typedef typename Functor::func Binary;
577 typedef typename Eigen::internal::scalar_right<
578 Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
579 Unary;
580 Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
581 }
582
583 void BCast(const CPUDevice& dev,
584 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
585 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
586 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
587 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
588 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
589 bool* error) {
590 typename Functor::func func(error);
591 auto lhs = in0.broadcast(bcast0);
592 auto rhs = in1.broadcast(bcast1);
593 Assign(dev, out, lhs.binaryExpr(rhs, func));
594 }
595 };
596
597 // Partial specialization of UnaryFunctor<Device=CPUDevice, Functor>.
598 template <typename Functor>
599 struct UnaryFunctor<CPUDevice, Functor> {
600 void operator()(const CPUDevice& d, typename Functor::tout_type out,
601 typename Functor::tin_type in) {
602 Assign(d, out, in.unaryExpr(typename Functor::func()));
603 }
604 };
605
606 // Partial specialization of ApproximateEqual<Device=CPUDevice, T>.
607 template <typename T>
608 struct ApproximateEqual<CPUDevice, T> {
609 void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat x,
610 typename TTypes<T>::ConstFlat y, T tolerance,
611 typename TTypes<bool>::Flat z) {
612 auto diff = x - y;
613 z.device(d) = diff.abs() <= tolerance;
614 }
615 };
616
617 } // end namespace functor
618
619 #define REGISTER(OP, D, N, F, T) \
620 REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \
621 OP<D##Device, F<T>>);
622
623 #define REGISTER_VARIANT(OP, D, N, ENUM) \
624 REGISTER_KERNEL_BUILDER( \
625 Name(N).Device(DEVICE_##D).TypeConstraint<Variant>("T"), \
626 OP<D##Device, ENUM>);
627
628 // Macros to register kernels for multiple types (T0, T1, etc.) on
629 // device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using
630 // the functor "F" (e.g., functor::sqrt).
631
632 #if defined(__ANDROID_TYPES_SLIM__)
633 // Note that __ANDROID_TYPES_SLIM__ is also checked in the cwise_ops*.cc files.
634 // Normally Android TensorFlow is built with a reduced number of types (float).
635 // Override on the command-line using "--copt=-D__ANDROID_TYPES_FULL__"
636 // to generate a library with full type support with a consequent increase in
637 // code size.
638 #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
639 #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
640 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
641 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0)
642 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
643 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
644 REGISTER(OP, D, N, F, T0)
645 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
646 REGISTER(OP, D, N, F, T0)
647 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
648 REGISTER(OP, D, N, F, T0)
649 #else // !defined(__ANDROID_TYPES_SLIM__)
650 #define REGISTER2(OP, D, N, F, T0, T1) \
651 REGISTER(OP, D, N, F, T0) \
652 REGISTER(OP, D, N, F, T1)
653 #define REGISTER3(OP, D, N, F, T0, T1, T2) \
654 REGISTER2(OP, D, N, F, T0, T1) \
655 REGISTER(OP, D, N, F, T2)
656 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
657 REGISTER2(OP, D, N, F, T0, T1) \
658 REGISTER2(OP, D, N, F, T2, T3)
659 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
660 REGISTER3(OP, D, N, F, T0, T1, T2) \
661 REGISTER2(OP, D, N, F, T3, T4)
662 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \
663 REGISTER3(OP, D, N, F, T0, T1, T2) \
664 REGISTER3(OP, D, N, F, T3, T4, T5)
665 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
666 REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
667 REGISTER3(OP, D, N, F, T4, T5, T6)
668 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
669 REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
670 REGISTER4(OP, D, N, F, T4, T5, T6, T7)
671 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
672 REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
673 REGISTER4(OP, D, N, F, T5, T6, T7, T8)
674
675 // Instead of adding REGISTER10, etc., shard the .cc files - see
676 // cwise_op_equal_to_*.cc for an example.
677
678 #endif // defined(__ANDROID_TYPES_SLIM__)
679
680 } // end namespace tensorflow
681
682 #endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
683