1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/util/util.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/python/lib/core/safe_ptr.h"
28 
29 namespace tensorflow {
30 namespace swig {
31 
PythonTypesMap()32 std::unordered_map<string, PyObject*>* PythonTypesMap() {
33   static auto* m = new std::unordered_map<string, PyObject*>();
34   return m;
35 }
36 
GetRegisteredType(const string & key)37 PyObject* GetRegisteredType(const string& key) {
38   auto* m = PythonTypesMap();
39   auto it = m->find(key);
40   if (it == m->end()) return nullptr;
41   return it->second;
42 }
43 
RegisterType(PyObject * type_name,PyObject * type)44 PyObject* RegisterType(PyObject* type_name, PyObject* type) {
45   if (!PyType_Check(type)) {
46     PyErr_SetString(PyExc_TypeError,
47                     tensorflow::strings::StrCat("Expecting a type, got ",
48                                                 Py_TYPE(type)->tp_name)
49                         .c_str());
50     return nullptr;
51   }
52 
53   string key;
54   if (PyBytes_Check(type_name)) {
55     key = PyBytes_AsString(type_name);
56   }
57 #if PY_MAJOR_VERSION >= 3
58   if (PyUnicode_Check(type_name)) {
59     key = PyUnicode_AsUTF8(type_name);
60   }
61 #endif
62 
63   if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) {
64     PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
65                                          "Type already registered for ", key)
66                                          .c_str());
67     return nullptr;
68   }
69 
70   Py_INCREF(type);
71   PythonTypesMap()->emplace(key, type);
72 
73   Py_RETURN_NONE;
74 }
75 
76 namespace {
77 const int kMaxItemsInCache = 1024;
78 
79 bool WarnedThatSetIsNotSequence = false;
80 
IsString(PyObject * o)81 bool IsString(PyObject* o) {
82   return PyBytes_Check(o) ||
83 #if PY_MAJOR_VERSION < 3
84          PyString_Check(o) ||
85 #endif
86          PyUnicode_Check(o);
87 }
88 
89 // Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
90 // and while we're at it give them consistent behavior by making sure the
91 // returned value is a list.
92 //
93 // As with PyMapping_Keys, returns a new reference.
94 //
95 // On failure, returns nullptr.
MappingKeys(PyObject * o)96 PyObject* MappingKeys(PyObject* o) {
97 #if PY_MAJOR_VERSION >= 3
98   return PyMapping_Keys(o);
99 #else
100   static char key_method_name[] = "keys";
101   Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
102   if (PyErr_Occurred() || raw_result.get() == nullptr) {
103     return nullptr;
104   }
105   return PySequence_Fast(
106       raw_result.get(),
107       "The '.keys()' method of a custom mapping returned a non-sequence.");
108 #endif
109 }
110 
111 // Equivalent to Python's 'o.__class__.__name__'
112 // Note that '__class__' attribute is set only in new-style classes.
113 // A lot of tensorflow code uses __class__ without checks, so it seems like
114 // we only support new-style classes.
GetClassName(PyObject * o)115 StringPiece GetClassName(PyObject* o) {
116   // __class__ is equivalent to type() for new style classes.
117   // type() is equivalent to PyObject_Type()
118   // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
119   // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
120   // we don't need here.
121   PyTypeObject* type = o->ob_type;
122 
123   // __name__ is the value of `tp_name` after the last '.'
124   // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
125   StringPiece name(type->tp_name);
126   size_t pos = name.rfind('.');
127   if (pos != StringPiece::npos) {
128     name.remove_prefix(pos + 1);
129   }
130   return name;
131 }
132 
PyObjectToString(PyObject * o)133 string PyObjectToString(PyObject* o) {
134   if (o == nullptr) {
135     return "<null object>";
136   }
137   PyObject* str = PyObject_Str(o);
138   if (str) {
139 #if PY_MAJOR_VERSION < 3
140     string s(PyString_AS_STRING(str));
141 #else
142     string s(PyUnicode_AsUTF8(str));
143 #endif
144     Py_DECREF(str);
145     return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
146   } else {
147     return "<failed to execute str() on object>";
148   }
149 }
150 
151 class CachedTypeCheck {
152  public:
CachedTypeCheck(std::function<int (PyObject *)> ternary_predicate)153   explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
154       : ternary_predicate_(std::move(ternary_predicate)) {}
155 
~CachedTypeCheck()156   ~CachedTypeCheck() {
157     mutex_lock l(type_to_sequence_map_mu_);
158     for (const auto& pair : type_to_sequence_map_) {
159       Py_DECREF(pair.first);
160     }
161   }
162 
163   // Caches successful executions of the one-argument (PyObject*) callable
164   // "ternary_predicate" based on the type of "o". -1 from the callable
165   // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
166   // does not match the predicate, and 1 indicates that it does. Used to avoid
167   // calling back into Python for expensive isinstance checks.
CachedLookup(PyObject * o)168   int CachedLookup(PyObject* o) {
169     // Try not to return to Python - see if the type has already been seen
170     // before.
171 
172     auto* type = Py_TYPE(o);
173 
174     {
175       tf_shared_lock l(type_to_sequence_map_mu_);
176       auto it = type_to_sequence_map_.find(type);
177       if (it != type_to_sequence_map_.end()) {
178         return it->second;
179       }
180     }
181 
182     int check_result = ternary_predicate_(o);
183 
184     if (check_result == -1) {
185       return -1;  // Type check error, not cached.
186     }
187 
188     // NOTE: This is never decref'd as long as the object lives, which is likely
189     // forever, but we don't want the type to get deleted as long as it is in
190     // the map. This should not be too much of a leak, as there should only be a
191     // relatively small number of types in the map, and an even smaller number
192     // that are eligible for decref. As a precaution, we limit the size of the
193     // map to 1024.
194     {
195       mutex_lock l(type_to_sequence_map_mu_);
196       if (type_to_sequence_map_.size() < kMaxItemsInCache) {
197         Py_INCREF(type);
198         auto insert_result = type_to_sequence_map_.insert({type, check_result});
199         if (!insert_result.second) {
200           // The type was added to the cache by a concurrent thread after we
201           // looked it up above.
202           Py_DECREF(type);
203         }
204       }
205     }
206 
207     return check_result;
208   }
209 
210  private:
211   std::function<int(PyObject*)> ternary_predicate_;
212   mutex type_to_sequence_map_mu_;
213   std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
214       GUARDED_BY(type_to_sequence_map_mu_);
215 };
216 
217 // Returns 1 if `o` is considered a mapping for the purposes of Flatten().
218 // Returns 0 otherwise.
219 // Returns -1 if an error occurred.
IsMappingHelper(PyObject * o)220 int IsMappingHelper(PyObject* o) {
221   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
222     PyObject* collections_mapping_type = GetRegisteredType("Mapping");
223     if (TF_PREDICT_FALSE(collections_mapping_type == nullptr)) {
224       PyErr_SetString(PyExc_RuntimeError,
225                       tensorflow::strings::StrCat(
226                           "collections.Mapping type has not been set. "
227                           "Please register the type with the identifier "
228                           "\"Mapping\" using RegisterType.")
229                           .c_str());
230       return -1;
231     }
232     return PyObject_IsInstance(to_check, collections_mapping_type);
233   });
234   if (PyDict_Check(o)) return true;
235   return check_cache->CachedLookup(o);
236 }
237 
238 // Returns 1 if `o` is an instance of attrs-decorated class.
239 // Returns 0 otherwise.
IsAttrsHelper(PyObject * o)240 int IsAttrsHelper(PyObject* o) {
241   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
242     Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
243     if (cls) {
244       return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
245     }
246 
247     // PyObject_GetAttrString returns null on error
248     PyErr_Clear();
249     return 0;
250   });
251   return check_cache->CachedLookup(o);
252 }
253 
254 // Returns 1 if `o` is an object of type IndexedSlices.
255 // Returns 0 otherwise.
256 // Returns -1 if an error occurred.
IsIndexedSlicesHelper(PyObject * o)257 int IsIndexedSlicesHelper(PyObject* o) {
258   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
259     PyObject* indexed_slices_type = GetRegisteredType("IndexedSlices");
260     if (TF_PREDICT_FALSE(indexed_slices_type == nullptr)) {
261       PyErr_SetString(PyExc_RuntimeError,
262                       tensorflow::strings::StrCat(
263                           "IndexedSlices type has not been set. "
264                           "Please register the type with the identifier "
265                           "\"IndexedSlices\" using RegisterType.")
266                           .c_str());
267       return -1;
268     }
269     return PyObject_IsInstance(to_check, indexed_slices_type);
270   });
271   return check_cache->CachedLookup(o);
272 }
273 
274 // Returns 1 if `o` is a Tensor.
275 // Returns 0 otherwise.
276 // Returns -1 if an error occurred.
IsTensorHelper(PyObject * o)277 int IsTensorHelper(PyObject* o) {
278   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
279     PyObject* tensor_type = GetRegisteredType("Tensor");
280     if (TF_PREDICT_FALSE(tensor_type == nullptr)) {
281       PyErr_SetString(PyExc_RuntimeError,
282                       tensorflow::strings::StrCat(
283                           "Tensor type has not been set. "
284                           "Please register the type with the identifier "
285                           "\"Tensor\" using RegisterType.")
286                           .c_str());
287       return -1;
288     }
289     return PyObject_IsInstance(to_check, tensor_type);
290   });
291   return check_cache->CachedLookup(o);
292 }
293 
294 // Returns 1 if `o` is considered a sequence for the purposes of Flatten().
295 // Returns 0 otherwise.
296 // Returns -1 if an error occurred.
IsSequenceHelper(PyObject * o)297 int IsSequenceHelper(PyObject* o) {
298   // We treat dicts and other mappings as special cases of sequences.
299   if (IsMappingHelper(o)) return true;
300   if (IsAttrsHelper(o)) return true;
301   if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
302     LOG(WARNING) << "Sets are not currently considered sequences, "
303                     "but this may change in the future, "
304                     "so consider avoiding using them.";
305     WarnedThatSetIsNotSequence = true;
306   }
307   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
308     PyObject* collections_sequence_type = GetRegisteredType("Sequence");
309     if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) {
310       PyErr_SetString(PyExc_RuntimeError,
311                       tensorflow::strings::StrCat(
312                           "collections.Sequence type has not been set. "
313                           "Please register the type with the identifier "
314                           "\"Sequence\" using RegisterType.")
315                           .c_str());
316       return -1;
317     }
318     int is_instance = PyObject_IsInstance(to_check, collections_sequence_type);
319 
320     // Don't cache a failed is_instance check.
321     if (is_instance == -1) return -1;
322 
323     return static_cast<int>(is_instance != 0 && !IsString(to_check));
324   });
325   return check_cache->CachedLookup(o);
326 }
327 
328 // ValueIterator interface
329 class ValueIterator {
330  public:
~ValueIterator()331   virtual ~ValueIterator() {}
332   virtual Safe_PyObjectPtr next() = 0;
333 
valid() const334   bool valid() const { return is_valid_; }
335 
336  protected:
invalidate()337   void invalidate() { is_valid_ = false; }
338 
339  private:
340   bool is_valid_ = true;
341 };
342 
343 using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
344 
345 // Iterate through dictionaries in a deterministic order by sorting the
346 // keys. Notice this means that we ignore the original order of
347 // `OrderedDict` instances. This is intentional, to avoid potential
348 // bugs caused by mixing ordered and plain dicts (e.g., flattening
349 // a dict but using a corresponding `OrderedDict` to pack it back).
350 class DictValueIterator : public ValueIterator {
351  public:
DictValueIterator(PyObject * dict)352   explicit DictValueIterator(PyObject* dict)
353       : dict_(dict), keys_(PyDict_Keys(dict)) {
354     if (PyList_Sort(keys_.get()) == -1) {
355       invalidate();
356     } else {
357       iter_.reset(PyObject_GetIter(keys_.get()));
358     }
359   }
360 
next()361   Safe_PyObjectPtr next() override {
362     Safe_PyObjectPtr result;
363     Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
364     if (key) {
365       // PyDict_GetItem returns a borrowed reference.
366       PyObject* elem = PyDict_GetItem(dict_, key.get());
367       if (elem) {
368         Py_INCREF(elem);
369         result.reset(elem);
370       } else {
371         PyErr_SetString(PyExc_RuntimeError,
372                         "Dictionary was modified during iteration over it");
373       }
374     }
375     return result;
376   }
377 
378  private:
379   PyObject* dict_;
380   Safe_PyObjectPtr keys_;
381   Safe_PyObjectPtr iter_;
382 };
383 
384 // Iterate over mapping objects by sorting the keys first
385 class MappingValueIterator : public ValueIterator {
386  public:
MappingValueIterator(PyObject * mapping)387   explicit MappingValueIterator(PyObject* mapping)
388       : mapping_(mapping), keys_(MappingKeys(mapping)) {
389     if (!keys_ || PyList_Sort(keys_.get()) == -1) {
390       invalidate();
391     } else {
392       iter_.reset(PyObject_GetIter(keys_.get()));
393     }
394   }
395 
next()396   Safe_PyObjectPtr next() override {
397     Safe_PyObjectPtr result;
398     Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
399     if (key) {
400       // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
401       PyObject* elem = PyObject_GetItem(mapping_, key.get());
402       if (elem) {
403         result.reset(elem);
404       } else {
405         PyErr_SetString(PyExc_RuntimeError,
406                         "Mapping was modified during iteration over it");
407       }
408     }
409     return result;
410   }
411 
412  private:
413   PyObject* mapping_;
414   Safe_PyObjectPtr keys_;
415   Safe_PyObjectPtr iter_;
416 };
417 
418 // Iterate over a sequence, by index.
419 class SequenceValueIterator : public ValueIterator {
420  public:
SequenceValueIterator(PyObject * iterable)421   explicit SequenceValueIterator(PyObject* iterable)
422       : seq_(PySequence_Fast(iterable, "")),
423         size_(PySequence_Fast_GET_SIZE(seq_.get())),
424         index_(0) {}
425 
next()426   Safe_PyObjectPtr next() override {
427     Safe_PyObjectPtr result;
428     if (index_ < size_) {
429       // PySequence_Fast_GET_ITEM returns a borrowed reference.
430       PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
431       ++index_;
432       Py_INCREF(elem);
433       result.reset(elem);
434     }
435 
436     return result;
437   }
438 
439  private:
440   Safe_PyObjectPtr seq_;
441   const Py_ssize_t size_;
442   Py_ssize_t index_;
443 };
444 
445 // Just return itself as a single item.
446 class SparseTensorValueIterator : public ValueIterator {
447  public:
SparseTensorValueIterator(PyObject * tensor)448   explicit SparseTensorValueIterator(PyObject* tensor) : tensor_(tensor) {
449     Py_INCREF(tensor);
450   }
451 
next()452   Safe_PyObjectPtr next() override { return std::move(tensor_); }
453 
454  private:
455   Safe_PyObjectPtr tensor_;
456 };
457 
458 // Returns nullptr (to raise an exception) when next() is called.  Caller
459 // should have already called PyErr_SetString.
460 class ErrorValueIterator : public ValueIterator {
461  public:
ErrorValueIterator()462   ErrorValueIterator() {}
next()463   Safe_PyObjectPtr next() override { return nullptr; }
464 };
465 
466 class AttrsValueIterator : public ValueIterator {
467  public:
AttrsValueIterator(PyObject * nested)468   explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
469     Py_INCREF(nested);
470     cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
471     if (cls_) {
472       attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
473       if (attrs_) {
474         iter_.reset(PyObject_GetIter(attrs_.get()));
475       }
476     }
477     if (!iter_ || PyErr_Occurred()) invalidate();
478   }
479 
next()480   Safe_PyObjectPtr next() override {
481     Safe_PyObjectPtr result;
482     Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
483     if (item) {
484       Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
485       result.reset(PyObject_GetAttr(nested_.get(), name.get()));
486     }
487 
488     return result;
489   }
490 
491  private:
492   Safe_PyObjectPtr nested_;
493   Safe_PyObjectPtr cls_;
494   Safe_PyObjectPtr attrs_;
495   Safe_PyObjectPtr iter_;
496 };
497 
IsSparseTensorValueType(PyObject * o)498 bool IsSparseTensorValueType(PyObject* o) {
499   PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue");
500   if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
501     return false;
502   }
503 
504   return PyObject_TypeCheck(
505              o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
506 }
507 
508 // Returns 1 if `o` is an instance of CompositeTensor.
509 // Returns 0 otherwise.
510 // Returns -1 if an error occurred.
IsCompositeTensorHelper(PyObject * o)511 bool IsCompositeTensorHelper(PyObject* o) {
512   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
513     PyObject* composite_tensor_type = GetRegisteredType("CompositeTensor");
514     if (TF_PREDICT_FALSE(composite_tensor_type == nullptr)) {
515       PyErr_SetString(PyExc_RuntimeError,
516                       tensorflow::strings::StrCat(
517                           "CompositeTensor type has not been set. "
518                           "Please register the type with the identifier "
519                           "\"CompositeTensor\" using RegisterType.")
520                           .c_str());
521       return -1;
522     }
523     int is_instance = PyObject_IsInstance(to_check, composite_tensor_type);
524 
525     // Don't cache a failed is_instance check.
526     if (is_instance == -1) return -1;
527 
528     return static_cast<int>(is_instance != 0);
529   });
530   return check_cache->CachedLookup(o);
531 }
532 
IsSequenceOrCompositeHelper(PyObject * o)533 int IsSequenceOrCompositeHelper(PyObject* o) {
534   return IsSequence(o) || IsCompositeTensor(o);
535 }
536 
IsSequenceForDataHelper(PyObject * o)537 int IsSequenceForDataHelper(PyObject* o) {
538   return IsSequenceHelper(o) == 1 && !PyList_Check(o) &&
539          !IsSparseTensorValueType(o);
540 }
541 
GetValueIterator(PyObject * nested)542 ValueIteratorPtr GetValueIterator(PyObject* nested) {
543   if (PyDict_Check(nested)) {
544     return absl::make_unique<DictValueIterator>(nested);
545   } else if (IsMappingHelper(nested)) {
546     return absl::make_unique<MappingValueIterator>(nested);
547   } else if (IsAttrsHelper(nested)) {
548     return absl::make_unique<AttrsValueIterator>(nested);
549   } else {
550     return absl::make_unique<SequenceValueIterator>(nested);
551   }
552 }
553 
554 // Similar to above, just specialized for the functions in the data package.
GetValueIteratorForData(PyObject * nested)555 ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
556   if (PyDict_Check(nested)) {
557     return absl::make_unique<DictValueIterator>(nested);
558   } else if (IsMappingHelper(nested)) {
559     return absl::make_unique<MappingValueIterator>(nested);
560   } else if (IsAttrsHelper(nested)) {
561     return absl::make_unique<AttrsValueIterator>(nested);
562   } else if (IsSparseTensorValueType(nested)) {
563     return absl::make_unique<SparseTensorValueIterator>(nested);
564   } else {
565     return absl::make_unique<SequenceValueIterator>(nested);
566   }
567 }
568 
569 // Similar to GetValueIterator above, but expands CompositeTensors.
GetValueIteratorForComposite(PyObject * nested)570 ValueIteratorPtr GetValueIteratorForComposite(PyObject* nested) {
571   if (IsCompositeTensor(nested)) {
572     static char expand_method_name[] = "_to_components";
573     nested = PyObject_CallMethod(nested, expand_method_name, nullptr);
574     if (PyErr_Occurred() || nested == nullptr) {
575       return absl::make_unique<ErrorValueIterator>();
576     }
577   }
578   return GetValueIterator(nested);
579 }
580 
FlattenHelper(PyObject * nested,PyObject * list,const std::function<int (PyObject *)> & is_sequence_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter)581 bool FlattenHelper(
582     PyObject* nested, PyObject* list,
583     const std::function<int(PyObject*)>& is_sequence_helper,
584     const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
585   // if nested is not a sequence, append itself and exit
586   int is_seq = is_sequence_helper(nested);
587   if (is_seq == -1) return false;
588   if (!is_seq) {
589     return PyList_Append(list, nested) != -1;
590   }
591 
592   ValueIteratorPtr iter = value_iterator_getter(nested);
593   if (!iter->valid()) return false;
594 
595   for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
596     if (Py_EnterRecursiveCall(" in flatten")) {
597       return false;
598     }
599     const bool success = FlattenHelper(item.get(), list, is_sequence_helper,
600                                        value_iterator_getter);
601     Py_LeaveRecursiveCall();
602     if (!success) {
603       return false;
604     }
605   }
606   return true;
607 }
608 
609 // Sets error using keys of 'dict1' and 'dict2'.
610 // 'dict1' and 'dict2' are assumed to be Python dictionaries.
SetDifferentKeysError(PyObject * dict1,PyObject * dict2,string * error_msg,bool * is_type_error)611 void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
612                            bool* is_type_error) {
613   Safe_PyObjectPtr k1(MappingKeys(dict1));
614   if (PyErr_Occurred() || k1.get() == nullptr) {
615     *error_msg =
616         ("The two dictionaries don't have the same set of keys. Failed to "
617          "fetch keys.");
618     return;
619   }
620   Safe_PyObjectPtr k2(MappingKeys(dict2));
621   if (PyErr_Occurred() || k2.get() == nullptr) {
622     *error_msg =
623         ("The two dictionaries don't have the same set of keys. Failed to "
624          "fetch keys.");
625     return;
626   }
627   *is_type_error = false;
628   *error_msg = tensorflow::strings::StrCat(
629       "The two dictionaries don't have the same set of keys. "
630       "First structure has keys ",
631       PyObjectToString(k1.get()), ", while second structure has keys ",
632       PyObjectToString(k2.get()));
633 }
634 
635 // Returns true iff there were no "internal" errors. In other words,
636 // errors that has nothing to do with structure checking.
637 // If an "internal" error occurred, the appropriate Python error will be
638 // set and the caller can propage it directly to the user.
639 //
640 // Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
641 // be empty.
642 // Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
643 // with appropriate error and sets `is_type_error` to true iff
644 // the error to be raised should be TypeError.
AssertSameStructureHelper(PyObject * o1,PyObject * o2,bool check_types,string * error_msg,bool * is_type_error,const std::function<int (PyObject *)> & is_sequence_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter)645 bool AssertSameStructureHelper(
646     PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
647     bool* is_type_error,
648     const std::function<int(PyObject*)>& is_sequence_helper,
649     const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
650   DCHECK(error_msg);
651   DCHECK(is_type_error);
652   const bool is_seq1 = is_sequence_helper(o1);
653   const bool is_seq2 = is_sequence_helper(o2);
654   if (PyErr_Occurred()) return false;
655   if (is_seq1 != is_seq2) {
656     string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
657     string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1);
658     *is_type_error = false;
659     *error_msg = tensorflow::strings::StrCat(
660         "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
661         non_seq_str, "\" is not");
662     return true;
663   }
664 
665   // Got to objects that are considered non-sequences. Note that in tf.data
666   // use case lists and sparse_tensors are not considered sequences. So finished
667   // checking, structures are the same.
668   if (!is_seq1) return true;
669 
670   if (check_types) {
671     const PyTypeObject* type1 = o1->ob_type;
672     const PyTypeObject* type2 = o2->ob_type;
673 
674     // We treat two different namedtuples with identical name and fields
675     // as having the same type.
676     const PyObject* o1_tuple = IsNamedtuple(o1, true);
677     if (o1_tuple == nullptr) return false;
678     const PyObject* o2_tuple = IsNamedtuple(o2, true);
679     if (o2_tuple == nullptr) {
680       Py_DECREF(o1_tuple);
681       return false;
682     }
683     bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
684     Py_DECREF(o1_tuple);
685     Py_DECREF(o2_tuple);
686 
687     if (both_tuples) {
688       const PyObject* same_tuples = SameNamedtuples(o1, o2);
689       if (same_tuples == nullptr) return false;
690       bool not_same_tuples = same_tuples != Py_True;
691       Py_DECREF(same_tuples);
692       if (not_same_tuples) {
693         *is_type_error = true;
694         *error_msg = tensorflow::strings::StrCat(
695             "The two namedtuples don't have the same sequence type. "
696             "First structure ",
697             PyObjectToString(o1), " has type ", type1->tp_name,
698             ", while second structure ", PyObjectToString(o2), " has type ",
699             type2->tp_name);
700         return true;
701       }
702     } else if (type1 != type2
703                /* If both sequences are list types, don't complain. This allows
704                   one to be a list subclass (e.g. _ListWrapper used for
705                   automatic dependency tracking.) */
706                && !(PyList_Check(o1) && PyList_Check(o2))
707                /* Two mapping types will also compare equal, making _DictWrapper
708                   and dict compare equal. */
709                && !(IsMappingHelper(o1) && IsMappingHelper(o2))) {
710       *is_type_error = true;
711       *error_msg = tensorflow::strings::StrCat(
712           "The two namedtuples don't have the same sequence type. "
713           "First structure ",
714           PyObjectToString(o1), " has type ", type1->tp_name,
715           ", while second structure ", PyObjectToString(o2), " has type ",
716           type2->tp_name);
717       return true;
718     }
719 
720     if (PyDict_Check(o1) && PyDict_Check(o2)) {
721       if (PyDict_Size(o1) != PyDict_Size(o2)) {
722         SetDifferentKeysError(o1, o2, error_msg, is_type_error);
723         return true;
724       }
725 
726       PyObject* key;
727       Py_ssize_t pos = 0;
728       while (PyDict_Next(o1, &pos, &key, nullptr)) {
729         if (PyDict_GetItem(o2, key) == nullptr) {
730           SetDifferentKeysError(o1, o2, error_msg, is_type_error);
731           return true;
732         }
733       }
734     } else if (IsMappingHelper(o1)) {
735       // Fallback for custom mapping types. Instead of using PyDict methods
736       // which stay in C, we call iter(o1).
737       if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
738         SetDifferentKeysError(o1, o2, error_msg, is_type_error);
739         return true;
740       }
741 
742       Safe_PyObjectPtr iter(PyObject_GetIter(o1));
743       PyObject* key;
744       while ((key = PyIter_Next(iter.get())) != nullptr) {
745         if (!PyMapping_HasKey(o2, key)) {
746           SetDifferentKeysError(o1, o2, error_msg, is_type_error);
747           Py_DECREF(key);
748           return true;
749         }
750         Py_DECREF(key);
751       }
752     }
753   }
754 
755   ValueIteratorPtr iter1 = value_iterator_getter(o1);
756   ValueIteratorPtr iter2 = value_iterator_getter(o2);
757 
758   if (!iter1->valid() || !iter2->valid()) return false;
759 
760   while (true) {
761     Safe_PyObjectPtr v1 = iter1->next();
762     Safe_PyObjectPtr v2 = iter2->next();
763     if (v1 && v2) {
764       if (Py_EnterRecursiveCall(" in assert_same_structure")) {
765         return false;
766       }
767       bool no_internal_errors = AssertSameStructureHelper(
768           v1.get(), v2.get(), check_types, error_msg, is_type_error,
769           is_sequence_helper, value_iterator_getter);
770       Py_LeaveRecursiveCall();
771       if (!no_internal_errors) return false;
772       if (!error_msg->empty()) return true;
773     } else if (!v1 && !v2) {
774       // Done with all recursive calls. Structure matched.
775       return true;
776     } else {
777       *is_type_error = false;
778       *error_msg = tensorflow::strings::StrCat(
779           "The two structures don't have the same number of elements. ",
780           "First structure: ", PyObjectToString(o1),
781           ". Second structure: ", PyObjectToString(o2));
782       return true;
783     }
784   }
785 }
786 
787 }  // namespace
788 
IsSequence(PyObject * o)789 bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
IsMapping(PyObject * o)790 bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
IsAttrs(PyObject * o)791 bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
IsTensor(PyObject * o)792 bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
IsIndexedSlices(PyObject * o)793 bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
794 
Flatten(PyObject * nested,bool expand_composites)795 PyObject* Flatten(PyObject* nested, bool expand_composites) {
796   PyObject* list = PyList_New(0);
797   const std::function<int(PyObject*)>& is_sequence_helper =
798       expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
799   const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
800       expand_composites ? GetValueIteratorForComposite : GetValueIterator;
801   if (FlattenHelper(nested, list, is_sequence_helper, get_value_iterator)) {
802     return list;
803   } else {
804     Py_DECREF(list);
805     return nullptr;
806   }
807 }
808 
IsSequenceOrComposite(PyObject * o)809 bool IsSequenceOrComposite(PyObject* o) {
810   return IsSequenceOrCompositeHelper(o) == 1;
811 }
812 
IsCompositeTensor(PyObject * o)813 bool IsCompositeTensor(PyObject* o) { return IsCompositeTensorHelper(o) == 1; }
814 
IsSequenceForData(PyObject * o)815 bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; }
816 
FlattenForData(PyObject * nested)817 PyObject* FlattenForData(PyObject* nested) {
818   PyObject* list = PyList_New(0);
819   if (FlattenHelper(nested, list, IsSequenceForDataHelper,
820                     GetValueIteratorForData)) {
821     return list;
822   } else {
823     Py_DECREF(list);
824     return nullptr;
825   }
826 }
827 
IsNamedtuple(PyObject * o,bool strict)828 PyObject* IsNamedtuple(PyObject* o, bool strict) {
829   // Must be subclass of tuple
830   if (!PyTuple_Check(o)) {
831     Py_RETURN_FALSE;
832   }
833 
834   // If strict, o.__class__.__base__ must be tuple
835   if (strict) {
836     PyObject* klass = PyObject_GetAttrString(o, "__class__");
837     if (klass == nullptr) return nullptr;
838     PyObject* base = PyObject_GetAttrString(klass, "__base__");
839     Py_DECREF(klass);
840     if (base == nullptr) return nullptr;
841 
842     const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
843     // built-in object types are singletons
844     bool tuple_base = base_type == &PyTuple_Type;
845     Py_DECREF(base);
846     if (!tuple_base) {
847       Py_RETURN_FALSE;
848     }
849   }
850 
851   PyObject* collections_sequence_type = GetRegisteredType("Sequence");
852 
853   if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) {
854     PyErr_SetString(PyExc_RuntimeError,
855                     tensorflow::strings::StrCat(
856                         "collections.Sequence type has not been set. "
857                         "Please register the type with the identifier "
858                         "\"Sequence\" using RegisterType.")
859                         .c_str());
860     return nullptr;
861   }
862 
863   // o must have attribute '_fields' and every element in
864   // '_fields' must be a string.
865   int has_fields = PyObject_HasAttrString(o, "_fields");
866   if (!has_fields) {
867     Py_RETURN_FALSE;
868   }
869 
870   Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
871   int is_instance =
872       PyObject_IsInstance(fields.get(), collections_sequence_type);
873   if (is_instance == 0) {
874     Py_RETURN_FALSE;
875   } else if (is_instance == -1) {
876     return nullptr;
877   }
878 
879   Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
880   const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
881   for (Py_ssize_t i = 0; i < s; ++i) {
882     // PySequence_Fast_GET_ITEM returns borrowed ref
883     PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
884     if (!IsString(elem)) {
885       Py_RETURN_FALSE;
886     }
887   }
888 
889   Py_RETURN_TRUE;
890 }
891 
SameNamedtuples(PyObject * o1,PyObject * o2)892 PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
893   Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields"));
894   Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields"));
895   if (f1 == nullptr || f2 == nullptr) {
896     PyErr_SetString(
897         PyExc_RuntimeError,
898         "Expected namedtuple-like objects (that have _fields attr)");
899     return nullptr;
900   }
901 
902   if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) {
903     Py_RETURN_FALSE;
904   }
905 
906   if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
907     Py_RETURN_TRUE;
908   } else {
909     Py_RETURN_FALSE;
910   }
911 }
912 
AssertSameStructure(PyObject * o1,PyObject * o2,bool check_types,bool expand_composites)913 PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types,
914                               bool expand_composites) {
915   const std::function<int(PyObject*)>& is_sequence_helper =
916       expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
917   const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
918       expand_composites ? GetValueIteratorForComposite : GetValueIterator;
919   string error_msg;
920   bool is_type_error = false;
921   AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
922                             is_sequence_helper, get_value_iterator);
923   if (PyErr_Occurred()) {
924     // Don't hide Python exceptions while checking (e.g. errors fetching keys
925     // from custom mappings).
926     return nullptr;
927   }
928   if (!error_msg.empty()) {
929     PyErr_SetString(
930         is_type_error ? PyExc_TypeError : PyExc_ValueError,
931         tensorflow::strings::StrCat(
932             "The two structures don't have the same nested structure.\n\n",
933             "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
934             PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
935             .c_str());
936     return nullptr;
937   }
938   Py_RETURN_NONE;
939 }
940 
AssertSameStructureForData(PyObject * o1,PyObject * o2,bool check_types)941 PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
942                                      bool check_types) {
943   string error_msg;
944   bool is_type_error = false;
945   AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
946                             IsSequenceForDataHelper, GetValueIterator);
947   if (PyErr_Occurred()) {
948     // Don't hide Python exceptions while checking (e.g. errors fetching keys
949     // from custom mappings).
950     return nullptr;
951   }
952   if (!error_msg.empty()) {
953     PyErr_SetString(
954         is_type_error ? PyExc_TypeError : PyExc_ValueError,
955         tensorflow::strings::StrCat(
956             "The two structures don't have the same nested structure.\n\n",
957             "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
958             PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
959             .c_str());
960     return nullptr;
961   }
962   Py_RETURN_NONE;
963 }
964 
965 }  // namespace swig
966 }  // namespace tensorflow
967