1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 
16 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_
17 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 
23 namespace tensorflow {
24 namespace boosted_trees {
25 namespace utils {
26 
27 class TensorUtils {
28  public:
29   // Read an input list into a vector of tensors.
30   static std::vector<Tensor> OpInputListToTensorVec(
31       const OpInputList& input_list);
32 
33   // Reads the dense float features input list.
34   static Status ReadDenseFloatFeatures(OpKernelContext* const context,
35                                        OpInputList* features_list);
36 
37   // Reads the sparse float features input list.
38   static Status ReadSparseFloatFeatures(OpKernelContext* const context,
39                                         OpInputList* features_indices_list,
40                                         OpInputList* feature_values_list,
41                                         OpInputList* feature_shapes_list);
42 
43   // Reads the sparse int features input list.
44   static Status ReadSparseIntFeatures(OpKernelContext* const context,
45                                       OpInputList* features_indices_list,
46                                       OpInputList* feature_values_list,
47                                       OpInputList* feature_shapes_list);
48 
49   // Infers the batch size by looking at the op input features.
50   static int64 InferBatchSize(
51       const OpInputList& dense_float_features_list,
52       const OpInputList& sparse_float_feature_shapes_list,
53       const OpInputList& sparse_int_feature_shapes_list);
54 };
55 
56 }  // namespace utils
57 }  // namespace boosted_trees
58 }  // namespace tensorflow
59 
60 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_
61