1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/kernel_def_util.h"
17 
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/attr_value_util.h"
20 #include "tensorflow/core/framework/kernel_def.pb.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/types.h"
23 
24 namespace tensorflow {
25 
26 namespace {
27 // Helper for KernelAttrsMatch().
InTypeList(DataType dt,const AttrValue & type_list)28 bool InTypeList(DataType dt, const AttrValue& type_list) {
29   for (int in_list : type_list.list().type()) {
30     if (dt == in_list) return true;
31   }
32   return false;
33 }
34 }  // namespace
35 
KernelAttrsMatch(const KernelDef & kernel_def,AttrSlice attrs,bool * match)36 Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
37                         bool* match) {
38   *match = false;
39   for (const auto& constraint : kernel_def.constraint()) {
40     auto constraint_value_case = AttrValue::VALUE_NOT_SET;
41     int value_type_num = 0;
42     if (constraint.allowed_values().list().type_size() > 0) {
43       constraint_value_case = AttrValue::kType;
44       value_type_num++;
45     }
46     if (constraint.allowed_values().list().s_size() > 0) {
47       constraint_value_case = AttrValue::kS;
48       value_type_num++;
49     }
50     if (constraint.allowed_values().list().i_size() > 0) {
51       constraint_value_case = AttrValue::kI;
52       value_type_num++;
53     }
54     if (constraint.allowed_values().list().b_size() > 0) {
55       constraint_value_case = AttrValue::kB;
56       value_type_num++;
57     }
58 
59     if (value_type_num == 0) {
60       return errors::Unimplemented(
61           "KernelDef '", kernel_def.ShortDebugString(),
62           " has constraint on attr '", constraint.name(),
63           "' with unsupported type: ",
64           SummarizeAttrValue(constraint.allowed_values()));
65     }
66     if (value_type_num > 1) {
67       return errors::InvalidArgument(
68           "KernelDef '", kernel_def.ShortDebugString(),
69           " has constraint on attr '", constraint.name(),
70           "' with more than one value type: ",
71           SummarizeAttrValue(constraint.allowed_values()));
72     }
73 
74     const AttrValue* attr_value = attrs.Find(constraint.name());
75     if (attr_value == nullptr) {
76       return errors::InvalidArgument(
77           "OpKernel '", kernel_def.op(), "' has constraint on attr '",
78           constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
79           "', KernelDef: '", kernel_def.ShortDebugString(), "'");
80     }
81 
82 #define RETURN_IF_ATTR_NOT_FOUND(n, oneof_case, type_str)          \
83   do {                                                             \
84     if (constraint_value_case == AttrValue::oneof_case) {          \
85       Status s = AttrValueHasType(*attr_value, type_str);          \
86       if (!s.ok()) {                                               \
87         return errors::InvalidArgument(                            \
88             "KernelDef '", kernel_def.ShortDebugString(),          \
89             "' has constraint on attr '", constraint.name(),       \
90             "' that has value '", SummarizeAttrValue(*attr_value), \
91             "' that does not have the same type in NodeDef "       \
92             "'",                                                   \
93             attrs.SummarizeNode(), "'");                           \
94       }                                                            \
95       bool found = false;                                          \
96       for (auto& value : constraint.allowed_values().list().n()) { \
97         if (value == attr_value->n()) {                            \
98           found = true;                                            \
99           break;                                                   \
100         }                                                          \
101       }                                                            \
102       if (!found) {                                                \
103         return Status::OK();                                       \
104       }                                                            \
105     }                                                              \
106   } while (false)
107 
108     RETURN_IF_ATTR_NOT_FOUND(s, kS, "string");
109     RETURN_IF_ATTR_NOT_FOUND(i, kI, "int");
110     RETURN_IF_ATTR_NOT_FOUND(b, kB, "bool");
111 
112 #undef RETURN_IF_ATTR_NOT_FOUND
113 
114     if (constraint_value_case != AttrValue::kType) {
115       continue;
116     }
117 
118     if (attr_value->type() != DT_INVALID) {
119       if (!InTypeList(attr_value->type(), constraint.allowed_values())) {
120         return Status::OK();
121       }
122     } else {
123       if (!AttrValueHasType(*attr_value, "list(type)").ok()) {
124         return errors::InvalidArgument(
125             "KernelDef '", kernel_def.ShortDebugString(),
126             "' has constraint on attr '", constraint.name(),
127             "' that has value '", SummarizeAttrValue(*attr_value),
128             "' that does not have type 'type' or 'list(type)' in NodeDef "
129             "'",
130             attrs.SummarizeNode(), "'");
131       }
132 
133       for (int t : attr_value->list().type()) {
134         if (!InTypeList(static_cast<DataType>(t),
135                         constraint.allowed_values())) {
136           return Status::OK();
137         }
138       }
139     }
140   }
141   *match = true;
142   return Status::OK();
143 }
144 
145 }  // namespace tensorflow
146