1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/sdca_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include <stdint.h>
21 #include <atomic>
22 #include <limits>
23 #include <memory>
24 #include <new>
25 #include <string>
26 #include <vector>
27
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/core/framework/device_base.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_def_builder.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_types.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/kernels/hinge-loss.h"
39 #include "tensorflow/core/kernels/logistic-loss.h"
40 #include "tensorflow/core/kernels/loss.h"
41 #include "tensorflow/core/kernels/sdca_internal.h"
42 #include "tensorflow/core/kernels/smooth-hinge-loss.h"
43 #include "tensorflow/core/kernels/squared-loss.h"
44 #include "tensorflow/core/lib/core/coding.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/gtl/inlined_vector.h"
49 #include "tensorflow/core/lib/strings/stringprintf.h"
50 #include "tensorflow/core/platform/fingerprint.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/core/platform/mutex.h"
53 #include "tensorflow/core/platform/types.h"
54 #include "tensorflow/core/util/work_sharder.h"
55
56 namespace tensorflow {
57
58 namespace {
59
60 using sdca::Example;
61 using sdca::Examples;
62 using sdca::ExampleStatistics;
63 using sdca::ModelWeights;
64 using sdca::Regularizations;
65
66 struct ComputeOptions {
ComputeOptionstensorflow::__anon409ac3500111::ComputeOptions67 explicit ComputeOptions(OpKernelConstruction* const context) {
68 string loss_type;
69 OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type));
70 if (loss_type == "logistic_loss") {
71 loss_updater.reset(new LogisticLossUpdater);
72 } else if (loss_type == "squared_loss") {
73 loss_updater.reset(new SquaredLossUpdater);
74 } else if (loss_type == "hinge_loss") {
75 loss_updater.reset(new HingeLossUpdater);
76 } else if (loss_type == "smooth_hinge_loss") {
77 loss_updater.reset(new SmoothHingeLossUpdater);
78 } else {
79 OP_REQUIRES(
80 context, false,
81 errors::InvalidArgument("Unsupported loss type: ", loss_type));
82 }
83 OP_REQUIRES_OK(context, context->GetAttr("adaptative", &adaptative));
84 OP_REQUIRES_OK(
85 context, context->GetAttr("num_sparse_features", &num_sparse_features));
86 OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values",
87 &num_sparse_features_with_values));
88 OP_REQUIRES_OK(context,
89 context->GetAttr("num_dense_features", &num_dense_features));
90 OP_REQUIRES(
91 context, num_sparse_features + num_dense_features > 0,
92 errors::InvalidArgument("Requires at least one feature to train."));
93
94 OP_REQUIRES(context,
95 static_cast<int64>(num_sparse_features) +
96 static_cast<int64>(num_dense_features) <=
97 std::numeric_limits<int>::max(),
98 errors::InvalidArgument(
99 strings::Printf("Too many feature groups: %lld > %d",
100 static_cast<int64>(num_sparse_features) +
101 static_cast<int64>(num_dense_features),
102 std::numeric_limits<int>::max())));
103 OP_REQUIRES_OK(
104 context, context->GetAttr("num_loss_partitions", &num_loss_partitions));
105 OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
106 &num_inner_iterations));
107 OP_REQUIRES_OK(context, regularizations.Initialize(context));
108 }
109
110 std::unique_ptr<DualLossUpdater> loss_updater;
111 int num_sparse_features = 0;
112 int num_sparse_features_with_values = 0;
113 int num_dense_features = 0;
114 int num_inner_iterations = 0;
115 int num_loss_partitions = 0;
116 bool adaptative = false;
117 Regularizations regularizations;
118 };
119
120 // TODO(shengx): The helper classes/methods are changed to support multiclass
121 // SDCA, which lead to changes within this function. Need to revisit the
122 // convergence once the multiclass SDCA is in.
DoCompute(const ComputeOptions & options,OpKernelContext * const context)123 void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
124 ModelWeights model_weights;
125 OP_REQUIRES_OK(context, model_weights.Initialize(context));
126
127 Examples examples;
128 OP_REQUIRES_OK(
129 context,
130 examples.Initialize(context, model_weights, options.num_sparse_features,
131 options.num_sparse_features_with_values,
132 options.num_dense_features));
133
134 const Tensor* example_state_data_t;
135 OP_REQUIRES_OK(context,
136 context->input("example_state_data", &example_state_data_t));
137 TensorShape expected_example_state_shape({examples.num_examples(), 4});
138 OP_REQUIRES(context,
139 example_state_data_t->shape() == expected_example_state_shape,
140 errors::InvalidArgument(
141 "Expected shape ", expected_example_state_shape.DebugString(),
142 " for example_state_data, got ",
143 example_state_data_t->shape().DebugString()));
144
145 Tensor mutable_example_state_data_t(*example_state_data_t);
146 auto example_state_data = mutable_example_state_data_t.matrix<float>();
147 OP_REQUIRES_OK(context, context->set_output("out_example_state_data",
148 mutable_example_state_data_t));
149
150 if (options.adaptative) {
151 OP_REQUIRES_OK(context,
152 examples.SampleAdaptativeProbabilities(
153 options.num_loss_partitions, options.regularizations,
154 model_weights, example_state_data, options.loss_updater,
155 /*num_weight_vectors =*/1));
156 }
157
158 mutex mu;
159 Status train_step_status GUARDED_BY(mu);
160 std::atomic<std::int64_t> atomic_index(-1);
161 auto train_step = [&](const int64 begin, const int64 end) {
162 // The static_cast here is safe since begin and end can be at most
163 // num_examples which is an int.
164 for (int id = static_cast<int>(begin); id < end; ++id) {
165 const int64 example_index =
166 examples.sampled_index(++atomic_index, options.adaptative);
167 const Example& example = examples.example(example_index);
168 const float dual = example_state_data(example_index, 0);
169 const float example_weight = example.example_weight();
170 float example_label = example.example_label();
171 const Status conversion_status =
172 options.loss_updater->ConvertLabel(&example_label);
173 if (!conversion_status.ok()) {
174 mutex_lock l(mu);
175 train_step_status = conversion_status;
176 // Return from this worker thread - the calling thread is
177 // responsible for checking context status and returning on error.
178 return;
179 }
180
181 // Compute wx, example norm weighted by regularization, dual loss,
182 // primal loss.
183 // For binary SDCA, num_weight_vectors should be one.
184 const ExampleStatistics example_statistics =
185 example.ComputeWxAndWeightedExampleNorm(
186 options.num_loss_partitions, model_weights,
187 options.regularizations, 1 /* num_weight_vectors */);
188
189 const double new_dual = options.loss_updater->ComputeUpdatedDual(
190 options.num_loss_partitions, example_label, example_weight, dual,
191 example_statistics.wx[0], example_statistics.normalized_squared_norm);
192
193 // Compute new weights.
194 const double normalized_bounded_dual_delta =
195 (new_dual - dual) * example_weight /
196 options.regularizations.symmetric_l2();
197 model_weights.UpdateDeltaWeights(
198 context->eigen_cpu_device(), example,
199 std::vector<double>{normalized_bounded_dual_delta});
200
201 // Update example data.
202 example_state_data(example_index, 0) = new_dual;
203 example_state_data(example_index, 1) =
204 options.loss_updater->ComputePrimalLoss(
205 example_statistics.prev_wx[0], example_label, example_weight);
206 example_state_data(example_index, 2) =
207 options.loss_updater->ComputeDualLoss(dual, example_label,
208 example_weight);
209 example_state_data(example_index, 3) = example_weight;
210 }
211 };
212 // TODO(sibyl-Aix6ihai): Tune this properly based on sparsity of the data,
213 // number of cpus, and cost per example.
214 const int64 kCostPerUnit = examples.num_features();
215 const DeviceBase::CpuWorkerThreads& worker_threads =
216 *context->device()->tensorflow_cpu_worker_threads();
217
218 Shard(worker_threads.num_threads, worker_threads.workers,
219 examples.num_examples(), kCostPerUnit, train_step);
220 OP_REQUIRES_OK(context, train_step_status);
221 }
222
223 } // namespace
224
225 class SdcaOptimizer : public OpKernel {
226 public:
SdcaOptimizer(OpKernelConstruction * const context)227 explicit SdcaOptimizer(OpKernelConstruction* const context)
228 : OpKernel(context), options_(context) {}
229
Compute(OpKernelContext * context)230 void Compute(OpKernelContext* context) override {
231 DoCompute(options_, context);
232 }
233
234 private:
235 // TODO(sibyl-Aix6ihai): We could use the type-constraint on loss_type, and
236 // template the entire class to avoid the virtual table lookup penalty in
237 // the inner loop.
238 ComputeOptions options_;
239 };
240 REGISTER_KERNEL_BUILDER(Name("SdcaOptimizer").Device(DEVICE_CPU),
241 SdcaOptimizer);
242
243 class SdcaShrinkL1 : public OpKernel {
244 public:
SdcaShrinkL1(OpKernelConstruction * const context)245 explicit SdcaShrinkL1(OpKernelConstruction* const context)
246 : OpKernel(context) {
247 OP_REQUIRES_OK(context, regularizations_.Initialize(context));
248 }
249
Compute(OpKernelContext * context)250 void Compute(OpKernelContext* context) override {
251 OpMutableInputList weights_inputs;
252 OP_REQUIRES_OK(context,
253 context->mutable_input_list("weights", &weights_inputs));
254
255 auto do_work = [&](const int64 begin, const int64 end) {
256 for (int i = begin; i < end; ++i) {
257 auto prox_w = weights_inputs.at(i, /*lock_held=*/true).flat<float>();
258 prox_w.device(context->eigen_cpu_device()) =
259 regularizations_.EigenShrinkVector(prox_w);
260 }
261 };
262
263 if (weights_inputs.size() > 0) {
264 int64 num_weights = 0;
265 for (int i = 0; i < weights_inputs.size(); ++i) {
266 num_weights += weights_inputs.at(i, /*lock_held=*/true).NumElements();
267 }
268 // TODO(sibyl-Aix6ihai): Tune this value.
269 const int64 kCostPerUnit = (num_weights * 50) / weights_inputs.size();
270 const DeviceBase::CpuWorkerThreads& worker_threads =
271 *context->device()->tensorflow_cpu_worker_threads();
272 Shard(worker_threads.num_threads, worker_threads.workers,
273 weights_inputs.size(), kCostPerUnit, do_work);
274 }
275 }
276
277 private:
278 Regularizations regularizations_;
279 };
280 REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
281
282 // Computes platform independent, compact and unique (with very high
283 // probability) representation of an example id. It shouldn't be put in
284 // persistent storage, as its implementation may change in the future.
285 //
286 // The current probability of at least one collision for 1B example_ids is
287 // approximately 10^-21 (ie 2^60 / 2^129).
288 class SdcaFprint : public OpKernel {
289 public:
SdcaFprint(OpKernelConstruction * const context)290 explicit SdcaFprint(OpKernelConstruction* const context)
291 : OpKernel(context) {}
292
Compute(OpKernelContext * context)293 void Compute(OpKernelContext* context) override {
294 const Tensor& input = context->input(0);
295 OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
296 errors::InvalidArgument("Input must be a vector, got shape ",
297 input.shape().DebugString()));
298 Tensor* out;
299 const int64 num_elements = input.NumElements();
300 OP_REQUIRES_OK(context, context->allocate_output(
301 0, TensorShape({num_elements, 2}), &out));
302
303 const auto in_values = input.flat<string>();
304 auto out_values = out->matrix<int64>();
305
306 for (int64 i = 0; i < num_elements; ++i) {
307 const Fprint128 fprint = Fingerprint128(in_values(i));
308 // Never return 0 or 1 as the first value of the hash to allow these to
309 // safely be used as sentinel values (e.g. dense hash table empty key).
310 out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2)
311 ? fprint.low64
312 : fprint.low64 + ~static_cast<uint64>(1);
313 out_values(i, 1) = fprint.high64;
314 }
315 }
316 };
317 REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint);
318
319 } // namespace tensorflow
320