1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <algorithm>
19 #include <functional>
20 #include <iterator>
21 #include <numeric>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/threadpool.h"
30 #include "tensorflow/core/lib/gtl/array_slice.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace tensorflow {
36 namespace {
37 using errors::InvalidArgument;
38 
39 class PmfToCdfOp : public OpKernel {
40  public:
PmfToCdfOp(OpKernelConstruction * context)41   explicit PmfToCdfOp(OpKernelConstruction* context) : OpKernel(context) {
42     OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_));
43     OP_REQUIRES(
44         context, 0 < precision_ && precision_ <= 16,
45         InvalidArgument("`precision` must be in [1, 16]: ", precision_));
46   }
47 
Compute(OpKernelContext * context)48   void Compute(OpKernelContext* context) override {
49     const Tensor& pmf_tensor = context->input(0);
50 
51     TensorShape shape = pmf_tensor.shape();
52     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(shape),
53                 InvalidArgument("`pmf` should be at least 1-D."));
54     OP_REQUIRES(
55         context, shape.dim_size(shape.dims() - 1) > 1,
56         InvalidArgument("`pmf` size should be at least 2 in the last axis."));
57     shape.set_dim(shape.dims() - 1, shape.dim_size(shape.dims() - 1) + 1);
58 
59     Tensor* cdf_tensor;
60     OP_REQUIRES_OK(context, context->allocate_output(0, shape, &cdf_tensor));
61 
62     auto pmf = pmf_tensor.flat_inner_dims<float, 2>();
63     auto cdf = cdf_tensor->flat_inner_dims<int32, 2>();
64     CHECK_EQ(pmf.dimension(0), cdf.dimension(0));
65     CHECK_EQ(pmf.dimension(1) + 1, cdf.dimension(1));
66 
67     const double n = pmf.dimension(1);
68     const int64 cost_per_unit = static_cast<int64>(50.0 * n * std::log2(n));
69     thread::ThreadPool* thread_pool =
70         context->device()->tensorflow_cpu_worker_threads()->workers;
71     thread_pool->ParallelFor(
72         pmf.dimension(0), cost_per_unit,
73         [this, pmf, &cdf](int64 start, int64 limit) {
74           const gtl::ArraySlice<float>::size_type pmf_size = pmf.dimension(1);
75           for (int64 i = start; i < limit; ++i) {
76             cdf(i, 0) = 0;
77             PerShard({&pmf(i, 0), pmf_size}, {&cdf(i, 1), pmf_size});
78           }
79         });
80   }
81 
82  private:
83   struct PenaltyItem {
PenaltyItemtensorflow::__anon312944970111::PmfToCdfOp::PenaltyItem84     PenaltyItem(int32* p, double mass) : pointer(p), mass(mass) {
85       penalty = ComputeNextPenalty();
86     }
87 
Decreasetensorflow::__anon312944970111::PmfToCdfOp::PenaltyItem88     void Decrease() {
89       CHECK_GT(*pointer, 1);
90       --*pointer;
91       penalty = ComputeNextPenalty();
92     }
93 
operator <(const PenaltyItem & lhs,const PenaltyItem & rhs)94     friend bool operator<(const PenaltyItem& lhs, const PenaltyItem& rhs) {
95       return lhs.penalty < rhs.penalty;
96     }
97 
ComputeNextPenaltytensorflow::__anon312944970111::PmfToCdfOp::PenaltyItem98     double ComputeNextPenalty() {
99       if (*pointer <= 1) {
100         return std::numeric_limits<double>::infinity();
101       }
102       return mass * (std::log2(*pointer) - std::log2(*pointer - 1));
103     }
104 
105     int32* pointer;
106     double mass;
107     double penalty;
108   };
109 
110   struct GainItem {
GainItemtensorflow::__anon312944970111::PmfToCdfOp::GainItem111     GainItem(int32* p, double mass) : pointer(p), mass(mass) {
112       gain = ComputeNextGain();
113     }
114 
Increasetensorflow::__anon312944970111::PmfToCdfOp::GainItem115     void Increase() {
116       CHECK_GT(*pointer, 0);
117       ++*pointer;
118       gain = ComputeNextGain();
119     }
120 
operator >(const GainItem & lhs,const GainItem & rhs)121     friend bool operator>(const GainItem& lhs, const GainItem& rhs) {
122       return lhs.gain > rhs.gain;
123     }
124 
ComputeNextGaintensorflow::__anon312944970111::PmfToCdfOp::GainItem125     double ComputeNextGain() {
126       // Never increment zero value to non-zero value.
127       if (*pointer < 1) {
128         return -std::numeric_limits<double>::infinity();
129       }
130       return mass * (std::log2(*pointer + 1) - std::log2(*pointer));
131     }
132 
133     int32* pointer;
134     double mass;
135     double gain;
136   };
137 
PerShard(gtl::ArraySlice<float> pmf,gtl::MutableArraySlice<int32> cdf) const138   void PerShard(gtl::ArraySlice<float> pmf,
139                 gtl::MutableArraySlice<int32> cdf) const {
140     CHECK_EQ(pmf.size(), cdf.size());
141 
142     const int32 normalizer = 1 << precision_;
143     std::transform(pmf.begin(), pmf.end(), cdf.begin(),
144                    [normalizer](float mass) {
145                      int32 value = std::rint(mass * normalizer);
146                      // NOTE: Consider checking if mass > 0.
147                      value = std::max(value, 1);
148                      return value;
149                    });
150 
151     int32 sum = std::accumulate(cdf.begin(), cdf.end(), 0);
152     if (sum > normalizer) {
153       std::vector<PenaltyItem> queue;
154       queue.reserve(cdf.size());
155       for (int i = 0; i < cdf.size(); ++i) {
156         queue.emplace_back(&cdf[i], pmf[i]);
157       }
158 
159       std::sort(queue.begin(), queue.end());
160       while (sum-- > normalizer) {
161         queue[0].Decrease();
162         // Performs a linear search because this find_if is likely to return
163         // iterator very close to the begin.
164         auto iter = std::find_if(
165             std::next(queue.begin()), queue.end(),
166             [&queue](const PenaltyItem& rhs) { return queue[0] < rhs; });
167         std::rotate(queue.begin(), std::next(queue.begin()), iter);
168       }
169     } else if (sum < normalizer) {
170       std::vector<GainItem> queue;
171       queue.reserve(cdf.size());
172       for (int i = 0; i < cdf.size(); ++i) {
173         queue.emplace_back(&cdf[i], pmf[i]);
174       }
175 
176       std::sort(queue.begin(), queue.end(), std::greater<GainItem>());
177       while (sum++ < normalizer) {
178         queue[0].Increase();
179         // Performs a linear search because this find_if is likely to return
180         // iterator very close to the begin.
181         auto iter = std::find_if(
182             std::next(queue.begin()), queue.end(),
183             [&queue](const GainItem& rhs) { return queue[0] > rhs; });
184         std::rotate(queue.begin(), std::next(queue.begin()), iter);
185       }
186     }
187     std::partial_sum(cdf.begin(), cdf.end(), cdf.begin());
188   }
189 
190   int precision_;
191 };
192 
193 REGISTER_KERNEL_BUILDER(Name("PmfToQuantizedCdf").Device(DEVICE_CPU),
194                         PmfToCdfOp);
195 }  // namespace
196 }  // namespace tensorflow
197