1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/kernel_def_util.h"
17 
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/attr_value_util.h"
20 #include "tensorflow/core/framework/kernel_def.pb_text.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/types.h"
23 
24 namespace tensorflow {
25 
26 namespace {
27 // Helper for KernelAttrsMatch().
InTypeList(DataType dt,const AttrValue & type_list)28 bool InTypeList(DataType dt, const AttrValue& type_list) {
29   for (int in_list : type_list.list().type()) {
30     if (dt == in_list) return true;
31   }
32   return false;
33 }
34 }  // namespace
35 
KernelAttrsMatch(const KernelDef & kernel_def,AttrSlice attrs,bool * match)36 Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
37                         bool* match) {
38   *match = false;
39   for (const auto& constraint : kernel_def.constraint()) {
40     if (constraint.allowed_values().list().type_size() == 0) {
41       return errors::Unimplemented(
42           "KernelDef '", ProtoShortDebugString(kernel_def),
43           " has constraint on attr '", constraint.name(),
44           "' with unsupported type: ",
45           SummarizeAttrValue(constraint.allowed_values()));
46     }
47 
48     const AttrValue* found = attrs.Find(constraint.name());
49     if (found) {
50       if (found->type() != DT_INVALID) {
51         if (!InTypeList(found->type(), constraint.allowed_values())) {
52           return Status::OK();
53         }
54       } else {
55         if (!AttrValueHasType(*found, "list(type)").ok()) {
56           return errors::InvalidArgument(
57               "KernelDef '", ProtoShortDebugString(kernel_def),
58               "' has constraint on attr '", constraint.name(),
59               "' that has value '", SummarizeAttrValue(*found),
60               "' that does not have type 'type' or 'list(type)' in NodeDef "
61               "'",
62               attrs.SummarizeNode(), "'");
63         }
64 
65         for (int t : found->list().type()) {
66           if (!InTypeList(static_cast<DataType>(t),
67                           constraint.allowed_values())) {
68             return Status::OK();
69           }
70         }
71       }
72     } else {
73       return errors::InvalidArgument(
74           "OpKernel '", kernel_def.op(), "' has constraint on attr '",
75           constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
76           "', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
77     }
78   }
79   *match = true;
80   return Status::OK();
81 }
82 
83 }  // namespace tensorflow
84