1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/sdca_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include <stdint.h>
21 #include <atomic>
22 #include <limits>
23 #include <memory>
24 #include <new>
25 #include <string>
26 #include <vector>
27 
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/core/framework/device_base.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_def_builder.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_types.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/kernels/hinge-loss.h"
39 #include "tensorflow/core/kernels/logistic-loss.h"
40 #include "tensorflow/core/kernels/loss.h"
41 #include "tensorflow/core/kernels/sdca_internal.h"
42 #include "tensorflow/core/kernels/smooth-hinge-loss.h"
43 #include "tensorflow/core/kernels/squared-loss.h"
44 #include "tensorflow/core/lib/core/coding.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/gtl/inlined_vector.h"
49 #include "tensorflow/core/lib/strings/stringprintf.h"
50 #include "tensorflow/core/platform/fingerprint.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/core/platform/mutex.h"
53 #include "tensorflow/core/platform/types.h"
54 #include "tensorflow/core/util/work_sharder.h"
55 
56 namespace tensorflow {
57 
58 namespace {
59 
60 using sdca::Example;
61 using sdca::Examples;
62 using sdca::ExampleStatistics;
63 using sdca::ModelWeights;
64 using sdca::Regularizations;
65 
66 struct ComputeOptions {
ComputeOptionstensorflow::__anon409ac3500111::ComputeOptions67   explicit ComputeOptions(OpKernelConstruction* const context) {
68     string loss_type;
69     OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type));
70     if (loss_type == "logistic_loss") {
71       loss_updater.reset(new LogisticLossUpdater);
72     } else if (loss_type == "squared_loss") {
73       loss_updater.reset(new SquaredLossUpdater);
74     } else if (loss_type == "hinge_loss") {
75       loss_updater.reset(new HingeLossUpdater);
76     } else if (loss_type == "smooth_hinge_loss") {
77       loss_updater.reset(new SmoothHingeLossUpdater);
78     } else {
79       OP_REQUIRES(
80           context, false,
81           errors::InvalidArgument("Unsupported loss type: ", loss_type));
82     }
83     OP_REQUIRES_OK(context, context->GetAttr("adaptative", &adaptative));
84     OP_REQUIRES_OK(
85         context, context->GetAttr("num_sparse_features", &num_sparse_features));
86     OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values",
87                                              &num_sparse_features_with_values));
88     OP_REQUIRES_OK(context,
89                    context->GetAttr("num_dense_features", &num_dense_features));
90     OP_REQUIRES(
91         context, num_sparse_features + num_dense_features > 0,
92         errors::InvalidArgument("Requires at least one feature to train."));
93 
94     OP_REQUIRES(context,
95                 static_cast<int64>(num_sparse_features) +
96                         static_cast<int64>(num_dense_features) <=
97                     std::numeric_limits<int>::max(),
98                 errors::InvalidArgument(
99                     strings::Printf("Too many feature groups: %lld > %d",
100                                     static_cast<int64>(num_sparse_features) +
101                                         static_cast<int64>(num_dense_features),
102                                     std::numeric_limits<int>::max())));
103     OP_REQUIRES_OK(
104         context, context->GetAttr("num_loss_partitions", &num_loss_partitions));
105     OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
106                                              &num_inner_iterations));
107     OP_REQUIRES_OK(context, regularizations.Initialize(context));
108   }
109 
110   std::unique_ptr<DualLossUpdater> loss_updater;
111   int num_sparse_features = 0;
112   int num_sparse_features_with_values = 0;
113   int num_dense_features = 0;
114   int num_inner_iterations = 0;
115   int num_loss_partitions = 0;
116   bool adaptative = false;
117   Regularizations regularizations;
118 };
119 
120 // TODO(shengx): The helper classes/methods are changed to support multiclass
121 // SDCA, which lead to changes within this function. Need to revisit the
122 // convergence once the multiclass SDCA is in.
DoCompute(const ComputeOptions & options,OpKernelContext * const context)123 void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
124   ModelWeights model_weights;
125   OP_REQUIRES_OK(context, model_weights.Initialize(context));
126 
127   Examples examples;
128   OP_REQUIRES_OK(
129       context,
130       examples.Initialize(context, model_weights, options.num_sparse_features,
131                           options.num_sparse_features_with_values,
132                           options.num_dense_features));
133 
134   const Tensor* example_state_data_t;
135   OP_REQUIRES_OK(context,
136                  context->input("example_state_data", &example_state_data_t));
137   TensorShape expected_example_state_shape({examples.num_examples(), 4});
138   OP_REQUIRES(context,
139               example_state_data_t->shape() == expected_example_state_shape,
140               errors::InvalidArgument(
141                   "Expected shape ", expected_example_state_shape.DebugString(),
142                   " for example_state_data, got ",
143                   example_state_data_t->shape().DebugString()));
144 
145   Tensor mutable_example_state_data_t(*example_state_data_t);
146   auto example_state_data = mutable_example_state_data_t.matrix<float>();
147   OP_REQUIRES_OK(context, context->set_output("out_example_state_data",
148                                               mutable_example_state_data_t));
149 
150   if (options.adaptative) {
151     OP_REQUIRES_OK(context,
152                    examples.SampleAdaptativeProbabilities(
153                        options.num_loss_partitions, options.regularizations,
154                        model_weights, example_state_data, options.loss_updater,
155                        /*num_weight_vectors =*/1));
156   }
157 
158   mutex mu;
159   Status train_step_status GUARDED_BY(mu);
160   std::atomic<std::int64_t> atomic_index(-1);
161   auto train_step = [&](const int64 begin, const int64 end) {
162     // The static_cast here is safe since begin and end can be at most
163     // num_examples which is an int.
164     for (int id = static_cast<int>(begin); id < end; ++id) {
165       const int64 example_index =
166           examples.sampled_index(++atomic_index, options.adaptative);
167       const Example& example = examples.example(example_index);
168       const float dual = example_state_data(example_index, 0);
169       const float example_weight = example.example_weight();
170       float example_label = example.example_label();
171       const Status conversion_status =
172           options.loss_updater->ConvertLabel(&example_label);
173       if (!conversion_status.ok()) {
174         mutex_lock l(mu);
175         train_step_status = conversion_status;
176         // Return from this worker thread - the calling thread is
177         // responsible for checking context status and returning on error.
178         return;
179       }
180 
181       // Compute wx, example norm weighted by regularization, dual loss,
182       // primal loss.
183       // For binary SDCA, num_weight_vectors should be one.
184       const ExampleStatistics example_statistics =
185           example.ComputeWxAndWeightedExampleNorm(
186               options.num_loss_partitions, model_weights,
187               options.regularizations, 1 /* num_weight_vectors */);
188 
189       const double new_dual = options.loss_updater->ComputeUpdatedDual(
190           options.num_loss_partitions, example_label, example_weight, dual,
191           example_statistics.wx[0], example_statistics.normalized_squared_norm);
192 
193       // Compute new weights.
194       const double normalized_bounded_dual_delta =
195           (new_dual - dual) * example_weight /
196           options.regularizations.symmetric_l2();
197       model_weights.UpdateDeltaWeights(
198           context->eigen_cpu_device(), example,
199           std::vector<double>{normalized_bounded_dual_delta});
200 
201       // Update example data.
202       example_state_data(example_index, 0) = new_dual;
203       example_state_data(example_index, 1) =
204           options.loss_updater->ComputePrimalLoss(
205               example_statistics.prev_wx[0], example_label, example_weight);
206       example_state_data(example_index, 2) =
207           options.loss_updater->ComputeDualLoss(dual, example_label,
208                                                 example_weight);
209       example_state_data(example_index, 3) = example_weight;
210     }
211   };
212   // TODO(sibyl-Aix6ihai): Tune this properly based on sparsity of the data,
213   // number of cpus, and cost per example.
214   const int64 kCostPerUnit = examples.num_features();
215   const DeviceBase::CpuWorkerThreads& worker_threads =
216       *context->device()->tensorflow_cpu_worker_threads();
217 
218   Shard(worker_threads.num_threads, worker_threads.workers,
219         examples.num_examples(), kCostPerUnit, train_step);
220   OP_REQUIRES_OK(context, train_step_status);
221 }
222 
223 }  // namespace
224 
225 class SdcaOptimizer : public OpKernel {
226  public:
SdcaOptimizer(OpKernelConstruction * const context)227   explicit SdcaOptimizer(OpKernelConstruction* const context)
228       : OpKernel(context), options_(context) {}
229 
Compute(OpKernelContext * context)230   void Compute(OpKernelContext* context) override {
231     DoCompute(options_, context);
232   }
233 
234  private:
235   // TODO(sibyl-Aix6ihai): We could use the type-constraint on loss_type, and
236   // template the entire class to avoid the virtual table lookup penalty in
237   // the inner loop.
238   ComputeOptions options_;
239 };
240 REGISTER_KERNEL_BUILDER(Name("SdcaOptimizer").Device(DEVICE_CPU),
241                         SdcaOptimizer);
242 
243 class SdcaShrinkL1 : public OpKernel {
244  public:
SdcaShrinkL1(OpKernelConstruction * const context)245   explicit SdcaShrinkL1(OpKernelConstruction* const context)
246       : OpKernel(context) {
247     OP_REQUIRES_OK(context, regularizations_.Initialize(context));
248   }
249 
Compute(OpKernelContext * context)250   void Compute(OpKernelContext* context) override {
251     OpMutableInputList weights_inputs;
252     OP_REQUIRES_OK(context,
253                    context->mutable_input_list("weights", &weights_inputs));
254 
255     auto do_work = [&](const int64 begin, const int64 end) {
256       for (int i = begin; i < end; ++i) {
257         auto prox_w = weights_inputs.at(i, /*lock_held=*/true).flat<float>();
258         prox_w.device(context->eigen_cpu_device()) =
259             regularizations_.EigenShrinkVector(prox_w);
260       }
261     };
262 
263     if (weights_inputs.size() > 0) {
264       int64 num_weights = 0;
265       for (int i = 0; i < weights_inputs.size(); ++i) {
266         num_weights += weights_inputs.at(i, /*lock_held=*/true).NumElements();
267       }
268       // TODO(sibyl-Aix6ihai): Tune this value.
269       const int64 kCostPerUnit = (num_weights * 50) / weights_inputs.size();
270       const DeviceBase::CpuWorkerThreads& worker_threads =
271           *context->device()->tensorflow_cpu_worker_threads();
272       Shard(worker_threads.num_threads, worker_threads.workers,
273             weights_inputs.size(), kCostPerUnit, do_work);
274     }
275   }
276 
277  private:
278   Regularizations regularizations_;
279 };
280 REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
281 
282 // Computes platform independent, compact and unique (with very high
283 // probability) representation of an example id. It shouldn't be put in
284 // persistent storage, as its implementation may change in the future.
285 //
286 // The current probability of at least one collision for 1B example_ids is
287 // approximately 10^-21 (ie 2^60 / 2^129).
288 class SdcaFprint : public OpKernel {
289  public:
SdcaFprint(OpKernelConstruction * const context)290   explicit SdcaFprint(OpKernelConstruction* const context)
291       : OpKernel(context) {}
292 
Compute(OpKernelContext * context)293   void Compute(OpKernelContext* context) override {
294     const Tensor& input = context->input(0);
295     OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
296                 errors::InvalidArgument("Input must be a vector, got shape ",
297                                         input.shape().DebugString()));
298     Tensor* out;
299     const int64 num_elements = input.NumElements();
300     OP_REQUIRES_OK(context, context->allocate_output(
301                                 0, TensorShape({num_elements, 2}), &out));
302 
303     const auto in_values = input.flat<string>();
304     auto out_values = out->matrix<int64>();
305 
306     for (int64 i = 0; i < num_elements; ++i) {
307       const Fprint128 fprint = Fingerprint128(in_values(i));
308       // Never return 0 or 1 as the first value of the hash to allow these to
309       // safely be used as sentinel values (e.g. dense hash table empty key).
310       out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2)
311                              ? fprint.low64
312                              : fprint.low64 + ~static_cast<uint64>(1);
313       out_values(i, 1) = fprint.high64;
314     }
315   }
316 };
317 REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint);
318 
319 }  // namespace tensorflow
320