• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/compare_and_bitpack_op.h"
21 
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/util/work_sharder.h"
30 
31 namespace tensorflow {
32 
33 typedef Eigen::ThreadPoolDevice CPUDevice;
34 typedef Eigen::GpuDevice GPUDevice;
35 
36 template <typename Device, typename T>
37 class CompareAndBitpackOp : public OpKernel {
38  public:
CompareAndBitpackOp(OpKernelConstruction * context)39   explicit CompareAndBitpackOp(OpKernelConstruction* context)
40       : OpKernel(context) {}
41 
Compute(OpKernelContext * c)42   void Compute(OpKernelContext* c) override {
43     const Tensor& input_t = c->input(0);
44     const Tensor& threshold_t = c->input(1);
45     OP_REQUIRES(
46         c, TensorShapeUtils::IsScalar(threshold_t.shape()),
47         errors::InvalidArgument("Compare must be a scalar, but saw shape: ",
48                                 threshold_t.shape().DebugString()));
49     const TensorShape& input_shape = input_t.shape();
50     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_shape),
51                 errors::InvalidArgument(
52                     "Input should be at least a vector, but saw a scalar."));
53     OP_REQUIRES(c, input_shape.dim_size(input_shape.dims() - 1) % 8 == 0,
54                 errors::InvalidArgument(
55                     "Inner dimension of input should be "
56                     "divisible by ",
57                     8, ", but saw shape: ", input_shape.DebugString()));
58 
59     TensorShape output_shape = input_shape;
60     int rank = input_shape.dims();
61     output_shape.set_dim(rank - 1, input_shape.dim_size(rank - 1) / 8);
62 
63     Tensor* output_t;
64     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output_t));
65 
66     auto input = input_t.flat_inner_dims<T>();
67     auto threshold = threshold_t.scalar<T>();
68     auto output = output_t->flat_inner_dims<uint8>();
69 
70     functor::CompareAndBitpack<Device, T> func;
71     func(c, input, threshold, output);
72   }
73 };
74 
75 #define REGISTER_COMPARE_AND_BITPACK(type)                                    \
76   REGISTER_KERNEL_BUILDER(                                                    \
77       Name("CompareAndBitpack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
78       CompareAndBitpackOp<CPUDevice, type>);
79 
80 TF_CALL_REAL_NUMBER_TYPES(REGISTER_COMPARE_AND_BITPACK);
81 TF_CALL_bool(REGISTER_COMPARE_AND_BITPACK);
82 
83 #undef REGISTER_COMPARE_AND_BITPACK
84 
85 namespace functor {
86 
87 template <typename T, class = void, class = void>
88 struct ComputeShard {
Computetensorflow::functor::ComputeShard89   static EIGEN_STRONG_INLINE void Compute(typename TTypes<T>::ConstMatrix input,
90                                           typename TTypes<uint8>::Matrix output,
91                                           const T& thresh, int64 start,
92                                           int64 limit) {
93     for (int64 i = start; i < limit; ++i) {
94       uint8* out = output.data() + i;
95       const T* block = input.data() + 8 * i;
96       *out = ((((block[0] > thresh) << 7)) | (((block[1] > thresh) << 6)) |
97               (((block[2] > thresh) << 5)) | (((block[3] > thresh) << 4)) |
98               (((block[4] > thresh) << 3)) | (((block[5] > thresh) << 2)) |
99               (((block[6] > thresh) << 1)) | (((block[7] > thresh))));
100     }
101   }
102 };
103 
104 // Specialization for bool on systems where sizeof(bool) == 1.
105 template <typename T>
106 struct ComputeShard<T,
107                     typename std::enable_if<std::is_same<T, bool>::value>::type,
108                     typename std::enable_if<sizeof(T) == 1>::type> {
Computetensorflow::functor::ComputeShard109   static EIGEN_STRONG_INLINE void Compute(
110       typename TTypes<bool>::ConstMatrix input,
111       typename TTypes<uint8>::Matrix output, bool /*thresh*/, int64 start,
112       int64 limit) {
113 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
114     for (int64 i = start; i < limit; ++i) {
115       uint8* out = output.data() + i;
116       const int64 block = *reinterpret_cast<const int64*>(input.data() + 8 * i);
117       *out = ((((block & (1LL << (7 * 8))) >> (7 * 8 - 7))) |
118               (((block & (1LL << (6 * 8))) >> (6 * 8 - 6))) |
119               (((block & (1LL << (5 * 8))) >> (5 * 8 - 5))) |
120               (((block & (1LL << (4 * 8))) >> (4 * 8 - 4))) |
121               (((block & (1LL << (3 * 8))) >> (3 * 8 - 3))) |
122               (((block & (1LL << (2 * 8))) >> (2 * 8 - 2))) |
123               (((block & (1LL << 8)) >> (1 * 8 - 1))) | (((block & (1LL)))));
124     }
125 #else
126     for (int64 i = start; i < limit; ++i) {
127       uint8* out = output.data() + i;
128       const int64 block = *reinterpret_cast<const int64*>(input.data() + 8 * i);
129       *out =
130           ((((block & (1LL << (7 * 8))) >> (7 * 8 - 0))) |
131            (((block & (1LL << (6 * 8))) >> (6 * 8 - 1))) |
132            (((block & (1LL << (5 * 8))) >> (5 * 8 - 2))) |
133            (((block & (1LL << (4 * 8))) >> (4 * 8 - 3))) |
134            (((block & (1LL << (3 * 8))) >> (3 * 8 - 4))) |
135            (((block & (1LL << (2 * 8))) >> (2 * 8 - 5))) |
136            (((block & (1LL << 8)) >> (1 * 8 - 6))) | (((block & (1LL)) << 7)));
137     }
138 #endif
139   }
140 };
141 
142 template <typename T>
143 struct CompareAndBitpack<CPUDevice, T> {
operator ()tensorflow::functor::CompareAndBitpack144   void operator()(OpKernelContext* c, typename TTypes<T>::ConstMatrix input,
145                   typename TTypes<T>::ConstScalar threshold,
146                   TTypes<uint8>::Matrix output) {
147     const T thresh = threshold();
148     auto shard = [&, thresh](int64 start, int64 limit) {
149       ComputeShard<T>::Compute(input, output, thresh, start, limit);
150     };
151     int64 total_shards = output.size();  // Approximate cmp as an add and
152                                          // bitwise-or + shift as an add.
153     const double total_cost = 8 * (Eigen::TensorOpCost::AddCost<T>() +
154                                    Eigen::TensorOpCost::AddCost<uint8>());
155     const int64 shard_cost = (total_cost >= static_cast<double>(kint64max))
156                                  ? kint64max
157                                  : static_cast<int64>(total_cost);
158 
159     auto worker_threads = *(c->device()->tensorflow_cpu_worker_threads());
160     Shard(worker_threads.num_threads, worker_threads.workers, total_shards,
161           shard_cost, shard);
162   }
163 };
164 
165 }  // namespace functor
166 
167 #if GOOGLE_CUDA
168 
169 #define REGISTER_COMPARE_AND_BITPACK(type)                                    \
170   REGISTER_KERNEL_BUILDER(                                                    \
171       Name("CompareAndBitpack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
172       CompareAndBitpackOp<GPUDevice, type>);
173 
174 TF_CALL_GPU_NUMBER_TYPES(REGISTER_COMPARE_AND_BITPACK);
175 TF_CALL_bool(REGISTER_COMPARE_AND_BITPACK);
176 
177 #undef REGISTER_COMPARE_AND_BITPACK
178 
179 namespace functor {
180 
181 #define DECLARE_GPU_SPEC(T)                                      \
182   template <>                                                    \
183   void CompareAndBitpack<GPUDevice, T>::operator()(              \
184       OpKernelContext* c, typename TTypes<T>::ConstMatrix input, \
185       typename TTypes<T>::ConstScalar threshold,                 \
186       TTypes<uint8>::Matrix output);                             \
187   extern template struct CompareAndBitpack<GPUDevice, T>;
188 
189 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC)
190 TF_CALL_bool(DECLARE_GPU_SPEC)
191 
192 #undef DECLARE_GPU_SPEC
193 
194 }  // namespace functor
195 
196 #endif  // GOOGLE_CUDA
197 
198 }  // namespace tensorflow
199