1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h"
16 
17 #include "tensorflow/core/framework/tensor_testutil.h"
18 #include "tensorflow/core/lib/core/status_test_util.h"
19 
20 namespace tensorflow {
21 namespace boosted_trees {
22 namespace testutil {
23 
24 using tensorflow::Tensor;
25 
RandomlyInitializeBatchFeatures(tensorflow::random::SimplePhilox * rng,uint32 num_dense_float_features,uint32 num_sparse_float_features,double sparsity_lo,double sparsity_hi,boosted_trees::utils::BatchFeatures * batch_features)26 void RandomlyInitializeBatchFeatures(
27     tensorflow::random::SimplePhilox* rng, uint32 num_dense_float_features,
28     uint32 num_sparse_float_features, double sparsity_lo, double sparsity_hi,
29     boosted_trees::utils::BatchFeatures* batch_features) {
30   const int64 batch_size = static_cast<int64>(batch_features->batch_size());
31 
32   // Populate dense features.
33   std::vector<tensorflow::Tensor> dense_float_features_list;
34   for (int i = 0; i < num_dense_float_features; ++i) {
35     std::vector<float> values;
36     for (int64 j = 0; j < batch_size; ++j) {
37       values.push_back(rng->RandFloat());
38     }
39     auto dense_tensor = Tensor(tensorflow::DT_FLOAT, {batch_size, 1});
40     tensorflow::test::FillValues<float>(&dense_tensor, values);
41     dense_float_features_list.push_back(dense_tensor);
42   }
43 
44   // Populate sparse features.
45   std::vector<tensorflow::Tensor> sparse_float_feature_indices_list;
46   std::vector<tensorflow::Tensor> sparse_float_feature_values_list;
47   std::vector<tensorflow::Tensor> sparse_float_feature_shapes_list;
48   for (int i = 0; i < num_sparse_float_features; ++i) {
49     std::set<uint64> indices;
50     const double sparsity =
51         sparsity_lo + rng->RandDouble() * (sparsity_hi - sparsity_lo);
52     const double density = 1 - sparsity;
53     for (int64 k = 0; k < static_cast<int64>(density * batch_size) + 1; ++k) {
54       indices.insert(rng->Uniform64(batch_size));
55     }
56     const int64 sparse_values_size = indices.size();
57     std::vector<int64> indices_vector;
58     for (auto idx : indices) {
59       indices_vector.push_back(idx);
60       indices_vector.push_back(0);
61     }
62     auto indices_tensor = Tensor(tensorflow::DT_INT64, {sparse_values_size, 2});
63     tensorflow::test::FillValues<int64>(&indices_tensor, indices_vector);
64     sparse_float_feature_indices_list.push_back(indices_tensor);
65 
66     std::vector<float> values;
67     for (int64 j = 0; j < sparse_values_size; ++j) {
68       values.push_back(rng->RandFloat());
69     }
70     auto values_tensor = Tensor(tensorflow::DT_FLOAT, {sparse_values_size});
71     tensorflow::test::FillValues<float>(&values_tensor, values);
72     sparse_float_feature_values_list.push_back(values_tensor);
73 
74     auto shape_tensor = Tensor(tensorflow::DT_INT64, {2});
75     tensorflow::test::FillValues<int64>(&shape_tensor, {batch_size, 1});
76     sparse_float_feature_shapes_list.push_back(shape_tensor);
77   }
78 
79   // TODO(salehay): Add categorical feature generation support.
80   TF_EXPECT_OK(batch_features->Initialize(
81       dense_float_features_list, sparse_float_feature_indices_list,
82       sparse_float_feature_values_list, sparse_float_feature_shapes_list, {},
83       {}, {}));
84 }
85 
86 }  // namespace testutil
87 }  // namespace boosted_trees
88 }  // namespace tensorflow
89