1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/attr_value_util.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/core/framework/attr_value.pb_text.h"
22 #include "tensorflow/core/framework/tensor.pb_text.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb_text.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/stringpiece.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/protobuf.h"
31
32 namespace tensorflow {
33 namespace {
34
SummarizeString(const string & str)35 string SummarizeString(const string& str) {
36 string escaped = str_util::CEscape(str);
37
38 // If the string is long, replace the middle with ellipses.
39 constexpr int kMaxStringSummarySize = 80;
40 if (escaped.size() >= kMaxStringSummarySize) {
41 StringPiece prefix(escaped);
42 StringPiece suffix = prefix;
43 prefix.remove_suffix(escaped.size() - 10);
44 suffix.remove_prefix(escaped.size() - 10);
45 return strings::StrCat("\"", prefix, "...", suffix, "\"");
46 } else {
47 return strings::StrCat("\"", escaped, "\"");
48 }
49 }
50
SummarizeTensor(const TensorProto & tensor_proto)51 string SummarizeTensor(const TensorProto& tensor_proto) {
52 Tensor t;
53 if (!t.FromProto(tensor_proto)) {
54 return strings::StrCat(
55 "<Invalid TensorProto: ", ProtoShortDebugString(tensor_proto), ">");
56 }
57 return t.DebugString();
58 }
59
SummarizeFunc(const NameAttrList & func)60 string SummarizeFunc(const NameAttrList& func) {
61 std::vector<string> entries;
62 for (auto p : func.attr()) {
63 entries.push_back(
64 strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
65 }
66 std::sort(entries.begin(), entries.end());
67 return strings::StrCat(func.name(), "[", str_util::Join(entries, ", "), "]");
68 }
69
70 } // namespace
71
SummarizeAttrValue(const AttrValue & attr_value)72 string SummarizeAttrValue(const AttrValue& attr_value) {
73 switch (attr_value.value_case()) {
74 case AttrValue::kS:
75 return SummarizeString(attr_value.s());
76 case AttrValue::kI:
77 return strings::StrCat(attr_value.i());
78 case AttrValue::kF:
79 return strings::StrCat(attr_value.f());
80 case AttrValue::kB:
81 return attr_value.b() ? "true" : "false";
82 case AttrValue::kType:
83 return EnumName_DataType(attr_value.type());
84 case AttrValue::kShape:
85 return PartialTensorShape::DebugString(attr_value.shape());
86 case AttrValue::kTensor:
87 return SummarizeTensor(attr_value.tensor());
88 case AttrValue::kList: {
89 std::vector<string> pieces;
90 if (attr_value.list().s_size() > 0) {
91 for (int i = 0; i < attr_value.list().s_size(); ++i) {
92 pieces.push_back(SummarizeString(attr_value.list().s(i)));
93 }
94 } else if (attr_value.list().i_size() > 0) {
95 for (int i = 0; i < attr_value.list().i_size(); ++i) {
96 pieces.push_back(strings::StrCat(attr_value.list().i(i)));
97 }
98 } else if (attr_value.list().f_size() > 0) {
99 for (int i = 0; i < attr_value.list().f_size(); ++i) {
100 pieces.push_back(strings::StrCat(attr_value.list().f(i)));
101 }
102 } else if (attr_value.list().b_size() > 0) {
103 for (int i = 0; i < attr_value.list().b_size(); ++i) {
104 pieces.push_back(attr_value.list().b(i) ? "true" : "false");
105 }
106 } else if (attr_value.list().type_size() > 0) {
107 for (int i = 0; i < attr_value.list().type_size(); ++i) {
108 pieces.push_back(EnumName_DataType(attr_value.list().type(i)));
109 }
110 } else if (attr_value.list().shape_size() > 0) {
111 for (int i = 0; i < attr_value.list().shape_size(); ++i) {
112 pieces.push_back(
113 TensorShape::DebugString(attr_value.list().shape(i)));
114 }
115 } else if (attr_value.list().tensor_size() > 0) {
116 for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
117 pieces.push_back(SummarizeTensor(attr_value.list().tensor(i)));
118 }
119 } else if (attr_value.list().func_size() > 0) {
120 for (int i = 0; i < attr_value.list().func_size(); ++i) {
121 pieces.push_back(SummarizeFunc(attr_value.list().func(i)));
122 }
123 }
124 constexpr int kMaxListSummarySize = 15;
125 if (pieces.size() >= kMaxListSummarySize) {
126 pieces.erase(pieces.begin() + 5, pieces.begin() + (pieces.size() - 6));
127 pieces[5] = "...";
128 }
129 return strings::StrCat("[", str_util::Join(pieces, ", "), "]");
130 }
131 case AttrValue::kFunc: {
132 return SummarizeFunc(attr_value.func());
133 }
134 case AttrValue::kPlaceholder:
135 return strings::StrCat("$", attr_value.placeholder());
136 case AttrValue::VALUE_NOT_SET:
137 return "<Unknown AttrValue type>";
138 }
139 return "<Unknown AttrValue type>"; // Prevent missing return warning
140 }
141
AttrValueHasType(const AttrValue & attr_value,StringPiece type)142 Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
143 int num_set = 0;
144
145 #define VALIDATE_FIELD(name, type_string, oneof_case) \
146 do { \
147 if (attr_value.has_list()) { \
148 if (attr_value.list().name##_size() > 0) { \
149 if (type != "list(" type_string ")") { \
150 return errors::InvalidArgument( \
151 "AttrValue had value with type 'list(" type_string ")' when '", \
152 type, "' expected"); \
153 } \
154 ++num_set; \
155 } \
156 } else if (attr_value.value_case() == AttrValue::oneof_case) { \
157 if (type != type_string) { \
158 return errors::InvalidArgument( \
159 "AttrValue had value with type '" type_string "' when '", type, \
160 "' expected"); \
161 } \
162 ++num_set; \
163 } \
164 } while (false)
165
166 VALIDATE_FIELD(s, "string", kS);
167 VALIDATE_FIELD(i, "int", kI);
168 VALIDATE_FIELD(f, "float", kF);
169 VALIDATE_FIELD(b, "bool", kB);
170 VALIDATE_FIELD(type, "type", kType);
171 VALIDATE_FIELD(shape, "shape", kShape);
172 VALIDATE_FIELD(tensor, "tensor", kTensor);
173 VALIDATE_FIELD(func, "func", kFunc);
174
175 #undef VALIDATE_FIELD
176
177 if (attr_value.value_case() == AttrValue::kPlaceholder) {
178 return errors::InvalidArgument(
179 "AttrValue had value with unexpected type 'placeholder'");
180 }
181
182 // If the attr type is 'list', we expect attr_value.has_list() to be
183 // true. However, proto3's attr_value.has_list() can be false when
184 // set to an empty list for GraphDef versions <= 4. So we simply
185 // check if has_list is false and some other field in attr_value is
186 // set to flag the error. This test can be made more strict once
187 // support for GraphDef versions <= 4 is dropped.
188 if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) {
189 if (num_set) {
190 return errors::InvalidArgument(
191 "AttrValue missing value with expected type '", type, "'");
192 } else {
193 // Indicate that we have a list, but an empty one.
194 ++num_set;
195 }
196 }
197
198 // Okay to have an empty list, but not to be missing a non-list value.
199 if (num_set == 0 && !StringPiece(type).starts_with("list(")) {
200 return errors::InvalidArgument(
201 "AttrValue missing value with expected type '", type, "'");
202 }
203
204 // Ref types and DT_INVALID are illegal, and DataTypes must
205 // be a valid enum type.
206 if (type == "type") {
207 if (!DataType_IsValid(attr_value.type())) {
208 return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
209 attr_value.type());
210 }
211 if (IsRefType(attr_value.type())) {
212 return errors::InvalidArgument(
213 "AttrValue must not have reference type value of ",
214 DataTypeString(attr_value.type()));
215 }
216 if (attr_value.type() == DT_INVALID) {
217 return errors::InvalidArgument("AttrValue has invalid DataType");
218 }
219 } else if (type == "list(type)") {
220 for (auto as_int : attr_value.list().type()) {
221 const DataType dtype = static_cast<DataType>(as_int);
222 if (!DataType_IsValid(dtype)) {
223 return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
224 as_int);
225 }
226 if (IsRefType(dtype)) {
227 return errors::InvalidArgument(
228 "AttrValue must not have reference type value of ",
229 DataTypeString(dtype));
230 }
231 if (dtype == DT_INVALID) {
232 return errors::InvalidArgument("AttrValue contains invalid DataType");
233 }
234 }
235 }
236
237 return Status::OK();
238 }
239
ParseAttrValue(StringPiece type,StringPiece text,AttrValue * out)240 bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
241 // Parse type.
242 string field_name;
243 bool is_list = type.Consume("list(");
244 if (type.Consume("string")) {
245 field_name = "s";
246 } else if (type.Consume("int")) {
247 field_name = "i";
248 } else if (type.Consume("float")) {
249 field_name = "f";
250 } else if (type.Consume("bool")) {
251 field_name = "b";
252 } else if (type.Consume("type")) {
253 field_name = "type";
254 } else if (type.Consume("shape")) {
255 field_name = "shape";
256 } else if (type.Consume("tensor")) {
257 field_name = "tensor";
258 } else if (type.Consume("func")) {
259 field_name = "func";
260 } else if (type.Consume("placeholder")) {
261 field_name = "placeholder";
262 } else {
263 return false;
264 }
265 if (is_list && !type.Consume(")")) {
266 return false;
267 }
268
269 // Construct a valid text proto message to parse.
270 string to_parse;
271 if (is_list) {
272 // TextFormat parser considers "i: 7" to be the same as "i: [7]",
273 // but we only want to allow list values with [].
274 StringPiece cleaned = text;
275 str_util::RemoveLeadingWhitespace(&cleaned);
276 str_util::RemoveTrailingWhitespace(&cleaned);
277 if (cleaned.size() < 2 || cleaned[0] != '[' ||
278 cleaned[cleaned.size() - 1] != ']') {
279 return false;
280 }
281 cleaned.remove_prefix(1);
282 str_util::RemoveLeadingWhitespace(&cleaned);
283 if (cleaned.size() == 1) {
284 // User wrote "[]", so return empty list without invoking the TextFormat
285 // parse which returns an error for "i: []".
286 out->Clear();
287 out->mutable_list();
288 return true;
289 }
290 to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
291 } else {
292 to_parse = strings::StrCat(field_name, ": ", text);
293 }
294
295 return ProtoParseFromString(to_parse, out);
296 }
297
SetAttrValue(const AttrValue & value,AttrValue * out)298 void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
299
300 #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
301 void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
302
303 #define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \
304 void SetAttrValue(ARG_TYPE value, AttrValue* out) { \
305 out->mutable_list()->Clear(); /* create list() even if value empty */ \
306 for (const auto& v : value) { \
307 out->mutable_list()->add_##FIELD(v); \
308 } \
309 }
310
311 #define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
312 DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
313 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
314
DEFINE_SET_ATTR_VALUE_ONE(const string &,s)315 DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
316 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
317 DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
318 DEFINE_SET_ATTR_VALUE_BOTH(int64, i)
319 DEFINE_SET_ATTR_VALUE_BOTH(int32, i)
320 DEFINE_SET_ATTR_VALUE_BOTH(float, f)
321 DEFINE_SET_ATTR_VALUE_BOTH(double, f)
322 DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
323 DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
324 DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
325 DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
326
327 void SetAttrValue(StringPiece value, AttrValue* out) {
328 out->set_s(value.data(), value.size());
329 }
330
SetAttrValue(const gtl::ArraySlice<StringPiece> value,AttrValue * out)331 void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) {
332 out->mutable_list()->Clear(); // Create list() even if value empty.
333 for (const auto& v : value) {
334 out->mutable_list()->add_s(v.data(), v.size());
335 }
336 }
337
SetAttrValue(const TensorShape & value,AttrValue * out)338 void SetAttrValue(const TensorShape& value, AttrValue* out) {
339 value.AsProto(out->mutable_shape());
340 }
341
SetAttrValue(const TensorShapeProto & value,AttrValue * out)342 void SetAttrValue(const TensorShapeProto& value, AttrValue* out) {
343 *out->mutable_shape() = value;
344 }
345
SetAttrValue(const PartialTensorShape & value,AttrValue * out)346 void SetAttrValue(const PartialTensorShape& value, AttrValue* out) {
347 value.AsProto(out->mutable_shape());
348 }
349
SetAttrValue(const gtl::ArraySlice<TensorShape> value,AttrValue * out)350 void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
351 out->mutable_list()->Clear(); // Create list() even if value empty.
352 for (const auto& v : value) {
353 v.AsProto(out->mutable_list()->add_shape());
354 }
355 }
356
SetAttrValue(gtl::ArraySlice<TensorShapeProto> value,AttrValue * out)357 void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) {
358 out->mutable_list()->Clear(); // Create list() even if value empty.
359 for (const auto& v : value) {
360 *out->mutable_list()->add_shape() = v;
361 }
362 }
363
SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,AttrValue * out)364 void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,
365 AttrValue* out) {
366 out->mutable_list()->Clear(); // Create list() even if value empty.
367 for (const auto& v : value) {
368 v.AsProto(out->mutable_list()->add_shape());
369 }
370 }
371
SetAttrValue(const Tensor & value,AttrValue * out)372 void SetAttrValue(const Tensor& value, AttrValue* out) {
373 if (value.NumElements() > 1) {
374 value.AsProtoTensorContent(out->mutable_tensor());
375 } else {
376 value.AsProtoField(out->mutable_tensor());
377 }
378 }
379
SetAttrValue(const gtl::ArraySlice<Tensor> value,AttrValue * out)380 void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
381 out->mutable_list()->Clear(); // Create list() even if value empty.
382 for (const auto& v : value) {
383 if (v.NumElements() > 1) {
384 v.AsProtoTensorContent(out->mutable_list()->add_tensor());
385 } else {
386 v.AsProtoField(out->mutable_list()->add_tensor());
387 }
388 }
389 }
390
SetAttrValue(const TensorProto & value,AttrValue * out)391 void SetAttrValue(const TensorProto& value, AttrValue* out) {
392 *out->mutable_tensor() = value;
393 }
394
SetAttrValue(const gtl::ArraySlice<TensorProto> value,AttrValue * out)395 void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
396 out->mutable_list()->Clear(); // Create list() even if value empty.
397 for (const auto& v : value) {
398 *out->mutable_list()->add_tensor() = v;
399 }
400 }
401
SetAttrValue(const NameAttrList & value,AttrValue * out)402 void SetAttrValue(const NameAttrList& value, AttrValue* out) {
403 *out->mutable_func() = value;
404 }
405
SetAttrValue(gtl::ArraySlice<NameAttrList> value,AttrValue * out)406 void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
407 out->mutable_list()->Clear(); // Create list() even if value empty.
408 for (const auto& v : value) {
409 *out->mutable_list()->add_func() = v;
410 }
411 }
412
AreAttrValuesEqual(const AttrValue & a,const AttrValue & b)413 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
414 // There are multiple equivalent representations of attr values containing
415 // TensorProtos. Compare them by constructing Tensors and serializing them
416 // back. Comparing Tensor objects is pretty tricky.
417 if (a.has_tensor() != b.has_tensor()) {
418 return false;
419 } else if (a.has_tensor() && b.has_tensor()) {
420 Tensor at(a.tensor().dtype());
421 bool success = at.FromProto(a.tensor());
422 DCHECK(success);
423
424 Tensor bt(b.tensor().dtype());
425 success = bt.FromProto(b.tensor());
426 DCHECK(success);
427
428 TensorProto ap;
429 at.AsProtoTensorContent(&ap);
430
431 TensorProto bp;
432 bt.AsProtoTensorContent(&bp);
433
434 string a_str, b_str;
435 SerializeToStringDeterministic(ap, &a_str);
436 SerializeToStringDeterministic(bp, &b_str);
437 return a_str == b_str;
438 }
439
440 // `func` field contains a nested AttrValue. Compare such AttrValues
441 // recursively.
442 if (a.has_func() != b.has_func()) {
443 return false;
444 } else if (a.has_func() && b.has_func()) {
445 const NameAttrList& af = a.func();
446 const NameAttrList& bf = b.func();
447 if (af.name() != bf.name()) return false;
448 std::unordered_map<string, AttrValue> am(af.attr().begin(),
449 af.attr().end());
450 for (const auto& bm_pair : bf.attr()) {
451 const auto& iter = am.find(bm_pair.first);
452 if (iter == am.end()) return false;
453 if (!AreAttrValuesEqual(iter->second, bm_pair.second)) return false;
454 am.erase(iter);
455 }
456 if (!am.empty()) return false;
457 return true;
458 }
459
460 // All other fields in AttrValue have deterministic representations.
461 // It is safe to compare their serialized strings.
462 string a_str, b_str;
463 SerializeToStringDeterministic(a, &a_str);
464 SerializeToStringDeterministic(b, &b_str);
465 return a_str == b_str;
466 }
467
AttrValueHash(const AttrValue & a)468 uint64 AttrValueHash(const AttrValue& a) {
469 if (a.has_tensor()) {
470 // Deal with multiple representations by parsing TensorProto to
471 // Tensor and serializing it back. This is slow, but current use case
472 // don't need high efficiency.
473 Tensor tensor(a.tensor().dtype());
474 bool success = tensor.FromProto(a.tensor());
475 DCHECK(success);
476 TensorProto p;
477 tensor.AsProtoTensorContent(&p);
478 string s;
479 SerializeToStringDeterministic(p, &s);
480 return Hash64(s);
481 }
482 if (a.has_func()) {
483 const NameAttrList& func = a.func();
484 uint64 h = Hash64(func.name());
485 std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
486 for (const auto& pair : map) {
487 h = Hash64(pair.first.data(), pair.first.size(), h);
488 h = Hash64Combine(AttrValueHash(pair.second), h);
489 }
490 return h;
491 }
492
493 // If `a` is not a tensor or func, get a hash of serialized string.
494 string s;
495 SerializeToStringDeterministic(a, &s);
496 return Hash64(s);
497 }
498
HasPlaceHolder(const AttrValue & val)499 bool HasPlaceHolder(const AttrValue& val) {
500 switch (val.value_case()) {
501 case AttrValue::kList: {
502 for (const NameAttrList& func : val.list().func()) {
503 for (const auto& p : func.attr()) {
504 if (HasPlaceHolder(p.second)) {
505 return true;
506 }
507 }
508 }
509 break;
510 }
511 case AttrValue::kFunc:
512 for (const auto& p : val.func().attr()) {
513 if (HasPlaceHolder(p.second)) {
514 return true;
515 }
516 }
517 break;
518 case AttrValue::kPlaceholder:
519 return true;
520 default:
521 break;
522 }
523 return false;
524 }
525
SubstitutePlaceholders(const SubstituteFunc & substitute,AttrValue * value)526 bool SubstitutePlaceholders(const SubstituteFunc& substitute,
527 AttrValue* value) {
528 switch (value->value_case()) {
529 case AttrValue::kList: {
530 for (NameAttrList& func : *value->mutable_list()->mutable_func()) {
531 for (auto& p : *func.mutable_attr()) {
532 if (!SubstitutePlaceholders(substitute, &p.second)) {
533 return false;
534 }
535 }
536 }
537 break;
538 }
539 case AttrValue::kFunc:
540 for (auto& p : *(value->mutable_func()->mutable_attr())) {
541 if (!SubstitutePlaceholders(substitute, &p.second)) {
542 return false;
543 }
544 }
545 break;
546 case AttrValue::kPlaceholder:
547 return substitute(value->placeholder(), value);
548 case AttrValue::VALUE_NOT_SET:
549 return false;
550 default:
551 break;
552 }
553 return true;
554 }
555
556 } // namespace tensorflow
557