1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op_kernel.h" 17 #include "tensorflow/core/framework/tensor.h" 18 #include "tensorflow/core/framework/tensor_shape.h" 19 #include "tensorflow/core/framework/types.h" 20 #include "tensorflow/core/lib/core/errors.h" 21 #include "tensorflow/core/lib/strings/numbers.h" 22 #include "tensorflow/core/lib/strings/str_util.h" 23 24 namespace tensorflow { 25 26 template <typename T, typename Tlabel> 27 class DecodeLibsvmOp : public OpKernel { 28 public: DecodeLibsvmOp(OpKernelConstruction * ctx)29 explicit DecodeLibsvmOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 30 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_features", &num_features_)); 31 OP_REQUIRES(ctx, (num_features_ >= 1), 32 errors::InvalidArgument("Invalid number of features \"", 33 num_features_, "\"")); 34 } 35 Compute(OpKernelContext * ctx)36 void Compute(OpKernelContext* ctx) override { 37 const Tensor* input_tensor; 38 OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); 39 const auto& input_flat = input_tensor->flat<string>(); 40 41 Tensor* label_tensor; 42 OP_REQUIRES_OK( 43 ctx, ctx->allocate_output(0, input_tensor->shape(), &label_tensor)); 44 auto label = label_tensor->flat<Tlabel>(); 45 46 std::vector<T> out_values; 47 std::vector<std::pair<int64, int64>> out_indices; 48 for (int i = 0; i < input_flat.size(); ++i) { 49 StringPiece line(input_flat(i)); 50 str_util::RemoveWhitespaceContext(&line); 51 52 StringPiece piece; 53 OP_REQUIRES(ctx, str_util::ConsumeNonWhitespace(&line, &piece), 54 errors::InvalidArgument("No label found for input[", i, 55 "]: \"", input_flat(i), "\"")); 56 57 Tlabel label_value; 58 OP_REQUIRES(ctx, 59 strings::SafeStringToNumeric<Tlabel>(piece, &label_value), 60 errors::InvalidArgument("Label format incorrect: ", piece)); 61 62 label(i) = label_value; 63 64 str_util::RemoveLeadingWhitespace(&line); 65 while (str_util::ConsumeNonWhitespace(&line, &piece)) { 66 size_t p = piece.find(':'); 67 OP_REQUIRES(ctx, (p != StringPiece::npos), 68 errors::InvalidArgument("Invalid feature \"", piece, "\"")); 69 70 int64 feature_index; 71 OP_REQUIRES( 72 ctx, strings::safe_strto64(piece.substr(0, p), &feature_index), 73 errors::InvalidArgument("Feature format incorrect: ", piece)); 74 OP_REQUIRES(ctx, (feature_index >= 0), 75 errors::InvalidArgument( 76 "Feature index should be >= 0, got ", feature_index)); 77 78 T feature_value; 79 OP_REQUIRES( 80 81 ctx, 82 strings::SafeStringToNumeric<T>(piece.substr(p + 1), 83 &feature_value), 84 errors::InvalidArgument("Feature format incorrect: ", piece)); 85 86 out_values.emplace_back(feature_value); 87 out_indices.emplace_back(std::pair<int64, int64>(i, feature_index)); 88 89 str_util::RemoveLeadingWhitespace(&line); 90 } 91 } 92 93 Tensor* indices_tensor; 94 OP_REQUIRES_OK(ctx, ctx->allocate_output( 95 1, 96 TensorShape({static_cast<int64>(out_indices.size()), 97 input_tensor->shape().dims() + 1}), 98 &indices_tensor)); 99 auto indices = indices_tensor->matrix<int64>(); 100 // Translate flat index to shaped index like np.unravel_index 101 // Calculate factors for each dimension 102 std::vector<int64> factors(input_tensor->shape().dims()); 103 factors[input_tensor->shape().dims() - 1] = 1; 104 for (int j = input_tensor->shape().dims() - 2; j >= 0; j--) { 105 factors[j] = factors[j + 1] * input_tensor->shape().dim_size(j + 1); 106 } 107 for (int i = 0; i < out_indices.size(); i++) { 108 indices(i, 0) = out_indices[i].first; 109 int64 value = out_indices[i].first; 110 for (int j = 0; j < input_tensor->shape().dims(); j++) { 111 indices(i, j) = value / factors[j]; 112 value = value % factors[j]; 113 } 114 indices(i, input_tensor->shape().dims()) = out_indices[i].second; 115 } 116 117 Tensor* values_tensor; 118 OP_REQUIRES_OK(ctx, 119 ctx->allocate_output( 120 2, TensorShape({static_cast<int64>(out_values.size())}), 121 &values_tensor)); 122 auto values = values_tensor->vec<T>(); 123 std::copy_n(out_values.begin(), out_values.size(), &values(0)); 124 125 Tensor* shape_tensor; 126 OP_REQUIRES_OK(ctx, ctx->allocate_output( 127 3, TensorShape({input_tensor->shape().dims() + 1}), 128 &shape_tensor)); 129 auto shape = shape_tensor->flat<int64>(); 130 for (int i = 0; i < input_tensor->shape().dims(); i++) { 131 shape(i) = input_tensor->shape().dim_size(i); 132 } 133 shape(input_tensor->shape().dims()) = num_features_; 134 } 135 136 private: 137 int64 num_features_; 138 }; 139 140 #define REGISTER_KERNEL(type) \ 141 REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ 142 .Device(DEVICE_CPU) \ 143 .TypeConstraint<type>("dtype") \ 144 .TypeConstraint<int32>("label_dtype"), \ 145 DecodeLibsvmOp<type, int32>); \ 146 REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ 147 .Device(DEVICE_CPU) \ 148 .TypeConstraint<type>("dtype") \ 149 .TypeConstraint<int64>("label_dtype"), \ 150 DecodeLibsvmOp<type, int64>); \ 151 REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ 152 .Device(DEVICE_CPU) \ 153 .TypeConstraint<type>("dtype") \ 154 .TypeConstraint<float>("label_dtype"), \ 155 DecodeLibsvmOp<type, float>); \ 156 REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ 157 .Device(DEVICE_CPU) \ 158 .TypeConstraint<type>("dtype") \ 159 .TypeConstraint<double>("label_dtype"), \ 160 DecodeLibsvmOp<type, double>); 161 162 REGISTER_KERNEL(float); 163 REGISTER_KERNEL(double); 164 REGISTER_KERNEL(int32); 165 REGISTER_KERNEL(int64); 166 #undef REGISTER_KERNEL 167 168 } // namespace tensorflow 169