1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/attr_value_util.h"
17 
18 #include <string>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/core/framework/attr_value.pb_text.h"
24 #include "tensorflow/core/framework/tensor.pb_text.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb_text.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/hash/hash.h"
31 #include "tensorflow/core/lib/strings/proto_serialization.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 
35 namespace tensorflow {
36 namespace {
37 
38 // Do not construct large tensors to compute their hash or compare for equality.
39 constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024;  // 32mb
40 
41 // Return the size of the tensor represented by this TensorProto. If shape is
42 // not fully defined return -1.
TensorByteSize(const TensorProto & t)43 int64 TensorByteSize(const TensorProto& t) {
44   // num_elements returns -1 if shape is not fully defined.
45   int64 num_elems = TensorShape(t.tensor_shape()).num_elements();
46   return num_elems < 0 ? -1 : num_elems * DataTypeSize(t.dtype());
47 }
48 
49 // Compute TensorProto hash by creating a Tensor, serializing it as tensor
50 // content, and computing a hash of it's string representation. This is unsafe
51 // operation, because large tensors can be represented as TensorProto, but can't
52 // be serialized to tensor content.
TensorProtoHash(const TensorProto & tp)53 uint64 TensorProtoHash(const TensorProto& tp) {
54   Tensor tensor(tp.dtype());
55   bool success = tensor.FromProto(tp);
56   DCHECK(success);
57   TensorProto p;
58   tensor.AsProtoTensorContent(&p);
59   return DeterministicProtoHash64(p);
60 }
61 
62 // Do not create large tensors in memory, compute hash based on TensorProto
63 // string representation. Tensors with identical content potentially can have a
64 // different hash code if they are defined with different TensorProto
65 // representations.
FastTensorProtoHash(const TensorProto & tp)66 uint64 FastTensorProtoHash(const TensorProto& tp) {
67   if (TensorByteSize(tp) > kMaxAttrValueTensorByteSize) {
68     return DeterministicProtoHash64(tp);
69   } else {
70     return TensorProtoHash(tp);
71   }
72 }
73 
74 // There are multiple equivalent representations of attr values containing
75 // TensorProtos. Compare them by constructing Tensors and serializing them
76 // back. Comparing Tensor objects is pretty tricky. This is unsafe operation,
77 // because large tensors can be represented as TensorProto, but can't be
78 // serialized to tensor content.
AreTensorProtosEqual(const TensorProto & lhs,const TensorProto & rhs)79 bool AreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
80   Tensor lhs_t(lhs.dtype());
81   bool success = lhs_t.FromProto(lhs);
82   DCHECK(success);
83 
84   Tensor rhs_t(rhs.dtype());
85   success = rhs_t.FromProto(rhs);
86   DCHECK(success);
87 
88   TensorProto lhs_tp;
89   lhs_t.AsProtoTensorContent(&lhs_tp);
90 
91   TensorProto rhs_tp;
92   rhs_t.AsProtoTensorContent(&rhs_tp);
93 
94   return AreSerializedProtosEqual(lhs_tp, rhs_tp);
95 }
96 
97 // Do not construct large tensors in memory, compare equality using TensorProto
98 // string representation. Tensors with identical content potentially can have
99 // different tensor proto representation.
FastAreTensorProtosEqual(const TensorProto & lhs,const TensorProto & rhs)100 bool FastAreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
101   // A small TensorProto can expand into a giant Tensor.  So we avoid
102   // conversion to an actual Tensor if we can quickly rule out equality
103   // by comparing the Tensor size since different sized Tensors are definitely
104   // different.
105   const int64 lhs_tensor_bytes = TensorByteSize(lhs);
106   const int64 rhs_tensor_bytes = TensorByteSize(rhs);
107   if (lhs_tensor_bytes != rhs_tensor_bytes) {
108     return false;
109   }
110 
111   // If the tensor is very large, we'll only compare the proto representation
112   // (even though this may miss some equivalent tensors whose actual tensor
113   // values are the same but which are described by different TensorProtos).
114   if (lhs_tensor_bytes > kMaxAttrValueTensorByteSize) {
115     return AreSerializedProtosEqual(lhs, rhs);
116   }
117 
118   // If the TensorProto representation expands into a much bigger Tensor,
119   // we have a fast-path that first compares the protos.
120   const int64 lhs_proto_bytes = lhs.ByteSizeLong();
121   const bool large_expansion =
122       (lhs_proto_bytes < 512 && lhs_tensor_bytes > 4096);
123   if (large_expansion && AreSerializedProtosEqual(lhs, rhs)) {
124     return true;
125   }
126 
127   // Fall back to the general code in AreTensorProtosEqual.
128   return AreTensorProtosEqual(lhs, rhs);
129 }
130 
131 using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
132 
AttrValueHash(const AttrValue & a,const TensorProtoHasher & tensor_hash)133 uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
134   if (a.has_tensor()) return tensor_hash(a.tensor());
135 
136   if (a.has_func()) {
137     const NameAttrList& func = a.func();
138     uint64 h = Hash64(func.name());
139     std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
140     for (const auto& pair : map) {
141       h = Hash64(pair.first.data(), pair.first.size(), h);
142       h = Hash64Combine(AttrValueHash(pair.second, tensor_hash), h);
143     }
144     return h;
145   }
146 
147   // If `a` is not a tensor or func, get a hash of serialized string.
148   return DeterministicProtoHash64(a);
149 }
150 
151 template <typename TensorProtosEquality>
AreAttrValuesEqual(const AttrValue & a,const AttrValue & b,TensorProtosEquality tensor_equality)152 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
153                         TensorProtosEquality tensor_equality) {
154   if (a.type() != b.type()) {
155     return false;
156   } else if (a.type() != DT_INVALID && b.type() != DT_INVALID) {
157     return a.type() == b.type();
158   }
159 
160   if (a.has_tensor() != b.has_tensor()) {
161     return false;
162   } else if (a.has_tensor() && b.has_tensor()) {
163     return tensor_equality(a.tensor(), b.tensor());
164   }
165 
166   // `func` field contains a nested AttrValue. Compare such AttrValues
167   // recursively.
168   if (a.has_func() != b.has_func()) {
169     return false;
170   } else if (a.has_func() && b.has_func()) {
171     const NameAttrList& af = a.func();
172     const NameAttrList& bf = b.func();
173     if (af.name() != bf.name()) return false;
174     std::unordered_map<string, AttrValue> am(af.attr().begin(),
175                                              af.attr().end());
176     for (const auto& bm_pair : bf.attr()) {
177       const auto& iter = am.find(bm_pair.first);
178       if (iter == am.end()) return false;
179       if (!AreAttrValuesEqual(iter->second, bm_pair.second, tensor_equality))
180         return false;
181       am.erase(iter);
182     }
183     if (!am.empty()) return false;
184     return true;
185   }
186 
187   // All other fields in AttrValue have deterministic representations.
188   // It is safe to compare their serialized strings.
189   return AreSerializedProtosEqual(a, b);
190 }
191 
SummarizeString(const string & str)192 string SummarizeString(const string& str) {
193   string escaped = absl::CEscape(str);
194 
195   // If the string is long, replace the middle with ellipses.
196   constexpr int kMaxStringSummarySize = 80;
197   if (escaped.size() >= kMaxStringSummarySize) {
198     StringPiece prefix(escaped);
199     StringPiece suffix = prefix;
200     prefix.remove_suffix(escaped.size() - 10);
201     suffix.remove_prefix(escaped.size() - 10);
202     return strings::StrCat("\"", prefix, "...", suffix, "\"");
203   } else {
204     return strings::StrCat("\"", escaped, "\"");
205   }
206 }
207 
SummarizeTensor(const TensorProto & tensor_proto)208 string SummarizeTensor(const TensorProto& tensor_proto) {
209   Tensor t;
210   if (!t.FromProto(tensor_proto)) {
211     return strings::StrCat(
212         "<Invalid TensorProto: ", tensor_proto.ShortDebugString(), ">");
213   }
214   return t.DebugString();
215 }
216 
SummarizeFunc(const NameAttrList & func)217 string SummarizeFunc(const NameAttrList& func) {
218   std::vector<string> entries;
219   for (const auto& p : func.attr()) {
220     entries.push_back(
221         strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
222   }
223   std::sort(entries.begin(), entries.end());
224   return strings::StrCat(func.name(), "[", absl::StrJoin(entries, ", "), "]");
225 }
226 
227 }  // namespace
228 
SummarizeAttrValue(const AttrValue & attr_value)229 string SummarizeAttrValue(const AttrValue& attr_value) {
230   switch (attr_value.value_case()) {
231     case AttrValue::kS:
232       return SummarizeString(attr_value.s());
233     case AttrValue::kI:
234       return strings::StrCat(attr_value.i());
235     case AttrValue::kF:
236       return strings::StrCat(attr_value.f());
237     case AttrValue::kB:
238       return attr_value.b() ? "true" : "false";
239     case AttrValue::kType:
240       return EnumName_DataType(attr_value.type());
241     case AttrValue::kShape:
242       return PartialTensorShape::DebugString(attr_value.shape());
243     case AttrValue::kTensor:
244       return SummarizeTensor(attr_value.tensor());
245     case AttrValue::kList: {
246       std::vector<string> pieces;
247       if (attr_value.list().s_size() > 0) {
248         for (int i = 0; i < attr_value.list().s_size(); ++i) {
249           pieces.push_back(SummarizeString(attr_value.list().s(i)));
250         }
251       } else if (attr_value.list().i_size() > 0) {
252         for (int i = 0; i < attr_value.list().i_size(); ++i) {
253           pieces.push_back(strings::StrCat(attr_value.list().i(i)));
254         }
255       } else if (attr_value.list().f_size() > 0) {
256         for (int i = 0; i < attr_value.list().f_size(); ++i) {
257           pieces.push_back(strings::StrCat(attr_value.list().f(i)));
258         }
259       } else if (attr_value.list().b_size() > 0) {
260         for (int i = 0; i < attr_value.list().b_size(); ++i) {
261           pieces.push_back(attr_value.list().b(i) ? "true" : "false");
262         }
263       } else if (attr_value.list().type_size() > 0) {
264         for (int i = 0; i < attr_value.list().type_size(); ++i) {
265           pieces.push_back(EnumName_DataType(attr_value.list().type(i)));
266         }
267       } else if (attr_value.list().shape_size() > 0) {
268         for (int i = 0; i < attr_value.list().shape_size(); ++i) {
269           pieces.push_back(
270               TensorShape::DebugString(attr_value.list().shape(i)));
271         }
272       } else if (attr_value.list().tensor_size() > 0) {
273         for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
274           pieces.push_back(SummarizeTensor(attr_value.list().tensor(i)));
275         }
276       } else if (attr_value.list().func_size() > 0) {
277         for (int i = 0; i < attr_value.list().func_size(); ++i) {
278           pieces.push_back(SummarizeFunc(attr_value.list().func(i)));
279         }
280       }
281       constexpr int kMaxListSummarySize = 50;
282       if (pieces.size() >= kMaxListSummarySize) {
283         pieces.erase(pieces.begin() + 5, pieces.begin() + (pieces.size() - 6));
284         pieces[5] = "...";
285       }
286       return strings::StrCat("[", absl::StrJoin(pieces, ", "), "]");
287     }
288     case AttrValue::kFunc: {
289       return SummarizeFunc(attr_value.func());
290     }
291     case AttrValue::kPlaceholder:
292       return strings::StrCat("$", attr_value.placeholder());
293     case AttrValue::VALUE_NOT_SET:
294       return "<Unknown AttrValue type>";
295   }
296   return "<Unknown AttrValue type>";  // Prevent missing return warning
297 }
298 
AttrValueHasType(const AttrValue & attr_value,StringPiece type)299 Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
300   int num_set = 0;
301 
302 #define VALIDATE_FIELD(name, type_string, oneof_case)                         \
303   do {                                                                        \
304     if (attr_value.has_list()) {                                              \
305       if (attr_value.list().name##_size() > 0) {                              \
306         if (type != "list(" type_string ")") {                                \
307           return errors::InvalidArgument(                                     \
308               "AttrValue had value with type 'list(" type_string ")' when '", \
309               type, "' expected");                                            \
310         }                                                                     \
311         ++num_set;                                                            \
312       }                                                                       \
313     } else if (attr_value.value_case() == AttrValue::oneof_case) {            \
314       if (type != type_string) {                                              \
315         return errors::InvalidArgument(                                       \
316             "AttrValue had value with type '" type_string "' when '", type,   \
317             "' expected");                                                    \
318       }                                                                       \
319       ++num_set;                                                              \
320     }                                                                         \
321   } while (false)
322 
323   VALIDATE_FIELD(s, "string", kS);
324   VALIDATE_FIELD(i, "int", kI);
325   VALIDATE_FIELD(f, "float", kF);
326   VALIDATE_FIELD(b, "bool", kB);
327   VALIDATE_FIELD(type, "type", kType);
328   VALIDATE_FIELD(shape, "shape", kShape);
329   VALIDATE_FIELD(tensor, "tensor", kTensor);
330   VALIDATE_FIELD(func, "func", kFunc);
331 
332 #undef VALIDATE_FIELD
333 
334   if (attr_value.value_case() == AttrValue::kPlaceholder) {
335     return errors::InvalidArgument(
336         "AttrValue had value with unexpected type 'placeholder'");
337   }
338 
339   // If the attr type is 'list', we expect attr_value.has_list() to be
340   // true.  However, proto3's attr_value.has_list() can be false when
341   // set to an empty list for GraphDef versions <= 4. So we simply
342   // check if has_list is false and some other field in attr_value is
343   // set to flag the error.  This test can be made more strict once
344   // support for GraphDef versions <= 4 is dropped.
345   if (absl::StartsWith(type, "list(") && !attr_value.has_list()) {
346     if (num_set) {
347       return errors::InvalidArgument(
348           "AttrValue missing value with expected type '", type, "'");
349     } else {
350       // Indicate that we have a list, but an empty one.
351       ++num_set;
352     }
353   }
354 
355   // Okay to have an empty list, but not to be missing a non-list value.
356   if (num_set == 0 && !absl::StartsWith(type, "list(")) {
357     return errors::InvalidArgument(
358         "AttrValue missing value with expected type '", type, "'");
359   }
360 
361   // Ref types and DT_INVALID are illegal, and DataTypes must
362   // be a valid enum type.
363   if (type == "type") {
364     if (!DataType_IsValid(attr_value.type())) {
365       return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
366                                      attr_value.type());
367     }
368     if (IsRefType(attr_value.type())) {
369       return errors::InvalidArgument(
370           "AttrValue must not have reference type value of ",
371           DataTypeString(attr_value.type()));
372     }
373     if (attr_value.type() == DT_INVALID) {
374       return errors::InvalidArgument("AttrValue has invalid DataType");
375     }
376   } else if (type == "list(type)") {
377     for (auto as_int : attr_value.list().type()) {
378       const DataType dtype = static_cast<DataType>(as_int);
379       if (!DataType_IsValid(dtype)) {
380         return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
381                                        as_int);
382       }
383       if (IsRefType(dtype)) {
384         return errors::InvalidArgument(
385             "AttrValue must not have reference type value of ",
386             DataTypeString(dtype));
387       }
388       if (dtype == DT_INVALID) {
389         return errors::InvalidArgument("AttrValue contains invalid DataType");
390       }
391     }
392   }
393 
394   return Status::OK();
395 }
396 
ParseAttrValue(StringPiece type,StringPiece text,AttrValue * out)397 bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
398   // Parse type.
399   string field_name;
400   bool is_list = absl::ConsumePrefix(&type, "list(");
401   if (absl::ConsumePrefix(&type, "string")) {
402     field_name = "s";
403   } else if (absl::ConsumePrefix(&type, "int")) {
404     field_name = "i";
405   } else if (absl::ConsumePrefix(&type, "float")) {
406     field_name = "f";
407   } else if (absl::ConsumePrefix(&type, "bool")) {
408     field_name = "b";
409   } else if (absl::ConsumePrefix(&type, "type")) {
410     field_name = "type";
411   } else if (absl::ConsumePrefix(&type, "shape")) {
412     field_name = "shape";
413   } else if (absl::ConsumePrefix(&type, "tensor")) {
414     field_name = "tensor";
415   } else if (absl::ConsumePrefix(&type, "func")) {
416     field_name = "func";
417   } else if (absl::ConsumePrefix(&type, "placeholder")) {
418     field_name = "placeholder";
419   } else {
420     return false;
421   }
422   if (is_list && !absl::ConsumePrefix(&type, ")")) {
423     return false;
424   }
425 
426   // Construct a valid text proto message to parse.
427   string to_parse;
428   if (is_list) {
429     // TextFormat parser considers "i: 7" to be the same as "i: [7]",
430     // but we only want to allow list values with [].
431     StringPiece cleaned = text;
432     str_util::RemoveLeadingWhitespace(&cleaned);
433     str_util::RemoveTrailingWhitespace(&cleaned);
434     if (cleaned.size() < 2 || cleaned[0] != '[' ||
435         cleaned[cleaned.size() - 1] != ']') {
436       return false;
437     }
438     cleaned.remove_prefix(1);
439     str_util::RemoveLeadingWhitespace(&cleaned);
440     if (cleaned.size() == 1) {
441       // User wrote "[]", so return empty list without invoking the TextFormat
442       // parse which returns an error for "i: []".
443       out->Clear();
444       out->mutable_list();
445       return true;
446     }
447     to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
448   } else {
449     to_parse = strings::StrCat(field_name, ": ", text);
450   }
451 
452   return ProtoParseFromString(to_parse, out);
453 }
454 
SetAttrValue(const AttrValue & value,AttrValue * out)455 void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
456 
457 #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
458   void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
459 
460 #define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD)                       \
461   void SetAttrValue(ARG_TYPE value, AttrValue* out) {                     \
462     out->mutable_list()->Clear(); /* create list() even if value empty */ \
463     for (const auto& v : value) {                                         \
464       out->mutable_list()->add_##FIELD(v);                                \
465     }                                                                     \
466   }
467 
468 #define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
469   DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD)        \
470   DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
471 
DEFINE_SET_ATTR_VALUE_ONE(const string &,s)472 DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
473 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
474 DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
475 DEFINE_SET_ATTR_VALUE_BOTH(int64, i)
476 DEFINE_SET_ATTR_VALUE_BOTH(int32, i)
477 DEFINE_SET_ATTR_VALUE_BOTH(float, f)
478 DEFINE_SET_ATTR_VALUE_BOTH(double, f)
479 DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
480 DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
481 DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
482 DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
483 
484 void SetAttrValue(const tstring& value, AttrValue* out) {
485   out->set_s(value.data(), value.size());
486 }
487 
SetAttrValue(gtl::ArraySlice<tstring> value,AttrValue * out)488 void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out) {
489   out->mutable_list()->Clear();
490   for (const auto& v : value) {
491     out->mutable_list()->add_s(v.data(), v.size());
492   }
493 }
494 
SetAttrValue(StringPiece value,AttrValue * out)495 void SetAttrValue(StringPiece value, AttrValue* out) {
496   out->set_s(value.data(), value.size());
497 }
498 
SetAttrValue(const gtl::ArraySlice<StringPiece> value,AttrValue * out)499 void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) {
500   out->mutable_list()->Clear();  // Create list() even if value empty.
501   for (const auto& v : value) {
502     out->mutable_list()->add_s(v.data(), v.size());
503   }
504 }
505 
MoveAttrValue(std::vector<string> && value,AttrValue * out)506 void MoveAttrValue(std::vector<string>&& value, AttrValue* out) {
507   out->mutable_list()->Clear();  // Create list() even if value empty.
508   for (auto& v : value) {
509     out->mutable_list()->add_s(std::move(v));
510   }
511 }
512 
SetAttrValue(const TensorShape & value,AttrValue * out)513 void SetAttrValue(const TensorShape& value, AttrValue* out) {
514   value.AsProto(out->mutable_shape());
515 }
516 
SetAttrValue(const TensorShapeProto & value,AttrValue * out)517 void SetAttrValue(const TensorShapeProto& value, AttrValue* out) {
518   *out->mutable_shape() = value;
519 }
520 
SetAttrValue(const PartialTensorShape & value,AttrValue * out)521 void SetAttrValue(const PartialTensorShape& value, AttrValue* out) {
522   value.AsProto(out->mutable_shape());
523 }
524 
SetAttrValue(const gtl::ArraySlice<TensorShape> value,AttrValue * out)525 void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
526   out->mutable_list()->Clear();  // Create list() even if value empty.
527   for (const auto& v : value) {
528     v.AsProto(out->mutable_list()->add_shape());
529   }
530 }
531 
SetAttrValue(gtl::ArraySlice<TensorShapeProto> value,AttrValue * out)532 void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) {
533   out->mutable_list()->Clear();  // Create list() even if value empty.
534   for (const auto& v : value) {
535     *out->mutable_list()->add_shape() = v;
536   }
537 }
538 
SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,AttrValue * out)539 void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,
540                   AttrValue* out) {
541   out->mutable_list()->Clear();  // Create list() even if value empty.
542   for (const auto& v : value) {
543     v.AsProto(out->mutable_list()->add_shape());
544   }
545 }
546 
SetAttrValue(const Tensor & value,AttrValue * out)547 void SetAttrValue(const Tensor& value, AttrValue* out) {
548   if (value.NumElements() > 1) {
549     value.AsProtoTensorContent(out->mutable_tensor());
550   } else {
551     value.AsProtoField(out->mutable_tensor());
552   }
553 }
554 
SetAttrValue(const gtl::ArraySlice<Tensor> value,AttrValue * out)555 void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
556   out->mutable_list()->Clear();  // Create list() even if value empty.
557   for (const auto& v : value) {
558     if (v.NumElements() > 1) {
559       v.AsProtoTensorContent(out->mutable_list()->add_tensor());
560     } else {
561       v.AsProtoField(out->mutable_list()->add_tensor());
562     }
563   }
564 }
565 
SetAttrValue(const TensorProto & value,AttrValue * out)566 void SetAttrValue(const TensorProto& value, AttrValue* out) {
567   *out->mutable_tensor() = value;
568 }
569 
SetAttrValue(const gtl::ArraySlice<TensorProto> value,AttrValue * out)570 void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
571   out->mutable_list()->Clear();  // Create list() even if value empty.
572   for (const auto& v : value) {
573     *out->mutable_list()->add_tensor() = v;
574   }
575 }
576 
SetAttrValue(const NameAttrList & value,AttrValue * out)577 void SetAttrValue(const NameAttrList& value, AttrValue* out) {
578   *out->mutable_func() = value;
579 }
580 
SetAttrValue(gtl::ArraySlice<NameAttrList> value,AttrValue * out)581 void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
582   out->mutable_list()->Clear();  // Create list() even if value empty.
583   for (const auto& v : value) {
584     *out->mutable_list()->add_func() = v;
585   }
586 }
587 
AreAttrValuesEqual(const AttrValue & a,const AttrValue & b)588 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
589   return AreAttrValuesEqual(a, b, AreTensorProtosEqual);
590 }
591 
AttrValueHash(const AttrValue & a)592 uint64 AttrValueHash(const AttrValue& a) {
593   return AttrValueHash(a, TensorProtoHash);
594 }
595 
FastAreAttrValuesEqual(const AttrValue & a,const AttrValue & b)596 bool FastAreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
597   return AreAttrValuesEqual(a, b, FastAreTensorProtosEqual);
598 }
599 
FastAttrValueHash(const AttrValue & a)600 uint64 FastAttrValueHash(const AttrValue& a) {
601   return AttrValueHash(a, FastTensorProtoHash);
602 }
603 
HasPlaceHolder(const AttrValue & val)604 bool HasPlaceHolder(const AttrValue& val) {
605   switch (val.value_case()) {
606     case AttrValue::kList: {
607       for (const NameAttrList& func : val.list().func()) {
608         for (const auto& p : func.attr()) {
609           if (HasPlaceHolder(p.second)) {
610             return true;
611           }
612         }
613       }
614       break;
615     }
616     case AttrValue::kFunc:
617       for (const auto& p : val.func().attr()) {
618         if (HasPlaceHolder(p.second)) {
619           return true;
620         }
621       }
622       break;
623     case AttrValue::kPlaceholder:
624       return true;
625     default:
626       break;
627   }
628   return false;
629 }
630 
SubstitutePlaceholders(const SubstituteFunc & substitute,AttrValue * value)631 bool SubstitutePlaceholders(const SubstituteFunc& substitute,
632                             AttrValue* value) {
633   switch (value->value_case()) {
634     case AttrValue::kList: {
635       for (NameAttrList& func : *value->mutable_list()->mutable_func()) {
636         for (auto& p : *func.mutable_attr()) {
637           if (!SubstitutePlaceholders(substitute, &p.second)) {
638             return false;
639           }
640         }
641       }
642       break;
643     }
644     case AttrValue::kFunc:
645       for (auto& p : *(value->mutable_func()->mutable_attr())) {
646         if (!SubstitutePlaceholders(substitute, &p.second)) {
647           return false;
648         }
649       }
650       break;
651     case AttrValue::kPlaceholder:
652       return substitute(value->placeholder(), value);
653     case AttrValue::VALUE_NOT_SET:
654       return false;
655     default:
656       break;
657   }
658   return true;
659 }
660 
661 }  // namespace tensorflow
662