1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
18 
19 // See docs in ../ops/math_ops.cc.
20 #define _USE_MATH_DEFINES
21 #include <cmath>
22 
23 #define EIGEN_USE_THREADS
24 
25 #include "tensorflow/core/platform/bfloat16.h"
26 
27 
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/variant_op_registry.h"
32 #include "tensorflow/core/kernels/cwise_ops.h"
33 #include "tensorflow/core/kernels/cwise_ops_gradients.h"
34 #include "tensorflow/core/kernels/fill_functor.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/util/bcast.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42 
43 class BinaryOpShared : public OpKernel {
44  public:
45   explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in);
46 
47  protected:
48   struct BinaryOpState {
49     // Sets up bcast with the shape of in0 and in1, ensures that the bcast
50     // is valid, and if so, set out, either by allocating a new buffer using
51     // ctx->output(...) or by creating an alias for an owned input buffer for
52     // in-place computation.
53     // Caller must check ctx->status() upon return for non-ok status.
54     // If ctx->status().ok() is true, then out is guaranteed to be allocated.
55     explicit BinaryOpState(OpKernelContext* ctx);
56 
57     const Tensor& in0;
58     const Tensor& in1;
59 
60     BCast bcast;
61     Tensor* out = nullptr;
62     int64 out_num_elements;
63 
64     int64 in0_num_elements;
65     int64 in1_num_elements;
66 
67     int ndims;
68     bool result;
69   };
70 
71   void SetUnimplementedError(OpKernelContext* ctx);
72   void SetComputeError(OpKernelContext* ctx);
73 };
74 
75 // Coefficient-wise binary operations:
76 //   Device: E.g., CPUDevice, GPUDevice.
77 //   Functor: defined in cwise_ops.h. E.g., functor::add.
78 template <typename Device, typename Functor>
79 class BinaryOp : public BinaryOpShared {
80  public:
81   typedef typename Functor::in_type Tin;    // Input scalar data type.
82   typedef typename Functor::out_type Tout;  // Output scalar data type.
83 
BinaryOp(OpKernelConstruction * ctx)84   explicit BinaryOp(OpKernelConstruction* ctx)
85       : BinaryOpShared(ctx, DataTypeToEnum<Tout>::v(),
86                        DataTypeToEnum<Tin>::v()) {}
87 
Compute(OpKernelContext * ctx)88   void Compute(OpKernelContext* ctx) override {
89     const Tensor& input_0 = ctx->input(0);
90     const Tensor& input_1 = ctx->input(1);
91     const Device& eigen_device = ctx->eigen_device<Device>();
92     bool error = false;
93     bool* const error_ptr = Functor::has_errors ? &error : nullptr;
94 
95     // NOTE: Handle three simple cases before building the BinaryOpState, which
96     // is relatively expensive for small operations.
97     if (input_0.shape() == input_1.shape()) {
98       // tensor op tensor with no broadcasting.
99       Tensor* out;
100       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
101                               {0, 1}, 0, input_0.shape(), &out));
102       functor::BinaryFunctor<Device, Functor, 1>()(
103           eigen_device, out->template flat<Tout>(),
104           input_0.template flat<Tin>(), input_1.template flat<Tin>(),
105           error_ptr);
106       if (Functor::has_errors && error) {
107         SetComputeError(ctx);
108       }
109       return;
110     } else if (input_0.shape().dims() == 0) {
111       // scalar op tensor.
112       Tensor* out;
113       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
114                               {1}, 0, input_1.shape(), &out));
115 
116       functor::BinaryFunctor<Device, Functor, 1>().Left(
117           eigen_device, out->template flat<Tout>(),
118           input_0.template scalar<Tin>(), input_1.template flat<Tin>(),
119           error_ptr);
120       if (Functor::has_errors && error) {
121         SetComputeError(ctx);
122       }
123       return;
124     } else if (input_1.shape().dims() == 0) {
125       // tensor op scalar.
126       Tensor* out;
127       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
128                               {0}, 0, input_0.shape(), &out));
129       functor::BinaryFunctor<Device, Functor, 1>().Right(
130           eigen_device, out->template flat<Tout>(),
131           input_0.template flat<Tin>(), input_1.template scalar<Tin>(),
132           error_ptr);
133       if (Functor::has_errors && error) {
134         SetComputeError(ctx);
135       }
136       return;
137     }
138 
139     // 'state': Shared helper not dependent on T to reduce code size
140     BinaryOpState state(ctx);
141     if (ctx->status().code() == error::RESOURCE_EXHAUSTED) {
142       // Stop when BinaryOpState's constructor failed due to OOM.
143       return;
144     }
145     auto& bcast = state.bcast;
146     Tensor* out = state.out;
147     if (!bcast.IsValid()) {
148       if (ctx->status().ok()) {
149         if (state.result) {
150           functor::SetOneFunctor<Device, bool>()(eigen_device,
151                                                  out->flat<bool>());
152         } else {
153           functor::SetZeroFunctor<Device, bool>()(eigen_device,
154                                                   out->flat<bool>());
155         }
156       }
157       return;
158     }
159 
160     auto& in0 = state.in0;
161     auto& in1 = state.in1;
162     if (state.out_num_elements == 0) {
163       return;
164     }
165 
166     const int ndims = state.ndims;
167     if (ndims <= 1) {
168       auto out_flat = out->flat<Tout>();
169       if (state.in1_num_elements == 1) {
170         // tensor op scalar
171         functor::BinaryFunctor<Device, Functor, 1>().Right(
172             eigen_device, out_flat, in0.template flat<Tin>(),
173             in1.template scalar<Tin>(), error_ptr);
174       } else if (state.in0_num_elements == 1) {
175         // scalar op tensor
176         functor::BinaryFunctor<Device, Functor, 1>().Left(
177             eigen_device, out_flat, in0.template scalar<Tin>(),
178             in1.template flat<Tin>(), error_ptr);
179       } else {
180         functor::BinaryFunctor<Device, Functor, 1>()(
181             eigen_device, out_flat, in0.template flat<Tin>(),
182             in1.template flat<Tin>(), error_ptr);
183       }
184     } else if (ndims == 2) {
185       functor::BinaryFunctor<Device, Functor, 2>().BCast(
186           eigen_device, out->shaped<Tout, 2>(bcast.result_shape()),
187           in0.template shaped<Tin, 2>(bcast.x_reshape()),
188           BCast::ToIndexArray<2>(bcast.x_bcast()),
189           in1.template shaped<Tin, 2>(bcast.y_reshape()),
190           BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr);
191     } else if (ndims == 3) {
192       functor::BinaryFunctor<Device, Functor, 3>().BCast(
193           eigen_device, out->shaped<Tout, 3>(bcast.result_shape()),
194           in0.template shaped<Tin, 3>(bcast.x_reshape()),
195           BCast::ToIndexArray<3>(bcast.x_bcast()),
196           in1.template shaped<Tin, 3>(bcast.y_reshape()),
197           BCast::ToIndexArray<3>(bcast.y_bcast()), error_ptr);
198     } else if (ndims == 4) {
199       functor::BinaryFunctor<Device, Functor, 4>().BCast(
200           eigen_device, out->shaped<Tout, 4>(bcast.result_shape()),
201           in0.template shaped<Tin, 4>(bcast.x_reshape()),
202           BCast::ToIndexArray<4>(bcast.x_bcast()),
203           in1.template shaped<Tin, 4>(bcast.y_reshape()),
204           BCast::ToIndexArray<4>(bcast.y_bcast()), error_ptr);
205     } else if (ndims == 5) {
206       functor::BinaryFunctor<Device, Functor, 5>().BCast(
207           eigen_device, out->shaped<Tout, 5>(bcast.result_shape()),
208           in0.template shaped<Tin, 5>(bcast.x_reshape()),
209           BCast::ToIndexArray<5>(bcast.x_bcast()),
210           in1.template shaped<Tin, 5>(bcast.y_reshape()),
211           BCast::ToIndexArray<5>(bcast.y_bcast()), error_ptr);
212     } else {
213       SetUnimplementedError(ctx);
214     }
215     if (Functor::has_errors && error) {
216       SetComputeError(ctx);
217     }
218   }
219 };
220 
221 template <typename Device, typename T>
222 class ApproximateEqualOp : public OpKernel {
223  public:
ApproximateEqualOp(OpKernelConstruction * context)224   explicit ApproximateEqualOp(OpKernelConstruction* context)
225       : OpKernel(context) {
226     float tolerance;
227     OP_REQUIRES_OK(context, context->GetAttr("tolerance", &tolerance));
228     tolerance_ = T(tolerance);
229   }
Compute(OpKernelContext * context)230   void Compute(OpKernelContext* context) override {
231     const Tensor& x_input = context->input(0);
232     const Tensor& y_input = context->input(1);
233     OP_REQUIRES(
234         context, x_input.shape() == y_input.shape(),
235         errors::InvalidArgument("x and y must be of the same shape. ",
236                                 "x shape: ", x_input.shape().DebugString(),
237                                 ". y shape: ", y_input.shape().DebugString()));
238     Tensor* z_output = nullptr;
239     OP_REQUIRES_OK(context,
240                    context->allocate_output(0, x_input.shape(), &z_output));
241     const Device& d = context->eigen_device<Device>();
242     typename TTypes<T>::ConstFlat x(x_input.flat<T>());
243     typename TTypes<T>::ConstFlat y(y_input.flat<T>());
244     typename TTypes<bool>::Flat z(z_output->flat<bool>());
245     functor::ApproximateEqual<Device, T>()(d, x, y, tolerance_, z);
246   }
247 
248  private:
249   T tolerance_;
250 };
251 
252 // Basic coefficient-wise binary operations that are known to not require
253 // any broadcasting. This is the case for example of the gradients of
254 // unary operations.
255 //   Device: E.g., CPUDevice, GPUDevice.
256 //   Functor: defined above. E.g., functor::tanh_grad.
257 template <typename Device, typename Functor>
258 class SimpleBinaryOp : public OpKernel {
259  public:
260   typedef typename Functor::in_type Tin;    // Input scalar data type.
261   typedef typename Functor::out_type Tout;  // Output scalar data type.
262 
SimpleBinaryOp(OpKernelConstruction * ctx)263   explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
264 
Compute(OpKernelContext * ctx)265   void Compute(OpKernelContext* ctx) override {
266     const Tensor& in0 = ctx->input(0);
267     const Tensor& in1 = ctx->input(1);
268     auto in0_flat = in0.flat<Tin>();
269     auto in1_flat = in1.flat<Tin>();
270     const Device& eigen_device = ctx->eigen_device<Device>();
271 
272     Tensor* out = nullptr;
273     if (std::is_same<Tin, Tout>::value) {
274       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
275                               {0, 1}, 0, in0.shape(), &out));
276     } else {
277       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
278     }
279     auto out_flat = out->flat<Tout>();
280     functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
281                                                     in0_flat, in1_flat);
282   }
283 };
284 
285 // Coefficient-wise unary operations:
286 //   Device: E.g., CPUDevice, GPUDevice.
287 //   Functor: defined in cwise_ops.h. E.g., functor::sqrt.
288 template <typename Device, typename Functor>
289 class UnaryOp : public OpKernel {
290  public:
291   typedef typename Functor::in_type Tin;    // Input scalar data type.
292   typedef typename Functor::out_type Tout;  // Output scalar data type.
293   // Tin may be different from Tout. E.g., abs: complex64 -> float
294 
UnaryOp(OpKernelConstruction * ctx)295   explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
296     auto in = DataTypeToEnum<Tin>::v();
297     auto out = DataTypeToEnum<Tout>::v();
298     OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out}));
299   }
300 
Compute(OpKernelContext * ctx)301   void Compute(OpKernelContext* ctx) override {
302     const Tensor& inp = ctx->input(0);
303     Tensor* out = nullptr;
304     if (std::is_same<Tin, Tout>::value) {
305       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
306                               {0}, 0, inp.shape(), &out));
307     } else {
308       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
309     }
310     functor::UnaryFunctor<Device, Functor>()(
311         ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
312   }
313 };
314 
315 template <typename Device, VariantUnaryOp OpEnum>
316 class UnaryVariantOp : public OpKernel {
317  public:
UnaryVariantOp(OpKernelConstruction * ctx)318   explicit UnaryVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
319 
Compute(OpKernelContext * ctx)320   void Compute(OpKernelContext* ctx) override {
321     const Tensor& inp = ctx->input(0);
322     OP_REQUIRES(
323         ctx, TensorShapeUtils::IsScalar(inp.shape()),
324         errors::InvalidArgument("Non-scalar variants are not supported."));
325     const Variant& v = inp.scalar<Variant>()();
326     Variant v_out;
327     OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(ctx, OpEnum, v, &v_out));
328     int numa_node = ctx->device()->NumaNode();
329     Tensor out(cpu_allocator(numa_node), DT_VARIANT, TensorShape());
330     out.scalar<Variant>()() = std::move(v_out);
331     ctx->set_output(0, std::move(out));
332   }
333 };
334 
335 namespace functor {
336 
337 template <typename D, typename Out, typename Rhs>
Assign(const D & d,Out out,Rhs rhs)338 void Assign(const D& d, Out out, Rhs rhs) {
339   out.device(d) = rhs;
340 }
341 
342 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, NDIMS>
343 // for functors with no error checking.
344 template <typename Functor, int NDIMS>
345 struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
346   void operator()(const CPUDevice& d, typename Functor::tout_type out,
347                   typename Functor::tin_type in0,
348                   typename Functor::tin_type in1, bool* error) {
349     Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
350   }
351 
352   void Left(const CPUDevice& d, typename Functor::tout_type out,
353             typename Functor::tscalar_type scalar,
354             typename Functor::tin_type in, bool* error) {
355     typedef typename Functor::out_type Tout;
356     typedef typename Functor::in_type Tin;
357     typedef typename Functor::func Binary;
358     typedef
359         typename Eigen::internal::scalar_left<Tout, Tin, Binary,
360                                               /*is_scalar_in_host_memory=*/true>
361             Unary;
362     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
363   }
364 
365   void Right(const CPUDevice& d, typename Functor::tout_type out,
366              typename Functor::tin_type in,
367              typename Functor::tscalar_type scalar, bool* error) {
368     typedef typename Functor::out_type Tout;
369     typedef typename Functor::in_type Tin;
370     typedef typename Functor::func Binary;
371     typedef typename Eigen::internal::scalar_right<
372         Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
373         Unary;
374     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
375   }
376 
377   void BCast(const CPUDevice& dev,
378              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
379              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
380              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
381              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
382              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
383              bool* error) {
384     typename Functor::func func;
385     if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
386       Assign(dev, out, in0.binaryExpr(in1, func));
387     } else if (AllOne<NDIMS>(bcast0)) {
388       auto rhs = in1.broadcast(bcast1);
389       Assign(dev, out, in0.binaryExpr(rhs, func));
390     } else if (AllOne<NDIMS>(bcast1)) {
391       auto lhs = in0.broadcast(bcast0);
392       Assign(dev, out, lhs.binaryExpr(in1, func));
393     } else {
394       auto lhs = in0.broadcast(bcast0);
395       auto rhs = in1.broadcast(bcast1);
396       Assign(dev, out, lhs.binaryExpr(rhs, func));
397     }
398   }
399 };
400 
401 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2>
402 // for functors with no error checking.
403 template <typename Functor>
404 struct BinaryFunctor<CPUDevice, Functor, 2, false> {
405   enum { NDIMS = 2 };
406 
407   void operator()(const CPUDevice& d, typename Functor::tout_type out,
408                   typename Functor::tin_type in0,
409                   typename Functor::tin_type in1, bool* error) {
410     Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
411   }
412 
413   void Left(const CPUDevice& d, typename Functor::tout_type out,
414             typename Functor::tscalar_type scalar,
415             typename Functor::tin_type in, bool* error) {
416     typedef typename Functor::out_type Tout;
417     typedef typename Functor::in_type Tin;
418     typedef typename Functor::func Binary;
419     typedef
420         typename Eigen::internal::scalar_left<Tout, Tin, Binary,
421                                               /*is_scalar_in_host_memory=*/true>
422             Unary;
423     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
424   }
425 
426   void Right(const CPUDevice& d, typename Functor::tout_type out,
427              typename Functor::tin_type in,
428              typename Functor::tscalar_type scalar, bool* error) {
429     typedef typename Functor::out_type Tout;
430     typedef typename Functor::in_type Tin;
431     typedef typename Functor::func Binary;
432     typedef typename Eigen::internal::scalar_right<
433         Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
434         Unary;
435     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
436   }
437 
438 #if !defined(EIGEN_HAS_INDEX_LIST)
439   inline Eigen::DSizes<int, 2> NByOne(int n) {
440     return Eigen::DSizes<int, 2>(n, 1);
441   }
442   inline Eigen::DSizes<int, 2> OneByM(int m) {
443     return Eigen::DSizes<int, 2>(1, m);
444   }
445 #else
446   inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) {
447     Eigen::IndexList<int, Eigen::type2index<1>> ret;
448     ret.set(0, n);
449     return ret;
450   }
451   inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) {
452     Eigen::IndexList<Eigen::type2index<1>, int> ret;
453     ret.set(1, m);
454     return ret;
455   }
456 #endif
457 
458   void BCast(const CPUDevice& dev,
459              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
460              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
461              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
462              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
463              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
464              bool* error) {
465     typedef typename Functor::in_type T;
466     typename Functor::func func;
467     if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) {
468       // Optimize for speed by using Eigen::type2index and avoid
469       // .broadcast() when we know it's a no-op.
470       //
471       // Here, we need to handle 6 cases depending on how many "1"
472       // exist in in0 and in1's shapes (4 numbers in total). It's not
473       // possible that two shapes have more than 2 1s because those
474       // are simplified to NDIMS==1 case.
475       //
476       // Because this optimization increases the binary size for each
477       // Functor (+, -, *, /, <, <=, etc.), type and ndim combination.
478       // we only apply such optimization for selected ops/types/ndims.
479       //
480       // Because NDIMS, Functor::use_broadcast_optimization and
481       // use_broadcast_optimization<T> are compile-time constant, gcc
482       // does a decent job avoiding generating code when conditions
483       // are not met.
484       const int a = in0.dimension(0);  // in0 is shape [a, b]
485       const int b = in0.dimension(1);
486       const int c = in1.dimension(0);  // in1 is shape [c, d]
487       const int d = in1.dimension(1);
488       if ((a == 1) && (d == 1)) {
489         auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
490         auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
491         Assign(dev, out, lhs.binaryExpr(rhs, func));
492         return;
493       }
494       if ((b == 1) && (c == 1)) {
495         auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
496         auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
497         Assign(dev, out, lhs.binaryExpr(rhs, func));
498         return;
499       }
500       if (a == 1) {
501         auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
502         auto rhs = in1;
503         Assign(dev, out, lhs.binaryExpr(rhs, func));
504         return;
505       }
506       if (b == 1) {
507         auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
508         auto rhs = in1;
509         Assign(dev, out, lhs.binaryExpr(rhs, func));
510         return;
511       }
512       if (c == 1) {
513         auto lhs = in0;
514         auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
515         Assign(dev, out, lhs.binaryExpr(rhs, func));
516         return;
517       }
518       if (d == 1) {
519         auto lhs = in0;
520         auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
521         Assign(dev, out, lhs.binaryExpr(rhs, func));
522         return;
523       }
524 
525       const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
526       const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
527       if (bcast0_all_one && !bcast1_all_one) {
528         auto lhs = in0;  // No need to do broadcast for in0
529         auto rhs = in1.broadcast(bcast1);
530         Assign(dev, out, lhs.binaryExpr(rhs, func));
531         return;
532       }
533 
534       if (!bcast0_all_one && bcast1_all_one) {
535         auto lhs = in0.broadcast(bcast0);
536         auto rhs = in1;  // No need to do broadcast for in1
537         Assign(dev, out, lhs.binaryExpr(rhs, func));
538         return;
539       }
540     }
541 
542     // Fallback path. Always works and probably slower.
543     auto lhs = in0.broadcast(bcast0);
544     auto rhs = in1.broadcast(bcast1);
545     Assign(dev, out, lhs.binaryExpr(rhs, func));
546   }
547 };
548 
549 // Version of BinaryFunctor with error handling.
550 template <typename Functor, int NDIMS>
551 struct BinaryFunctor<CPUDevice, Functor, NDIMS, true> {
552   void operator()(const CPUDevice& d, typename Functor::tout_type out,
553                   typename Functor::tin_type in0,
554                   typename Functor::tin_type in1, bool* error) {
555     Assign(d, out, in0.binaryExpr(in1, typename Functor::func(error)));
556   }
557 
558   void Left(const CPUDevice& d, typename Functor::tout_type out,
559             typename Functor::tscalar_type scalar,
560             typename Functor::tin_type in, bool* error) {
561     typedef typename Functor::out_type Tout;
562     typedef typename Functor::in_type Tin;
563     typedef typename Functor::func Binary;
564     typedef
565         typename Eigen::internal::scalar_left<Tout, Tin, Binary,
566                                               /*is_scalar_in_host_memory=*/true>
567             Unary;
568     Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
569   }
570 
571   void Right(const CPUDevice& d, typename Functor::tout_type out,
572              typename Functor::tin_type in,
573              typename Functor::tscalar_type scalar, bool* error) {
574     typedef typename Functor::out_type Tout;
575     typedef typename Functor::in_type Tin;
576     typedef typename Functor::func Binary;
577     typedef typename Eigen::internal::scalar_right<
578         Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
579         Unary;
580     Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
581   }
582 
583   void BCast(const CPUDevice& dev,
584              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
585              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
586              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
587              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
588              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
589              bool* error) {
590     typename Functor::func func(error);
591     auto lhs = in0.broadcast(bcast0);
592     auto rhs = in1.broadcast(bcast1);
593     Assign(dev, out, lhs.binaryExpr(rhs, func));
594   }
595 };
596 
597 // Partial specialization of UnaryFunctor<Device=CPUDevice, Functor>.
598 template <typename Functor>
599 struct UnaryFunctor<CPUDevice, Functor> {
600   void operator()(const CPUDevice& d, typename Functor::tout_type out,
601                   typename Functor::tin_type in) {
602     Assign(d, out, in.unaryExpr(typename Functor::func()));
603   }
604 };
605 
606 // Partial specialization of ApproximateEqual<Device=CPUDevice, T>.
607 template <typename T>
608 struct ApproximateEqual<CPUDevice, T> {
609   void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat x,
610                   typename TTypes<T>::ConstFlat y, T tolerance,
611                   typename TTypes<bool>::Flat z) {
612     auto diff = x - y;
613     z.device(d) = diff.abs() <= tolerance;
614   }
615 };
616 
617 }  // end namespace functor
618 
619 #define REGISTER(OP, D, N, F, T)                                             \
620   REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \
621                           OP<D##Device, F<T>>);
622 
623 #define REGISTER_VARIANT(OP, D, N, ENUM)                       \
624   REGISTER_KERNEL_BUILDER(                                     \
625       Name(N).Device(DEVICE_##D).TypeConstraint<Variant>("T"), \
626       OP<D##Device, ENUM>);
627 
628 // Macros to register kernels for multiple types (T0, T1, etc.)  on
629 // device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using
630 // the functor "F" (e.g., functor::sqrt).
631 
632 #if defined(__ANDROID_TYPES_SLIM__)
633 // Note that __ANDROID_TYPES_SLIM__ is also checked in the cwise_ops*.cc files.
634 // Normally Android TensorFlow is built with a reduced number of types (float).
635 // Override on the command-line using "--copt=-D__ANDROID_TYPES_FULL__"
636 // to generate a library with full type support with a consequent increase in
637 // code size.
638 #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
639 #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
640 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
641 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0)
642 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
643 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
644   REGISTER(OP, D, N, F, T0)
645 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
646   REGISTER(OP, D, N, F, T0)
647 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
648   REGISTER(OP, D, N, F, T0)
649 #else  // !defined(__ANDROID_TYPES_SLIM__)
650 #define REGISTER2(OP, D, N, F, T0, T1) \
651   REGISTER(OP, D, N, F, T0)            \
652   REGISTER(OP, D, N, F, T1)
653 #define REGISTER3(OP, D, N, F, T0, T1, T2) \
654   REGISTER2(OP, D, N, F, T0, T1)           \
655   REGISTER(OP, D, N, F, T2)
656 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
657   REGISTER2(OP, D, N, F, T0, T1)               \
658   REGISTER2(OP, D, N, F, T2, T3)
659 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
660   REGISTER3(OP, D, N, F, T0, T1, T2)               \
661   REGISTER2(OP, D, N, F, T3, T4)
662 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \
663   REGISTER3(OP, D, N, F, T0, T1, T2)                   \
664   REGISTER3(OP, D, N, F, T3, T4, T5)
665 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
666   REGISTER4(OP, D, N, F, T0, T1, T2, T3)                   \
667   REGISTER3(OP, D, N, F, T4, T5, T6)
668 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
669   REGISTER4(OP, D, N, F, T0, T1, T2, T3)                       \
670   REGISTER4(OP, D, N, F, T4, T5, T6, T7)
671 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
672   REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4)                       \
673   REGISTER4(OP, D, N, F, T5, T6, T7, T8)
674 
675 // Instead of adding REGISTER10, etc., shard the .cc files - see
676 // cwise_op_equal_to_*.cc for an example.
677 
678 #endif  // defined(__ANDROID_TYPES_SLIM__)
679 
680 }  // end namespace tensorflow
681 
682 #endif  // TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
683