1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/kernel_def_builder.h"
17 #include "tensorflow/core/framework/attr_value.pb.h"
18 #include "tensorflow/core/framework/kernel_def.pb.h"
19 
20 namespace tensorflow {
21 
KernelDefBuilder(const char * op_name)22 KernelDefBuilder::KernelDefBuilder(const char* op_name) {
23   kernel_def_ = new KernelDef;
24   kernel_def_->set_op(op_name);
25 }
26 
~KernelDefBuilder()27 KernelDefBuilder::~KernelDefBuilder() {
28   DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
29 }
30 
Device(const char * device_type)31 KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
32   kernel_def_->set_device_type(device_type);
33   return *this;
34 }
35 
36 template <>
AttrConstraint(const char * attr_name,gtl::ArraySlice<int64> allowed)37 KernelDefBuilder& KernelDefBuilder::AttrConstraint<int64>(
38     const char* attr_name, gtl::ArraySlice<int64> allowed) {
39   auto* constraint = kernel_def_->add_constraint();
40   constraint->set_name(attr_name);
41   auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
42   for (const int64 integer : allowed) {
43     LOG(INFO) << integer;
44     allowed_values->add_i(integer);
45   }
46   return *this;
47 }
48 
49 template <>
AttrConstraint(const char * attr_name,int64 allowed)50 KernelDefBuilder& KernelDefBuilder::AttrConstraint<int64>(const char* attr_name,
51                                                           int64 allowed) {
52   return AttrConstraint(
53       attr_name,
54       gtl::ArraySlice<int64>(std::initializer_list<int64>({allowed})));
55 }
56 
57 template <>
AttrConstraint(const char * attr_name,gtl::ArraySlice<string> allowed)58 KernelDefBuilder& KernelDefBuilder::AttrConstraint<string>(
59     const char* attr_name, gtl::ArraySlice<string> allowed) {
60   auto* constraint = kernel_def_->add_constraint();
61   constraint->set_name(attr_name);
62   auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
63   for (const auto& str : allowed) {
64     allowed_values->add_s(str);
65   }
66   return *this;
67 }
68 
69 template <>
AttrConstraint(const char * attr_name,string allowed)70 KernelDefBuilder& KernelDefBuilder::AttrConstraint<string>(
71     const char* attr_name, string allowed) {
72   return AttrConstraint(
73       attr_name,
74       gtl::ArraySlice<string>(std::initializer_list<string>({allowed})));
75 }
76 
77 template <>
AttrConstraint(const char * attr_name,gtl::ArraySlice<const char * > allowed)78 KernelDefBuilder& KernelDefBuilder::AttrConstraint<const char*>(
79     const char* attr_name, gtl::ArraySlice<const char*> allowed) {
80   auto* constraint = kernel_def_->add_constraint();
81   constraint->set_name(attr_name);
82   auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
83   for (const auto& str : allowed) {
84     allowed_values->add_s(str);
85   }
86   return *this;
87 }
88 
89 template <>
AttrConstraint(const char * attr_name,const char * allowed)90 KernelDefBuilder& KernelDefBuilder::AttrConstraint<const char*>(
91     const char* attr_name, const char* allowed) {
92   return AttrConstraint(attr_name,
93                         gtl::ArraySlice<const char*>(
94                             std::initializer_list<const char*>({allowed})));
95 }
96 
97 template <>
AttrConstraint(const char * attr_name,bool allowed)98 KernelDefBuilder& KernelDefBuilder::AttrConstraint<bool>(const char* attr_name,
99                                                          bool allowed) {
100   auto* constraint = kernel_def_->add_constraint();
101   constraint->set_name(attr_name);
102   auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
103   allowed_values->add_b(allowed);
104   return *this;
105 }
106 
TypeConstraint(const char * attr_name,gtl::ArraySlice<DataType> allowed)107 KernelDefBuilder& KernelDefBuilder::TypeConstraint(
108     const char* attr_name, gtl::ArraySlice<DataType> allowed) {
109   auto* constraint = kernel_def_->add_constraint();
110   constraint->set_name(attr_name);
111   auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
112   for (DataType dt : allowed) {
113     allowed_values->add_type(dt);
114   }
115   return *this;
116 }
117 
TypeConstraint(const char * attr_name,DataType allowed)118 KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name,
119                                                    DataType allowed) {
120   auto* constraint = kernel_def_->add_constraint();
121   constraint->set_name(attr_name);
122   constraint->mutable_allowed_values()->mutable_list()->add_type(allowed);
123   return *this;
124 }
125 
HostMemory(const char * arg_name)126 KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) {
127   kernel_def_->add_host_memory_arg(arg_name);
128   return *this;
129 }
130 
Label(const char * label)131 KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
132   CHECK_EQ(kernel_def_->label(), "")
133       << "Trying to set a kernel's label a second time: '" << label
134       << "' in: " << kernel_def_->DebugString();
135   kernel_def_->set_label(label);
136   return *this;
137 }
138 
Priority(int32 priority)139 KernelDefBuilder& KernelDefBuilder::Priority(int32 priority) {
140   kernel_def_->set_priority(priority);
141   return *this;
142 }
143 
Build()144 const KernelDef* KernelDefBuilder::Build() {
145   KernelDef* r = kernel_def_;
146   kernel_def_ = nullptr;
147   return r;
148 }
149 
150 }  // namespace tensorflow
151