1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/util/util.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
28 
29 namespace tensorflow {
30 namespace swig {
31 
32 namespace {
33 string PyObjectToString(PyObject* o);
34 }  // namespace
35 
RegisteredPyObjectMap()36 std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() {
37   static auto* m = new std::unordered_map<string, PyObject*>();
38   return m;
39 }
40 
GetRegisteredPyObject(const string & name)41 PyObject* GetRegisteredPyObject(const string& name) {
42   const auto* m = RegisteredPyObjectMap();
43   auto it = m->find(name);
44   if (it == m->end()) {
45     PyErr_SetString(PyExc_TypeError,
46                     tensorflow::strings::StrCat("No object with name ", name,
47                                                 " has been registered.")
48                         .c_str());
49     return nullptr;
50   }
51   return it->second;
52 }
53 
RegisterType(PyObject * type_name,PyObject * type)54 PyObject* RegisterType(PyObject* type_name, PyObject* type) {
55   if (!PyType_Check(type)) {
56     PyErr_SetString(PyExc_TypeError,
57                     tensorflow::strings::StrCat("Expecting a type, got ",
58                                                 Py_TYPE(type)->tp_name)
59                         .c_str());
60     return nullptr;
61   }
62   return RegisterPyObject(type_name, type);
63 }
64 
RegisterPyObject(PyObject * name,PyObject * value)65 PyObject* RegisterPyObject(PyObject* name, PyObject* value) {
66   string key;
67   if (PyBytes_Check(name)) {
68     key = PyBytes_AsString(name);
69 #if PY_MAJOR_VERSION >= 3
70   } else if (PyUnicode_Check(name)) {
71     key = PyUnicode_AsUTF8(name);
72 #endif
73   } else {
74     PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
75                                          "Expected name to be a str, got",
76                                          PyObjectToString(name))
77                                          .c_str());
78     return nullptr;
79   }
80 
81   auto* m = RegisteredPyObjectMap();
82   if (m->find(key) != m->end()) {
83     PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
84                                          "Value already registered for ", key)
85                                          .c_str());
86     return nullptr;
87   }
88 
89   Py_INCREF(value);
90   m->emplace(key, value);
91 
92   Py_RETURN_NONE;
93 }
94 
95 namespace {
96 const int kMaxItemsInCache = 1024;
97 
98 bool WarnedThatSetIsNotSequence = false;
99 
IsString(PyObject * o)100 bool IsString(PyObject* o) {
101   return PyBytes_Check(o) ||
102 #if PY_MAJOR_VERSION < 3
103          PyString_Check(o) ||
104 #endif
105          PyUnicode_Check(o);
106 }
107 
108 // Equivalent to Python's 'o.__class__.__name__'
109 // Note that '__class__' attribute is set only in new-style classes.
110 // A lot of tensorflow code uses __class__ without checks, so it seems like
111 // we only support new-style classes.
GetClassName(PyObject * o)112 StringPiece GetClassName(PyObject* o) {
113   // __class__ is equivalent to type() for new style classes.
114   // type() is equivalent to PyObject_Type()
115   // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
116   // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
117   // we don't need here.
118   PyTypeObject* type = o->ob_type;
119 
120   // __name__ is the value of `tp_name` after the last '.'
121   // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
122   StringPiece name(type->tp_name);
123   size_t pos = name.rfind('.');
124   if (pos != StringPiece::npos) {
125     name.remove_prefix(pos + 1);
126   }
127   return name;
128 }
129 
PyObjectToString(PyObject * o)130 string PyObjectToString(PyObject* o) {
131   if (o == nullptr) {
132     return "<null object>";
133   }
134   PyObject* str = PyObject_Str(o);
135   if (str) {
136 #if PY_MAJOR_VERSION < 3
137     string s(PyString_AS_STRING(str));
138 #else
139     string s(PyUnicode_AsUTF8(str));
140 #endif
141     Py_DECREF(str);
142     return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
143   } else {
144     return "<failed to execute str() on object>";
145   }
146 }
147 
148 class CachedTypeCheck {
149  public:
CachedTypeCheck(std::function<int (PyObject *)> ternary_predicate)150   explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
151       : ternary_predicate_(std::move(ternary_predicate)) {}
152 
~CachedTypeCheck()153   ~CachedTypeCheck() {
154     mutex_lock l(type_to_sequence_map_mu_);
155     for (const auto& pair : type_to_sequence_map_) {
156       Py_DECREF(pair.first);
157     }
158   }
159 
160   // Caches successful executions of the one-argument (PyObject*) callable
161   // "ternary_predicate" based on the type of "o". -1 from the callable
162   // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
163   // does not match the predicate, and 1 indicates that it does. Used to avoid
164   // calling back into Python for expensive isinstance checks.
CachedLookup(PyObject * o)165   int CachedLookup(PyObject* o) {
166     // Try not to return to Python - see if the type has already been seen
167     // before.
168 
169     auto* type = Py_TYPE(o);
170 
171     {
172       tf_shared_lock l(type_to_sequence_map_mu_);
173       auto it = type_to_sequence_map_.find(type);
174       if (it != type_to_sequence_map_.end()) {
175         return it->second;
176       }
177     }
178 
179     int check_result = ternary_predicate_(o);
180 
181     if (check_result == -1) {
182       return -1;  // Type check error, not cached.
183     }
184 
185     // NOTE: This is never decref'd as long as the object lives, which is likely
186     // forever, but we don't want the type to get deleted as long as it is in
187     // the map. This should not be too much of a leak, as there should only be a
188     // relatively small number of types in the map, and an even smaller number
189     // that are eligible for decref. As a precaution, we limit the size of the
190     // map to 1024.
191     {
192       mutex_lock l(type_to_sequence_map_mu_);
193       if (type_to_sequence_map_.size() < kMaxItemsInCache) {
194         Py_INCREF(type);
195         auto insert_result = type_to_sequence_map_.insert({type, check_result});
196         if (!insert_result.second) {
197           // The type was added to the cache by a concurrent thread after we
198           // looked it up above.
199           Py_DECREF(type);
200         }
201       }
202     }
203 
204     return check_result;
205   }
206 
207  private:
208   std::function<int(PyObject*)> ternary_predicate_;
209   mutex type_to_sequence_map_mu_;
210   std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
211       TF_GUARDED_BY(type_to_sequence_map_mu_);
212 };
213 
214 // Returns 1 if 'obj' is an instance of 'type_name'
215 // Returns 0 otherwise.
216 // Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
IsInstanceOfRegisteredType(PyObject * obj,const char * type_name)217 int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
218   PyObject* type_obj = GetRegisteredPyObject(type_name);
219   if (TF_PREDICT_FALSE(type_obj == nullptr)) {
220     PyErr_SetString(PyExc_RuntimeError,
221                     tensorflow::strings::StrCat(
222                         type_name,
223                         " type has not been set. "
224                         "Please register the type with the identifier \"",
225                         type_name, "\" using RegisterType.")
226                         .c_str());
227     return -1;
228   }
229   return PyObject_IsInstance(obj, type_obj);
230 }
231 
232 // Returns 1 if `o` is considered a mapping for the purposes of Flatten().
233 // Returns 0 otherwise.
234 // Returns -1 if an error occurred.
IsMappingHelper(PyObject * o)235 int IsMappingHelper(PyObject* o) {
236   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
237     return IsInstanceOfRegisteredType(to_check, "Mapping");
238   });
239   if (PyDict_Check(o)) return true;
240   return check_cache->CachedLookup(o);
241 }
242 
243 // Returns 1 if `o` is considered a mutable mapping for the purposes of
244 // Flatten(). Returns 0 otherwise. Returns -1 if an error occurred.
IsMutableMappingHelper(PyObject * o)245 int IsMutableMappingHelper(PyObject* o) {
246   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
247     return IsInstanceOfRegisteredType(to_check, "MutableMapping");
248   });
249   if (PyDict_Check(o)) return true;
250   return check_cache->CachedLookup(o);
251 }
252 
253 // Returns 1 if `o` is considered a mapping view for the purposes of Flatten().
254 // Returns 0 otherwise.
255 // Returns -1 if an error occurred.
IsMappingViewHelper(PyObject * o)256 int IsMappingViewHelper(PyObject* o) {
257   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
258     return IsInstanceOfRegisteredType(to_check, "MappingView");
259   });
260   return check_cache->CachedLookup(o);
261 }
262 
263 // Returns 1 if `o` is considered an object proxy
264 // Returns 0 otherwise.
265 // Returns -1 if an error occurred.
IsObjectProxy(PyObject * o)266 int IsObjectProxy(PyObject* o) {
267   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
268     return IsInstanceOfRegisteredType(to_check, "ObjectProxy");
269   });
270   return check_cache->CachedLookup(o);
271 }
272 
273 // Returns 1 if `o` is an instance of attrs-decorated class.
274 // Returns 0 otherwise.
IsAttrsHelper(PyObject * o)275 int IsAttrsHelper(PyObject* o) {
276   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
277     Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
278     if (cls) {
279       return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
280     }
281 
282     // PyObject_GetAttrString returns null on error
283     PyErr_Clear();
284     return 0;
285   });
286   return check_cache->CachedLookup(o);
287 }
288 
289 // Returns 1 if `o` is an object of type IndexedSlices.
290 // Returns 0 otherwise.
291 // Returns -1 if an error occurred.
IsIndexedSlicesHelper(PyObject * o)292 int IsIndexedSlicesHelper(PyObject* o) {
293   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
294     return IsInstanceOfRegisteredType(to_check, "IndexedSlices");
295   });
296   return check_cache->CachedLookup(o);
297 }
298 
299 // Returns 1 if `o` is a Tensor.
300 // Returns 0 otherwise.
301 // Returns -1 if an error occurred.
IsTensorHelper(PyObject * o)302 int IsTensorHelper(PyObject* o) {
303   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
304     return IsInstanceOfRegisteredType(to_check, "Tensor");
305   });
306   return check_cache->CachedLookup(o);
307 }
308 
309 // Returns 1 if `o` is an EagerTensor.
310 // Returns 0 otherwise.
311 // Returns -1 if an error occurred.
IsEagerTensorHelper(PyObject * o)312 int IsEagerTensorHelper(PyObject* o) {
313   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
314     return IsInstanceOfRegisteredType(to_check, "EagerTensor");
315   });
316   return check_cache->CachedLookup(o);
317 }
318 
319 // Returns 1 if `o` is a ResourceVariable.
320 // Returns 0 otherwise.
321 // Returns -1 if an error occurred.
IsResourceVariableHelper(PyObject * o)322 int IsResourceVariableHelper(PyObject* o) {
323   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
324     return IsInstanceOfRegisteredType(to_check, "ResourceVariable");
325   });
326   return check_cache->CachedLookup(o);
327 }
328 
329 // Returns 1 if `o` is a ResourceVariable.
330 // Returns 0 otherwise.
331 // Returns -1 if an error occurred.
IsVariableHelper(PyObject * o)332 int IsVariableHelper(PyObject* o) {
333   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
334     return IsInstanceOfRegisteredType(to_check, "Variable");
335   });
336   return check_cache->CachedLookup(o);
337 }
338 
339 // Returns 1 if `o` is considered a sequence for the purposes of Flatten().
340 // Returns 0 otherwise.
341 // Returns -1 if an error occurred.
IsSequenceHelper(PyObject * o)342 int IsSequenceHelper(PyObject* o) {
343   // We treat dicts and other mappings as special cases of sequences.
344   if (IsMappingHelper(o)) return true;
345   if (IsMappingViewHelper(o)) return true;
346   if (IsAttrsHelper(o)) return true;
347   if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
348     LOG(WARNING) << "Sets are not currently considered sequences, "
349                     "but this may change in the future, "
350                     "so consider avoiding using them.";
351     WarnedThatSetIsNotSequence = true;
352   }
353   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
354     int is_instance = IsInstanceOfRegisteredType(to_check, "Sequence");
355 
356     // Don't cache a failed is_instance check.
357     if (is_instance == -1) return -1;
358 
359     return static_cast<int>(is_instance != 0 && !IsString(to_check));
360   });
361   return check_cache->CachedLookup(o);
362 }
363 
364 // Returns 1 if `o`'s class has a `__tf_dispatch__` attribute.
365 // Returns 0 otherwise.
IsDispatchableHelper(PyObject * o)366 int IsDispatchableHelper(PyObject* o) {
367   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
368     return PyObject_HasAttrString(
369         reinterpret_cast<PyObject*>(to_check->ob_type), "__tf_dispatch__");
370   });
371   return check_cache->CachedLookup(o);
372 }
373 
374 // ValueIterator interface
375 class ValueIterator {
376  public:
~ValueIterator()377   virtual ~ValueIterator() {}
378   virtual Safe_PyObjectPtr next() = 0;
379 
valid() const380   bool valid() const { return is_valid_; }
381 
382  protected:
invalidate()383   void invalidate() { is_valid_ = false; }
384 
385  private:
386   bool is_valid_ = true;
387 };
388 
389 using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
390 
391 // Iterate through dictionaries in a deterministic order by sorting the
392 // keys. Notice this means that we ignore the original order of
393 // `OrderedDict` instances. This is intentional, to avoid potential
394 // bugs caused by mixing ordered and plain dicts (e.g., flattening
395 // a dict but using a corresponding `OrderedDict` to pack it back).
396 class DictValueIterator : public ValueIterator {
397  public:
DictValueIterator(PyObject * dict)398   explicit DictValueIterator(PyObject* dict)
399       : dict_(dict), keys_(PyDict_Keys(dict)) {
400     if (PyList_Sort(keys_.get()) == -1) {
401       invalidate();
402     } else {
403       iter_.reset(PyObject_GetIter(keys_.get()));
404     }
405   }
406 
next()407   Safe_PyObjectPtr next() override {
408     Safe_PyObjectPtr result;
409     Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
410     if (key) {
411       // PyDict_GetItem returns a borrowed reference.
412       PyObject* elem = PyDict_GetItem(dict_, key.get());
413       if (elem) {
414         Py_INCREF(elem);
415         result.reset(elem);
416       } else {
417         PyErr_SetString(PyExc_RuntimeError,
418                         "Dictionary was modified during iteration over it");
419       }
420     }
421     return result;
422   }
423 
424  private:
425   PyObject* dict_;
426   Safe_PyObjectPtr keys_;
427   Safe_PyObjectPtr iter_;
428 };
429 
430 // Iterate over mapping objects by sorting the keys first
431 class MappingValueIterator : public ValueIterator {
432  public:
MappingValueIterator(PyObject * mapping)433   explicit MappingValueIterator(PyObject* mapping)
434       : mapping_(mapping), keys_(MappingKeys(mapping)) {
435     if (!keys_ || PyList_Sort(keys_.get()) == -1) {
436       invalidate();
437     } else {
438       iter_.reset(PyObject_GetIter(keys_.get()));
439     }
440   }
441 
next()442   Safe_PyObjectPtr next() override {
443     Safe_PyObjectPtr result;
444     Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
445     if (key) {
446       // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
447       PyObject* elem = PyObject_GetItem(mapping_, key.get());
448       if (elem) {
449         result.reset(elem);
450       } else {
451         PyErr_SetString(PyExc_RuntimeError,
452                         "Mapping was modified during iteration over it");
453       }
454     }
455     return result;
456   }
457 
458  private:
459   PyObject* mapping_;
460   Safe_PyObjectPtr keys_;
461   Safe_PyObjectPtr iter_;
462 };
463 
464 // Iterate over a sequence, by index.
465 class SequenceValueIterator : public ValueIterator {
466  public:
SequenceValueIterator(PyObject * iterable)467   explicit SequenceValueIterator(PyObject* iterable)
468       : seq_(PySequence_Fast(iterable, "")),
469         size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0),
470         index_(0) {}
471 
next()472   Safe_PyObjectPtr next() override {
473     Safe_PyObjectPtr result;
474     if (index_ < size_) {
475       // PySequence_Fast_GET_ITEM returns a borrowed reference.
476       PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
477       ++index_;
478       if (elem) {
479         Py_INCREF(elem);
480         result.reset(elem);
481       }
482     }
483 
484     return result;
485   }
486 
487  private:
488   Safe_PyObjectPtr seq_;
489   const Py_ssize_t size_;
490   Py_ssize_t index_;
491 };
492 
493 // Iterator that just returns a single python object.
494 class SingleValueIterator : public ValueIterator {
495  public:
SingleValueIterator(PyObject * x)496   explicit SingleValueIterator(PyObject* x) : x_(x) { Py_INCREF(x); }
497 
next()498   Safe_PyObjectPtr next() override { return std::move(x_); }
499 
500  private:
501   Safe_PyObjectPtr x_;
502 };
503 
504 // Returns nullptr (to raise an exception) when next() is called.  Caller
505 // should have already called PyErr_SetString.
506 class ErrorValueIterator : public ValueIterator {
507  public:
ErrorValueIterator()508   ErrorValueIterator() {}
next()509   Safe_PyObjectPtr next() override { return nullptr; }
510 };
511 
512 class AttrsValueIterator : public ValueIterator {
513  public:
AttrsValueIterator(PyObject * nested)514   explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
515     Py_INCREF(nested);
516     cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
517     if (cls_) {
518       attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
519       if (attrs_) {
520         iter_.reset(PyObject_GetIter(attrs_.get()));
521       }
522     }
523     if (!iter_ || PyErr_Occurred()) invalidate();
524   }
525 
next()526   Safe_PyObjectPtr next() override {
527     Safe_PyObjectPtr result;
528     Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
529     if (item) {
530       Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
531       result.reset(PyObject_GetAttr(nested_.get(), name.get()));
532     }
533 
534     return result;
535   }
536 
537  private:
538   Safe_PyObjectPtr nested_;
539   Safe_PyObjectPtr cls_;
540   Safe_PyObjectPtr attrs_;
541   Safe_PyObjectPtr iter_;
542 };
543 
IsSparseTensorValueType(PyObject * o)544 bool IsSparseTensorValueType(PyObject* o) {
545   PyObject* sparse_tensor_value_type =
546       GetRegisteredPyObject("SparseTensorValue");
547   if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
548     return false;
549   }
550 
551   return PyObject_TypeCheck(
552              o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
553 }
554 
555 // Returns 1 if `o` is an instance of CompositeTensor.
556 // Returns 0 otherwise.
557 // Returns -1 if an error occurred.
IsCompositeTensorHelper(PyObject * o)558 bool IsCompositeTensorHelper(PyObject* o) {
559   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
560     return IsInstanceOfRegisteredType(to_check, "CompositeTensor");
561   });
562   return check_cache->CachedLookup(o);
563 }
564 
565 // Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec or
566 // VariableSpec.
567 // Returns 0 otherwise.
568 // Returns -1 if an error occurred.
IsTypeSpecHelper(PyObject * o)569 bool IsTypeSpecHelper(PyObject* o) {
570   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
571     int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec");
572     int is_dense_spec = (IsInstanceOfRegisteredType(to_check, "TensorSpec") ||
573                          IsInstanceOfRegisteredType(to_check, "VariableSpec"));
574     if ((is_type_spec == -1) || (is_dense_spec == -1)) return -1;
575     return static_cast<int>(is_type_spec && !is_dense_spec);
576   });
577   return check_cache->CachedLookup(o);
578 }
579 
580 // Returns 1 if `o` is a (non-string) sequence or CompositeTensor or
581 // (non-TensorSpec and non-VariableSpec) TypeSpec.
582 // Returns 0 otherwise.
583 // Returns -1 if an error occurred.
IsSequenceOrCompositeHelper(PyObject * o)584 int IsSequenceOrCompositeHelper(PyObject* o) {
585   int is_sequence = IsSequenceHelper(o);
586   int is_composite = IsCompositeTensorHelper(o);
587   int is_type_spec = IsTypeSpecHelper(o);
588   if ((is_sequence == -1) || (is_composite == -1) || (is_type_spec == -1)) {
589     return -1;
590   }
591   return is_sequence || is_composite || is_type_spec;
592 }
593 
IsSequenceForDataHelper(PyObject * o)594 int IsSequenceForDataHelper(PyObject* o) {
595   return IsSequenceHelper(o) == 1 && !PyList_Check(o) &&
596          !IsSparseTensorValueType(o);
597 }
598 
GetValueIterator(PyObject * nested)599 ValueIteratorPtr GetValueIterator(PyObject* nested) {
600   if (PyDict_Check(nested)) {
601     return absl::make_unique<DictValueIterator>(nested);
602   } else if (IsMappingHelper(nested)) {
603     return absl::make_unique<MappingValueIterator>(nested);
604   } else if (IsAttrsHelper(nested)) {
605     return absl::make_unique<AttrsValueIterator>(nested);
606   } else {
607     return absl::make_unique<SequenceValueIterator>(nested);
608   }
609 }
610 
611 // Similar to above, just specialized for the functions in the data package.
GetValueIteratorForData(PyObject * nested)612 ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
613   if (PyDict_Check(nested)) {
614     return absl::make_unique<DictValueIterator>(nested);
615   } else if (IsMappingHelper(nested)) {
616     return absl::make_unique<MappingValueIterator>(nested);
617   } else if (IsAttrsHelper(nested)) {
618     return absl::make_unique<AttrsValueIterator>(nested);
619   } else if (IsSparseTensorValueType(nested)) {
620     return absl::make_unique<SingleValueIterator>(nested);
621   } else {
622     return absl::make_unique<SequenceValueIterator>(nested);
623   }
624 }
625 
626 // Similar to GetValueIterator above, but expands CompositeTensor and TypeSpec.
GetValueIteratorForComposite(PyObject * nested)627 ValueIteratorPtr GetValueIteratorForComposite(PyObject* nested) {
628   if (IsCompositeTensor(nested)) {
629     Safe_PyObjectPtr spec(PyObject_GetAttrString(nested, "_type_spec"));
630     if (PyErr_Occurred() || !spec) {
631       return absl::make_unique<ErrorValueIterator>();
632     }
633 
634     static char to_components[] = "_to_components";
635     static char argspec[] = "(O)";
636     Safe_PyObjectPtr components(
637         PyObject_CallMethod(spec.get(), to_components, argspec, nested));
638     if (PyErr_Occurred() || components == nullptr) {
639       return absl::make_unique<ErrorValueIterator>();
640     }
641     return absl::make_unique<SingleValueIterator>(components.get());
642   }
643 
644   if (IsTypeSpec(nested)) {
645     Safe_PyObjectPtr specs(PyObject_GetAttrString(nested, "_component_specs"));
646     if (PyErr_Occurred() || specs == nullptr) {
647       return absl::make_unique<ErrorValueIterator>();
648     }
649     return absl::make_unique<SingleValueIterator>(specs.get());
650   }
651 
652   return GetValueIterator(nested);
653 }
654 
FlattenHelper(PyObject * nested,PyObject * list,const std::function<int (PyObject *)> & is_sequence_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter)655 bool FlattenHelper(
656     PyObject* nested, PyObject* list,
657     const std::function<int(PyObject*)>& is_sequence_helper,
658     const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
659   // if nested is not a sequence, append itself and exit
660   int is_seq = is_sequence_helper(nested);
661   if (is_seq == -1) return false;
662   if (!is_seq) {
663     return PyList_Append(list, nested) != -1;
664   }
665 
666   ValueIteratorPtr iter = value_iterator_getter(nested);
667   if (!iter->valid()) return false;
668 
669   for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
670     if (Py_EnterRecursiveCall(" in flatten")) {
671       return false;
672     }
673     const bool success = FlattenHelper(item.get(), list, is_sequence_helper,
674                                        value_iterator_getter);
675     Py_LeaveRecursiveCall();
676     if (!success) {
677       return false;
678     }
679   }
680   return true;
681 }
682 
683 // Sets error using keys of 'dict1' and 'dict2'.
684 // 'dict1' and 'dict2' are assumed to be Python dictionaries.
SetDifferentKeysError(PyObject * dict1,PyObject * dict2,string * error_msg,bool * is_type_error)685 void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
686                            bool* is_type_error) {
687   Safe_PyObjectPtr k1(MappingKeys(dict1));
688   if (PyErr_Occurred() || k1.get() == nullptr) {
689     *error_msg =
690         ("The two dictionaries don't have the same set of keys. Failed to "
691          "fetch keys.");
692     return;
693   }
694   Safe_PyObjectPtr k2(MappingKeys(dict2));
695   if (PyErr_Occurred() || k2.get() == nullptr) {
696     *error_msg =
697         ("The two dictionaries don't have the same set of keys. Failed to "
698          "fetch keys.");
699     return;
700   }
701   *is_type_error = false;
702   *error_msg = tensorflow::strings::StrCat(
703       "The two dictionaries don't have the same set of keys. "
704       "First structure has keys ",
705       PyObjectToString(k1.get()), ", while second structure has keys ",
706       PyObjectToString(k2.get()));
707 }
708 
709 // Returns true iff there were no "internal" errors. In other words,
710 // errors that has nothing to do with structure checking.
711 // If an "internal" error occurred, the appropriate Python error will be
712 // set and the caller can propage it directly to the user.
713 //
714 // Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
715 // be empty.
716 // Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
717 // with appropriate error and sets `is_type_error` to true iff
718 // the error to be raised should be TypeError.
AssertSameStructureHelper(PyObject * o1,PyObject * o2,bool check_types,string * error_msg,bool * is_type_error,const std::function<int (PyObject *)> & is_sequence_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter,bool check_composite_tensor_type_spec)719 bool AssertSameStructureHelper(
720     PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
721     bool* is_type_error,
722     const std::function<int(PyObject*)>& is_sequence_helper,
723     const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter,
724     bool check_composite_tensor_type_spec) {
725   DCHECK(error_msg);
726   DCHECK(is_type_error);
727   const bool is_seq1 = is_sequence_helper(o1);
728   const bool is_seq2 = is_sequence_helper(o2);
729   if (PyErr_Occurred()) return false;
730   if (is_seq1 != is_seq2) {
731     string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
732     string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1);
733     *is_type_error = false;
734     *error_msg = tensorflow::strings::StrCat(
735         "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
736         non_seq_str, "\" is not");
737     return true;
738   }
739 
740   // Got to objects that are considered non-sequences. Note that in tf.data
741   // use case lists and sparse_tensors are not considered sequences. So finished
742   // checking, structures are the same.
743   if (!is_seq1) return true;
744 
745   if (check_types) {
746     // Treat wrapped tuples as tuples.
747     tensorflow::Safe_PyObjectPtr o1_wrapped;
748     if (IsObjectProxy(o1)) {
749       o1_wrapped.reset(PyObject_GetAttrString(o1, "__wrapped__"));
750       o1 = o1_wrapped.get();
751     }
752     tensorflow::Safe_PyObjectPtr o2_wrapped;
753     if (IsObjectProxy(o2)) {
754       o2_wrapped.reset(PyObject_GetAttrString(o2, "__wrapped__"));
755       o2 = o2_wrapped.get();
756     }
757 
758     const PyTypeObject* type1 = o1->ob_type;
759     const PyTypeObject* type2 = o2->ob_type;
760 
761     // We treat two different namedtuples with identical name and fields
762     // as having the same type.
763     const PyObject* o1_tuple = IsNamedtuple(o1, false);
764     if (o1_tuple == nullptr) return false;
765     const PyObject* o2_tuple = IsNamedtuple(o2, false);
766     if (o2_tuple == nullptr) {
767       Py_DECREF(o1_tuple);
768       return false;
769     }
770     bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
771     Py_DECREF(o1_tuple);
772     Py_DECREF(o2_tuple);
773 
774     if (both_tuples) {
775       const PyObject* same_tuples = SameNamedtuples(o1, o2);
776       if (same_tuples == nullptr) return false;
777       bool not_same_tuples = same_tuples != Py_True;
778       Py_DECREF(same_tuples);
779       if (not_same_tuples) {
780         *is_type_error = true;
781         *error_msg = tensorflow::strings::StrCat(
782             "The two namedtuples don't have the same sequence type. "
783             "First structure ",
784             PyObjectToString(o1), " has type ", type1->tp_name,
785             ", while second structure ", PyObjectToString(o2), " has type ",
786             type2->tp_name);
787         return true;
788       }
789     } else if (type1 != type2
790                /* If both sequences are list types, don't complain. This allows
791                   one to be a list subclass (e.g. _ListWrapper used for
792                   automatic dependency tracking.) */
793                && !(PyList_Check(o1) && PyList_Check(o2))
794                /* Two mapping types will also compare equal, making _DictWrapper
795                   and dict compare equal. */
796                && !(IsMappingHelper(o1) && IsMappingHelper(o2))
797                /* For CompositeTensor & TypeSpec, we check below. */
798                && !(check_composite_tensor_type_spec &&
799                     (IsCompositeTensor(o1) || IsCompositeTensor(o2)) &&
800                     (IsTypeSpec(o1) || IsTypeSpec(o2)))) {
801       *is_type_error = true;
802       *error_msg = tensorflow::strings::StrCat(
803           "The two namedtuples don't have the same sequence type. "
804           "First structure ",
805           PyObjectToString(o1), " has type ", type1->tp_name,
806           ", while second structure ", PyObjectToString(o2), " has type ",
807           type2->tp_name);
808       return true;
809     }
810 
811     if (PyDict_Check(o1) && PyDict_Check(o2)) {
812       if (PyDict_Size(o1) != PyDict_Size(o2)) {
813         SetDifferentKeysError(o1, o2, error_msg, is_type_error);
814         return true;
815       }
816 
817       PyObject* key;
818       Py_ssize_t pos = 0;
819       while (PyDict_Next(o1, &pos, &key, nullptr)) {
820         if (PyDict_GetItem(o2, key) == nullptr) {
821           SetDifferentKeysError(o1, o2, error_msg, is_type_error);
822           return true;
823         }
824       }
825     } else if (IsMappingHelper(o1)) {
826       // Fallback for custom mapping types. Instead of using PyDict methods
827       // which stay in C, we call iter(o1).
828       if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
829         SetDifferentKeysError(o1, o2, error_msg, is_type_error);
830         return true;
831       }
832 
833       Safe_PyObjectPtr iter(PyObject_GetIter(o1));
834       PyObject* key;
835       while ((key = PyIter_Next(iter.get())) != nullptr) {
836         if (!PyMapping_HasKey(o2, key)) {
837           SetDifferentKeysError(o1, o2, error_msg, is_type_error);
838           Py_DECREF(key);
839           return true;
840         }
841         Py_DECREF(key);
842       }
843     }
844   }
845 
846   if (check_composite_tensor_type_spec &&
847       (IsCompositeTensor(o1) || IsCompositeTensor(o2))) {
848     Safe_PyObjectPtr owned_type_spec_1;
849     PyObject* type_spec_1 = o1;
850     if (IsCompositeTensor(o1)) {
851       owned_type_spec_1.reset(PyObject_GetAttrString(o1, "_type_spec"));
852       type_spec_1 = owned_type_spec_1.get();
853     }
854 
855     Safe_PyObjectPtr owned_type_spec_2;
856     PyObject* type_spec_2 = o2;
857     if (IsCompositeTensor(o2)) {
858       owned_type_spec_2.reset(PyObject_GetAttrString(o2, "_type_spec"));
859       type_spec_2 = owned_type_spec_2.get();
860     }
861 
862     // Two composite tensors are considered to have the same structure if
863     // there is some type spec that is compatible with both of them.  Thus,
864     // we use most_specific_compatible_type(), and check if it raises an
865     // exception.  We do *not* use is_compatible_with, since that would
866     // prevent us from e.g. using a cond statement where the two sides have
867     // different shapes.
868     static char compatible_type[] = "most_specific_compatible_type";
869     static char argspec[] = "(O)";
870     Safe_PyObjectPtr struct_compatible(PyObject_CallMethod(
871         type_spec_1, compatible_type, argspec, type_spec_2));
872     if (PyErr_Occurred() || struct_compatible == nullptr) {
873       PyErr_Clear();
874       *is_type_error = false;
875       *error_msg = tensorflow::strings::StrCat(
876           "Incompatible CompositeTensor TypeSpecs: ",
877           PyObjectToString(type_spec_1), " vs. ",
878           PyObjectToString(type_spec_2));
879       return true;
880     }
881   }
882 
883   ValueIteratorPtr iter1 = value_iterator_getter(o1);
884   ValueIteratorPtr iter2 = value_iterator_getter(o2);
885 
886   if (!iter1->valid() || !iter2->valid()) return false;
887 
888   while (true) {
889     Safe_PyObjectPtr v1 = iter1->next();
890     Safe_PyObjectPtr v2 = iter2->next();
891     if (v1 && v2) {
892       if (Py_EnterRecursiveCall(" in assert_same_structure")) {
893         return false;
894       }
895       bool no_internal_errors = AssertSameStructureHelper(
896           v1.get(), v2.get(), check_types, error_msg, is_type_error,
897           is_sequence_helper, value_iterator_getter,
898           check_composite_tensor_type_spec);
899       Py_LeaveRecursiveCall();
900       if (!no_internal_errors) return false;
901       if (!error_msg->empty()) return true;
902     } else if (!v1 && !v2) {
903       // Done with all recursive calls. Structure matched.
904       return true;
905     } else {
906       *is_type_error = false;
907       *error_msg = tensorflow::strings::StrCat(
908           "The two structures don't have the same number of elements. ",
909           "First structure: ", PyObjectToString(o1),
910           ". Second structure: ", PyObjectToString(o2));
911       return true;
912     }
913   }
914 }
915 
916 }  // namespace
917 
IsSequence(PyObject * o)918 bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
IsMapping(PyObject * o)919 bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
IsMutableMapping(PyObject * o)920 bool IsMutableMapping(PyObject* o) { return IsMutableMappingHelper(o) == 1; }
IsMappingView(PyObject * o)921 bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; }
IsAttrs(PyObject * o)922 bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
IsTensor(PyObject * o)923 bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
IsEagerTensorSlow(PyObject * o)924 bool IsEagerTensorSlow(PyObject* o) { return IsEagerTensorHelper(o) == 1; }
IsResourceVariable(PyObject * o)925 bool IsResourceVariable(PyObject* o) {
926   return IsResourceVariableHelper(o) == 1;
927 }
IsVariable(PyObject * o)928 bool IsVariable(PyObject* o) { return IsVariableHelper(o) == 1; }
IsIndexedSlices(PyObject * o)929 bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
IsDispatchable(PyObject * o)930 bool IsDispatchable(PyObject* o) { return IsDispatchableHelper(o) == 1; }
931 
IsTuple(PyObject * o)932 bool IsTuple(PyObject* o) {
933   tensorflow::Safe_PyObjectPtr wrapped;
934   if (IsObjectProxy(o)) {
935     wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
936     o = wrapped.get();
937   }
938   return PyTuple_Check(o);
939 }
940 
941 // Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
942 // and while we're at it give them consistent behavior by making sure the
943 // returned value is a list.
944 //
945 // As with PyMapping_Keys, returns a new reference.
946 //
947 // On failure, returns nullptr.
MappingKeys(PyObject * o)948 PyObject* MappingKeys(PyObject* o) {
949 #if PY_MAJOR_VERSION >= 3
950   return PyMapping_Keys(o);
951 #else
952   static char key_method_name[] = "keys";
953   Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
954   if (PyErr_Occurred() || raw_result.get() == nullptr) {
955     return nullptr;
956   }
957   return PySequence_Fast(
958       raw_result.get(),
959       "The '.keys()' method of a custom mapping returned a non-sequence.");
960 #endif
961 }
962 
Flatten(PyObject * nested,bool expand_composites)963 PyObject* Flatten(PyObject* nested, bool expand_composites) {
964   PyObject* list = PyList_New(0);
965   const std::function<int(PyObject*)>& is_sequence_helper =
966       expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
967   const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
968       expand_composites ? GetValueIteratorForComposite : GetValueIterator;
969   if (FlattenHelper(nested, list, is_sequence_helper, get_value_iterator)) {
970     return list;
971   } else {
972     Py_DECREF(list);
973     return nullptr;
974   }
975 }
976 
IsSequenceOrComposite(PyObject * o)977 bool IsSequenceOrComposite(PyObject* o) {
978   return IsSequenceOrCompositeHelper(o) == 1;
979 }
980 
IsCompositeTensor(PyObject * o)981 bool IsCompositeTensor(PyObject* o) { return IsCompositeTensorHelper(o) == 1; }
982 
IsTypeSpec(PyObject * o)983 bool IsTypeSpec(PyObject* o) { return IsTypeSpecHelper(o) == 1; }
984 
IsSequenceForData(PyObject * o)985 bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; }
986 
FlattenForData(PyObject * nested)987 PyObject* FlattenForData(PyObject* nested) {
988   PyObject* list = PyList_New(0);
989   if (FlattenHelper(nested, list, IsSequenceForDataHelper,
990                     GetValueIteratorForData)) {
991     return list;
992   } else {
993     Py_DECREF(list);
994     return nullptr;
995   }
996 }
997 
IsNamedtuple(PyObject * o,bool strict)998 PyObject* IsNamedtuple(PyObject* o, bool strict) {
999   // Some low-level CPython calls do not work with wrapt.ObjectProxy, so they
1000   // require some unwrapping if we want to treat them like the objects they're
1001   // wrapping.
1002   tensorflow::Safe_PyObjectPtr o_wrapped;
1003   if (IsObjectProxy(o)) {
1004     o_wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
1005     o = o_wrapped.get();
1006   }
1007 
1008   // Must be subclass of tuple
1009   if (!PyTuple_Check(o)) {
1010     Py_RETURN_FALSE;
1011   }
1012 
1013   // If strict, o.__class__.__base__ must be tuple
1014   if (strict) {
1015     PyObject* klass = PyObject_GetAttrString(o, "__class__");
1016     if (klass == nullptr) return nullptr;
1017     PyObject* base = PyObject_GetAttrString(klass, "__base__");
1018     Py_DECREF(klass);
1019     if (base == nullptr) return nullptr;
1020 
1021     const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
1022     // built-in object types are singletons
1023     bool tuple_base = base_type == &PyTuple_Type;
1024     Py_DECREF(base);
1025     if (!tuple_base) {
1026       Py_RETURN_FALSE;
1027     }
1028   }
1029 
1030   // o must have attribute '_fields' and every element in
1031   // '_fields' must be a string.
1032   int has_fields = PyObject_HasAttrString(o, "_fields");
1033   if (!has_fields) {
1034     Py_RETURN_FALSE;
1035   }
1036 
1037   Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
1038   int is_instance = IsInstanceOfRegisteredType(fields.get(), "Sequence");
1039   if (is_instance == 0) {
1040     Py_RETURN_FALSE;
1041   } else if (is_instance == -1) {
1042     return nullptr;
1043   }
1044 
1045   Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
1046   const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
1047   for (Py_ssize_t i = 0; i < s; ++i) {
1048     // PySequence_Fast_GET_ITEM returns borrowed ref
1049     PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
1050     if (!IsString(elem)) {
1051       Py_RETURN_FALSE;
1052     }
1053   }
1054 
1055   Py_RETURN_TRUE;
1056 }
1057 
SameNamedtuples(PyObject * o1,PyObject * o2)1058 PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
1059   Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields"));
1060   Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields"));
1061   if (f1 == nullptr || f2 == nullptr) {
1062     PyErr_SetString(
1063         PyExc_RuntimeError,
1064         "Expected namedtuple-like objects (that have _fields attr)");
1065     return nullptr;
1066   }
1067 
1068   if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) {
1069     Py_RETURN_FALSE;
1070   }
1071 
1072   if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
1073     Py_RETURN_TRUE;
1074   } else {
1075     Py_RETURN_FALSE;
1076   }
1077 }
1078 
AssertSameStructure(PyObject * o1,PyObject * o2,bool check_types,bool expand_composites)1079 PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types,
1080                               bool expand_composites) {
1081   const std::function<int(PyObject*)>& is_sequence_helper =
1082       expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
1083   const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
1084       expand_composites ? GetValueIteratorForComposite : GetValueIterator;
1085   const bool check_composite_tensor_type_spec = expand_composites;
1086   string error_msg;
1087   bool is_type_error = false;
1088   AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1089                             is_sequence_helper, get_value_iterator,
1090                             check_composite_tensor_type_spec);
1091   if (PyErr_Occurred()) {
1092     // Don't hide Python exceptions while checking (e.g. errors fetching keys
1093     // from custom mappings).
1094     return nullptr;
1095   }
1096   if (!error_msg.empty()) {
1097     PyErr_SetString(
1098         is_type_error ? PyExc_TypeError : PyExc_ValueError,
1099         tensorflow::strings::StrCat(
1100             "The two structures don't have the same nested structure.\n\n",
1101             "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1102             PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1103             .c_str());
1104     return nullptr;
1105   }
1106   Py_RETURN_NONE;
1107 }
1108 
AssertSameStructureForData(PyObject * o1,PyObject * o2,bool check_types)1109 PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
1110                                      bool check_types) {
1111   string error_msg;
1112   bool is_type_error = false;
1113   AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1114                             IsSequenceForDataHelper, GetValueIterator, false);
1115   if (PyErr_Occurred()) {
1116     // Don't hide Python exceptions while checking (e.g. errors fetching keys
1117     // from custom mappings).
1118     return nullptr;
1119   }
1120   if (!error_msg.empty()) {
1121     PyErr_SetString(
1122         is_type_error ? PyExc_TypeError : PyExc_ValueError,
1123         tensorflow::strings::StrCat(
1124             "The two structures don't have the same nested structure.\n\n",
1125             "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1126             PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1127             .c_str());
1128     return nullptr;
1129   }
1130   Py_RETURN_NONE;
1131 }
1132 
1133 }  // namespace swig
1134 }  // namespace tensorflow
1135