1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <algorithm> 19 #include <functional> 20 #include <iterator> 21 #include <numeric> 22 #include <vector> 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/tensor_types.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/core/threadpool.h" 30 #include "tensorflow/core/lib/gtl/array_slice.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace tensorflow { 36 namespace { 37 using errors::InvalidArgument; 38 39 class PmfToCdfOp : public OpKernel { 40 public: PmfToCdfOp(OpKernelConstruction * context)41 explicit PmfToCdfOp(OpKernelConstruction* context) : OpKernel(context) { 42 OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_)); 43 OP_REQUIRES( 44 context, 0 < precision_ && precision_ <= 16, 45 InvalidArgument("`precision` must be in [1, 16]: ", precision_)); 46 } 47 Compute(OpKernelContext * context)48 void Compute(OpKernelContext* context) override { 49 const Tensor& pmf_tensor = context->input(0); 50 51 TensorShape shape = pmf_tensor.shape(); 52 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(shape), 53 InvalidArgument("`pmf` should be at least 1-D.")); 54 OP_REQUIRES( 55 context, shape.dim_size(shape.dims() - 1) > 1, 56 InvalidArgument("`pmf` size should be at least 2 in the last axis.")); 57 shape.set_dim(shape.dims() - 1, shape.dim_size(shape.dims() - 1) + 1); 58 59 Tensor* cdf_tensor; 60 OP_REQUIRES_OK(context, context->allocate_output(0, shape, &cdf_tensor)); 61 62 auto pmf = pmf_tensor.flat_inner_dims<float, 2>(); 63 auto cdf = cdf_tensor->flat_inner_dims<int32, 2>(); 64 CHECK_EQ(pmf.dimension(0), cdf.dimension(0)); 65 CHECK_EQ(pmf.dimension(1) + 1, cdf.dimension(1)); 66 67 const double n = pmf.dimension(1); 68 const int64 cost_per_unit = static_cast<int64>(50.0 * n * std::log2(n)); 69 thread::ThreadPool* thread_pool = 70 context->device()->tensorflow_cpu_worker_threads()->workers; 71 thread_pool->ParallelFor( 72 pmf.dimension(0), cost_per_unit, 73 [this, pmf, &cdf](int64 start, int64 limit) { 74 const gtl::ArraySlice<float>::size_type pmf_size = pmf.dimension(1); 75 for (int64 i = start; i < limit; ++i) { 76 cdf(i, 0) = 0; 77 PerShard({&pmf(i, 0), pmf_size}, {&cdf(i, 1), pmf_size}); 78 } 79 }); 80 } 81 82 private: 83 struct PenaltyItem { PenaltyItemtensorflow::__anon312944970111::PmfToCdfOp::PenaltyItem84 PenaltyItem(int32* p, double mass) : pointer(p), mass(mass) { 85 penalty = ComputeNextPenalty(); 86 } 87 Decreasetensorflow::__anon312944970111::PmfToCdfOp::PenaltyItem88 void Decrease() { 89 CHECK_GT(*pointer, 1); 90 --*pointer; 91 penalty = ComputeNextPenalty(); 92 } 93 operator <(const PenaltyItem & lhs,const PenaltyItem & rhs)94 friend bool operator<(const PenaltyItem& lhs, const PenaltyItem& rhs) { 95 return lhs.penalty < rhs.penalty; 96 } 97 ComputeNextPenaltytensorflow::__anon312944970111::PmfToCdfOp::PenaltyItem98 double ComputeNextPenalty() { 99 if (*pointer <= 1) { 100 return std::numeric_limits<double>::infinity(); 101 } 102 return mass * (std::log2(*pointer) - std::log2(*pointer - 1)); 103 } 104 105 int32* pointer; 106 double mass; 107 double penalty; 108 }; 109 110 struct GainItem { GainItemtensorflow::__anon312944970111::PmfToCdfOp::GainItem111 GainItem(int32* p, double mass) : pointer(p), mass(mass) { 112 gain = ComputeNextGain(); 113 } 114 Increasetensorflow::__anon312944970111::PmfToCdfOp::GainItem115 void Increase() { 116 CHECK_GT(*pointer, 0); 117 ++*pointer; 118 gain = ComputeNextGain(); 119 } 120 operator >(const GainItem & lhs,const GainItem & rhs)121 friend bool operator>(const GainItem& lhs, const GainItem& rhs) { 122 return lhs.gain > rhs.gain; 123 } 124 ComputeNextGaintensorflow::__anon312944970111::PmfToCdfOp::GainItem125 double ComputeNextGain() { 126 // Never increment zero value to non-zero value. 127 if (*pointer < 1) { 128 return -std::numeric_limits<double>::infinity(); 129 } 130 return mass * (std::log2(*pointer + 1) - std::log2(*pointer)); 131 } 132 133 int32* pointer; 134 double mass; 135 double gain; 136 }; 137 PerShard(gtl::ArraySlice<float> pmf,gtl::MutableArraySlice<int32> cdf) const138 void PerShard(gtl::ArraySlice<float> pmf, 139 gtl::MutableArraySlice<int32> cdf) const { 140 CHECK_EQ(pmf.size(), cdf.size()); 141 142 const int32 normalizer = 1 << precision_; 143 std::transform(pmf.begin(), pmf.end(), cdf.begin(), 144 [normalizer](float mass) { 145 int32 value = std::rint(mass * normalizer); 146 // NOTE: Consider checking if mass > 0. 147 value = std::max(value, 1); 148 return value; 149 }); 150 151 int32 sum = std::accumulate(cdf.begin(), cdf.end(), 0); 152 if (sum > normalizer) { 153 std::vector<PenaltyItem> queue; 154 queue.reserve(cdf.size()); 155 for (int i = 0; i < cdf.size(); ++i) { 156 queue.emplace_back(&cdf[i], pmf[i]); 157 } 158 159 std::sort(queue.begin(), queue.end()); 160 while (sum-- > normalizer) { 161 queue[0].Decrease(); 162 // Performs a linear search because this find_if is likely to return 163 // iterator very close to the begin. 164 auto iter = std::find_if( 165 std::next(queue.begin()), queue.end(), 166 [&queue](const PenaltyItem& rhs) { return queue[0] < rhs; }); 167 std::rotate(queue.begin(), std::next(queue.begin()), iter); 168 } 169 } else if (sum < normalizer) { 170 std::vector<GainItem> queue; 171 queue.reserve(cdf.size()); 172 for (int i = 0; i < cdf.size(); ++i) { 173 queue.emplace_back(&cdf[i], pmf[i]); 174 } 175 176 std::sort(queue.begin(), queue.end(), std::greater<GainItem>()); 177 while (sum++ < normalizer) { 178 queue[0].Increase(); 179 // Performs a linear search because this find_if is likely to return 180 // iterator very close to the begin. 181 auto iter = std::find_if( 182 std::next(queue.begin()), queue.end(), 183 [&queue](const GainItem& rhs) { return queue[0] > rhs; }); 184 std::rotate(queue.begin(), std::next(queue.begin()), iter); 185 } 186 } 187 std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); 188 } 189 190 int precision_; 191 }; 192 193 REGISTER_KERNEL_BUILDER(Name("PmfToQuantizedCdf").Device(DEVICE_CPU), 194 PmfToCdfOp); 195 } // namespace 196 } // namespace tensorflow 197