1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: haberman@google.com (Josh Haberman)
32 
33 #include <google/protobuf/pyext/map_container.h>
34 
35 #include <memory>
36 #ifndef _SHARED_PTR_H
37 #include <google/protobuf/stubs/shared_ptr.h>
38 #endif
39 
40 #include <google/protobuf/stubs/logging.h>
41 #include <google/protobuf/stubs/common.h>
42 #include <google/protobuf/stubs/scoped_ptr.h>
43 #include <google/protobuf/map_field.h>
44 #include <google/protobuf/map.h>
45 #include <google/protobuf/message.h>
46 #include <google/protobuf/pyext/message.h>
47 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
48 
49 #if PY_MAJOR_VERSION >= 3
50   #define PyInt_FromLong PyLong_FromLong
51   #define PyInt_FromSize_t PyLong_FromSize_t
52 #endif
53 
54 namespace google {
55 namespace protobuf {
56 namespace python {
57 
58 // Functions that need access to map reflection functionality.
59 // They need to be contained in this class because it is friended.
60 class MapReflectionFriend {
61  public:
62   // Methods that are in common between the map types.
63   static PyObject* Contains(PyObject* _self, PyObject* key);
64   static Py_ssize_t Length(PyObject* _self);
65   static PyObject* GetIterator(PyObject *_self);
66   static PyObject* IterNext(PyObject* _self);
67 
68   // Methods that differ between the map types.
69   static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
70   static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
71   static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
72   static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
73 };
74 
75 struct MapIterator {
76   PyObject_HEAD;
77 
78   google::protobuf::scoped_ptr< ::google::protobuf::MapIterator> iter;
79 
80   // A pointer back to the container, so we can notice changes to the version.
81   // We own a ref on this.
82   MapContainer* container;
83 
84   // We need to keep a ref on the Message* too, because
85   // MapIterator::~MapIterator() accesses it.  Normally this would be ok because
86   // the ref on container (above) would guarantee outlive semantics.  However in
87   // the case of ClearField(), InitializeAndCopyToParentContainer() resets the
88   // message pointer (and the owner) to a different message, a copy of the
89   // original.  But our iterator still points to the original, which could now
90   // get deleted before us.
91   //
92   // To prevent this, we ensure that the Message will always stay alive as long
93   // as this iterator does.  This is solely for the benefit of the MapIterator
94   // destructor -- we should never actually access the iterator in this state
95   // except to delete it.
96   shared_ptr<Message> owner;
97 
98   // The version of the map when we took the iterator to it.
99   //
100   // We store this so that if the map is modified during iteration we can throw
101   // an error.
102   uint64 version;
103 
104   // True if the container is empty.  We signal this separately to avoid calling
105   // any of the iteration methods, which are non-const.
106   bool empty;
107 };
108 
GetMutableMessage()109 Message* MapContainer::GetMutableMessage() {
110   cmessage::AssureWritable(parent);
111   return const_cast<Message*>(message);
112 }
113 
114 // Consumes a reference on the Python string object.
PyStringToSTL(PyObject * py_string,string * stl_string)115 static bool PyStringToSTL(PyObject* py_string, string* stl_string) {
116   char *value;
117   Py_ssize_t value_len;
118 
119   if (!py_string) {
120     return false;
121   }
122   if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
123     Py_DECREF(py_string);
124     return false;
125   } else {
126     stl_string->assign(value, value_len);
127     Py_DECREF(py_string);
128     return true;
129   }
130 }
131 
PythonToMapKey(PyObject * obj,const FieldDescriptor * field_descriptor,MapKey * key)132 static bool PythonToMapKey(PyObject* obj,
133                            const FieldDescriptor* field_descriptor,
134                            MapKey* key) {
135   switch (field_descriptor->cpp_type()) {
136     case FieldDescriptor::CPPTYPE_INT32: {
137       GOOGLE_CHECK_GET_INT32(obj, value, false);
138       key->SetInt32Value(value);
139       break;
140     }
141     case FieldDescriptor::CPPTYPE_INT64: {
142       GOOGLE_CHECK_GET_INT64(obj, value, false);
143       key->SetInt64Value(value);
144       break;
145     }
146     case FieldDescriptor::CPPTYPE_UINT32: {
147       GOOGLE_CHECK_GET_UINT32(obj, value, false);
148       key->SetUInt32Value(value);
149       break;
150     }
151     case FieldDescriptor::CPPTYPE_UINT64: {
152       GOOGLE_CHECK_GET_UINT64(obj, value, false);
153       key->SetUInt64Value(value);
154       break;
155     }
156     case FieldDescriptor::CPPTYPE_BOOL: {
157       GOOGLE_CHECK_GET_BOOL(obj, value, false);
158       key->SetBoolValue(value);
159       break;
160     }
161     case FieldDescriptor::CPPTYPE_STRING: {
162       string str;
163       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
164         return false;
165       }
166       key->SetStringValue(str);
167       break;
168     }
169     default:
170       PyErr_Format(
171           PyExc_SystemError, "Type %d cannot be a map key",
172           field_descriptor->cpp_type());
173       return false;
174   }
175   return true;
176 }
177 
MapKeyToPython(const FieldDescriptor * field_descriptor,const MapKey & key)178 static PyObject* MapKeyToPython(const FieldDescriptor* field_descriptor,
179                                 const MapKey& key) {
180   switch (field_descriptor->cpp_type()) {
181     case FieldDescriptor::CPPTYPE_INT32:
182       return PyInt_FromLong(key.GetInt32Value());
183     case FieldDescriptor::CPPTYPE_INT64:
184       return PyLong_FromLongLong(key.GetInt64Value());
185     case FieldDescriptor::CPPTYPE_UINT32:
186       return PyInt_FromSize_t(key.GetUInt32Value());
187     case FieldDescriptor::CPPTYPE_UINT64:
188       return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
189     case FieldDescriptor::CPPTYPE_BOOL:
190       return PyBool_FromLong(key.GetBoolValue());
191     case FieldDescriptor::CPPTYPE_STRING:
192       return ToStringObject(field_descriptor, key.GetStringValue());
193     default:
194       PyErr_Format(
195           PyExc_SystemError, "Couldn't convert type %d to value",
196           field_descriptor->cpp_type());
197       return NULL;
198   }
199 }
200 
201 // This is only used for ScalarMap, so we don't need to handle the
202 // CPPTYPE_MESSAGE case.
MapValueRefToPython(const FieldDescriptor * field_descriptor,MapValueRef * value)203 PyObject* MapValueRefToPython(const FieldDescriptor* field_descriptor,
204                               MapValueRef* value) {
205   switch (field_descriptor->cpp_type()) {
206     case FieldDescriptor::CPPTYPE_INT32:
207       return PyInt_FromLong(value->GetInt32Value());
208     case FieldDescriptor::CPPTYPE_INT64:
209       return PyLong_FromLongLong(value->GetInt64Value());
210     case FieldDescriptor::CPPTYPE_UINT32:
211       return PyInt_FromSize_t(value->GetUInt32Value());
212     case FieldDescriptor::CPPTYPE_UINT64:
213       return PyLong_FromUnsignedLongLong(value->GetUInt64Value());
214     case FieldDescriptor::CPPTYPE_FLOAT:
215       return PyFloat_FromDouble(value->GetFloatValue());
216     case FieldDescriptor::CPPTYPE_DOUBLE:
217       return PyFloat_FromDouble(value->GetDoubleValue());
218     case FieldDescriptor::CPPTYPE_BOOL:
219       return PyBool_FromLong(value->GetBoolValue());
220     case FieldDescriptor::CPPTYPE_STRING:
221       return ToStringObject(field_descriptor, value->GetStringValue());
222     case FieldDescriptor::CPPTYPE_ENUM:
223       return PyInt_FromLong(value->GetEnumValue());
224     default:
225       PyErr_Format(
226           PyExc_SystemError, "Couldn't convert type %d to value",
227           field_descriptor->cpp_type());
228       return NULL;
229   }
230 }
231 
232 // This is only used for ScalarMap, so we don't need to handle the
233 // CPPTYPE_MESSAGE case.
PythonToMapValueRef(PyObject * obj,const FieldDescriptor * field_descriptor,bool allow_unknown_enum_values,MapValueRef * value_ref)234 static bool PythonToMapValueRef(PyObject* obj,
235                                 const FieldDescriptor* field_descriptor,
236                                 bool allow_unknown_enum_values,
237                                 MapValueRef* value_ref) {
238   switch (field_descriptor->cpp_type()) {
239     case FieldDescriptor::CPPTYPE_INT32: {
240       GOOGLE_CHECK_GET_INT32(obj, value, false);
241       value_ref->SetInt32Value(value);
242       return true;
243     }
244     case FieldDescriptor::CPPTYPE_INT64: {
245       GOOGLE_CHECK_GET_INT64(obj, value, false);
246       value_ref->SetInt64Value(value);
247       return true;
248     }
249     case FieldDescriptor::CPPTYPE_UINT32: {
250       GOOGLE_CHECK_GET_UINT32(obj, value, false);
251       value_ref->SetUInt32Value(value);
252       return true;
253     }
254     case FieldDescriptor::CPPTYPE_UINT64: {
255       GOOGLE_CHECK_GET_UINT64(obj, value, false);
256       value_ref->SetUInt64Value(value);
257       return true;
258     }
259     case FieldDescriptor::CPPTYPE_FLOAT: {
260       GOOGLE_CHECK_GET_FLOAT(obj, value, false);
261       value_ref->SetFloatValue(value);
262       return true;
263     }
264     case FieldDescriptor::CPPTYPE_DOUBLE: {
265       GOOGLE_CHECK_GET_DOUBLE(obj, value, false);
266       value_ref->SetDoubleValue(value);
267       return true;
268     }
269     case FieldDescriptor::CPPTYPE_BOOL: {
270       GOOGLE_CHECK_GET_BOOL(obj, value, false);
271       value_ref->SetBoolValue(value);
272       return true;;
273     }
274     case FieldDescriptor::CPPTYPE_STRING: {
275       string str;
276       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
277         return false;
278       }
279       value_ref->SetStringValue(str);
280       return true;
281     }
282     case FieldDescriptor::CPPTYPE_ENUM: {
283       GOOGLE_CHECK_GET_INT32(obj, value, false);
284       if (allow_unknown_enum_values) {
285         value_ref->SetEnumValue(value);
286         return true;
287       } else {
288         const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
289         const EnumValueDescriptor* enum_value =
290             enum_descriptor->FindValueByNumber(value);
291         if (enum_value != NULL) {
292           value_ref->SetEnumValue(value);
293           return true;
294         } else {
295           PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
296           return false;
297         }
298       }
299       break;
300     }
301     default:
302       PyErr_Format(
303           PyExc_SystemError, "Setting value to a field of unknown type %d",
304           field_descriptor->cpp_type());
305       return false;
306   }
307 }
308 
309 // Map methods common to ScalarMap and MessageMap //////////////////////////////
310 
GetMap(PyObject * obj)311 static MapContainer* GetMap(PyObject* obj) {
312   return reinterpret_cast<MapContainer*>(obj);
313 }
314 
Length(PyObject * _self)315 Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
316   MapContainer* self = GetMap(_self);
317   const google::protobuf::Message* message = self->message;
318   return message->GetReflection()->MapSize(*message,
319                                            self->parent_field_descriptor);
320 }
321 
Clear(PyObject * _self)322 PyObject* Clear(PyObject* _self) {
323   MapContainer* self = GetMap(_self);
324   Message* message = self->GetMutableMessage();
325   const Reflection* reflection = message->GetReflection();
326 
327   reflection->ClearField(message, self->parent_field_descriptor);
328 
329   Py_RETURN_NONE;
330 }
331 
Contains(PyObject * _self,PyObject * key)332 PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
333   MapContainer* self = GetMap(_self);
334 
335   const Message* message = self->message;
336   const Reflection* reflection = message->GetReflection();
337   MapKey map_key;
338 
339   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
340     return NULL;
341   }
342 
343   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
344                                  map_key)) {
345     Py_RETURN_TRUE;
346   } else {
347     Py_RETURN_FALSE;
348   }
349 }
350 
351 // Initializes the underlying Message object of "to" so it becomes a new parent
352 // repeated scalar, and copies all the values from "from" to it. A child scalar
353 // container can be released by passing it as both from and to (e.g. making it
354 // the recipient of the new parent message and copying the values from itself).
InitializeAndCopyToParentContainer(MapContainer * from,MapContainer * to)355 static int InitializeAndCopyToParentContainer(MapContainer* from,
356                                               MapContainer* to) {
357   // For now we require from == to, re-evaluate if we want to support deep copy
358   // as in repeated_scalar_container.cc.
359   GOOGLE_DCHECK(from == to);
360   Message* new_message = from->message->New();
361 
362   if (MapReflectionFriend::Length(reinterpret_cast<PyObject*>(from)) > 0) {
363     // A somewhat roundabout way of copying just one field from old_message to
364     // new_message.  This is the best we can do with what Reflection gives us.
365     Message* mutable_old = from->GetMutableMessage();
366     vector<const FieldDescriptor*> fields;
367     fields.push_back(from->parent_field_descriptor);
368 
369     // Move the map field into the new message.
370     mutable_old->GetReflection()->SwapFields(mutable_old, new_message, fields);
371 
372     // If/when we support from != to, this will be required also to copy the
373     // map field back into the existing message:
374     // mutable_old->MergeFrom(*new_message);
375   }
376 
377   // If from == to this could delete old_message.
378   to->owner.reset(new_message);
379 
380   to->parent = NULL;
381   to->parent_field_descriptor = from->parent_field_descriptor;
382   to->message = new_message;
383 
384   // Invalidate iterators, since they point to the old copy of the field.
385   to->version++;
386 
387   return 0;
388 }
389 
Release()390 int MapContainer::Release() {
391   return InitializeAndCopyToParentContainer(this, this);
392 }
393 
394 
395 // ScalarMap ///////////////////////////////////////////////////////////////////
396 
NewScalarMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor)397 PyObject *NewScalarMapContainer(
398     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
399   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
400     return NULL;
401   }
402 
403 #if PY_MAJOR_VERSION >= 3
404   ScopedPyObjectPtr obj(PyType_GenericAlloc(
405         reinterpret_cast<PyTypeObject *>(ScalarMapContainer_Type), 0));
406 #else
407   ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0));
408 #endif
409   if (obj.get() == NULL) {
410     return PyErr_Format(PyExc_RuntimeError,
411                         "Could not allocate new container.");
412   }
413 
414   MapContainer* self = GetMap(obj.get());
415 
416   self->message = parent->message;
417   self->parent = parent;
418   self->parent_field_descriptor = parent_field_descriptor;
419   self->owner = parent->owner;
420   self->version = 0;
421 
422   self->key_field_descriptor =
423       parent_field_descriptor->message_type()->FindFieldByName("key");
424   self->value_field_descriptor =
425       parent_field_descriptor->message_type()->FindFieldByName("value");
426 
427   if (self->key_field_descriptor == NULL ||
428       self->value_field_descriptor == NULL) {
429     return PyErr_Format(PyExc_KeyError,
430                         "Map entry descriptor did not have key/value fields");
431   }
432 
433   return obj.release();
434 }
435 
ScalarMapGetItem(PyObject * _self,PyObject * key)436 PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
437                                                 PyObject* key) {
438   MapContainer* self = GetMap(_self);
439 
440   Message* message = self->GetMutableMessage();
441   const Reflection* reflection = message->GetReflection();
442   MapKey map_key;
443   MapValueRef value;
444 
445   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
446     return NULL;
447   }
448 
449   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
450                                          map_key, &value)) {
451     self->version++;
452   }
453 
454   return MapValueRefToPython(self->value_field_descriptor, &value);
455 }
456 
ScalarMapSetItem(PyObject * _self,PyObject * key,PyObject * v)457 int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
458                                           PyObject* v) {
459   MapContainer* self = GetMap(_self);
460 
461   Message* message = self->GetMutableMessage();
462   const Reflection* reflection = message->GetReflection();
463   MapKey map_key;
464   MapValueRef value;
465 
466   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
467     return -1;
468   }
469 
470   self->version++;
471 
472   if (v) {
473     // Set item to v.
474     reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
475                                        map_key, &value);
476 
477     return PythonToMapValueRef(v, self->value_field_descriptor,
478                                reflection->SupportsUnknownEnumValues(), &value)
479                ? 0
480                : -1;
481   } else {
482     // Delete key from map.
483     if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
484                                    map_key)) {
485       return 0;
486     } else {
487       PyErr_Format(PyExc_KeyError, "Key not present in map");
488       return -1;
489     }
490   }
491 }
492 
ScalarMapGet(PyObject * self,PyObject * args)493 static PyObject* ScalarMapGet(PyObject* self, PyObject* args) {
494   PyObject* key;
495   PyObject* default_value = NULL;
496   if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
497     return NULL;
498   }
499 
500   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
501   if (is_present.get() == NULL) {
502     return NULL;
503   }
504 
505   if (PyObject_IsTrue(is_present.get())) {
506     return MapReflectionFriend::ScalarMapGetItem(self, key);
507   } else {
508     if (default_value != NULL) {
509       Py_INCREF(default_value);
510       return default_value;
511     } else {
512       Py_RETURN_NONE;
513     }
514   }
515 }
516 
ScalarMapDealloc(PyObject * _self)517 static void ScalarMapDealloc(PyObject* _self) {
518   MapContainer* self = GetMap(_self);
519   self->owner.reset();
520   Py_TYPE(_self)->tp_free(_self);
521 }
522 
523 static PyMethodDef ScalarMapMethods[] = {
524   { "__contains__", MapReflectionFriend::Contains, METH_O,
525     "Tests whether a key is a member of the map." },
526   { "clear", (PyCFunction)Clear, METH_NOARGS,
527     "Removes all elements from the map." },
528   { "get", ScalarMapGet, METH_VARARGS,
529     "Gets the value for the given key if present, or otherwise a default" },
530   /*
531   { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
532     "Makes a deep copy of the class." },
533   { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
534     "Outputs picklable representation of the repeated field." },
535   */
536   {NULL, NULL},
537 };
538 
539 #if PY_MAJOR_VERSION >= 3
540   static PyType_Slot ScalarMapContainer_Type_slots[] = {
541       {Py_tp_dealloc, (void *)ScalarMapDealloc},
542       {Py_mp_length, (void *)MapReflectionFriend::Length},
543       {Py_mp_subscript, (void *)MapReflectionFriend::ScalarMapGetItem},
544       {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem},
545       {Py_tp_methods, (void *)ScalarMapMethods},
546       {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
547       {0, 0},
548   };
549 
550   PyType_Spec ScalarMapContainer_Type_spec = {
551       FULL_MODULE_NAME ".ScalarMapContainer",
552       sizeof(MapContainer),
553       0,
554       Py_TPFLAGS_DEFAULT,
555       ScalarMapContainer_Type_slots
556   };
557   PyObject *ScalarMapContainer_Type;
558 #else
559   static PyMappingMethods ScalarMapMappingMethods = {
560     MapReflectionFriend::Length,             // mp_length
561     MapReflectionFriend::ScalarMapGetItem,   // mp_subscript
562     MapReflectionFriend::ScalarMapSetItem,   // mp_ass_subscript
563   };
564 
565   PyTypeObject ScalarMapContainer_Type = {
566     PyVarObject_HEAD_INIT(&PyType_Type, 0)
567     FULL_MODULE_NAME ".ScalarMapContainer",  //  tp_name
568     sizeof(MapContainer),                //  tp_basicsize
569     0,                                   //  tp_itemsize
570     ScalarMapDealloc,                    //  tp_dealloc
571     0,                                   //  tp_print
572     0,                                   //  tp_getattr
573     0,                                   //  tp_setattr
574     0,                                   //  tp_compare
575     0,                                   //  tp_repr
576     0,                                   //  tp_as_number
577     0,                                   //  tp_as_sequence
578     &ScalarMapMappingMethods,            //  tp_as_mapping
579     0,                                   //  tp_hash
580     0,                                   //  tp_call
581     0,                                   //  tp_str
582     0,                                   //  tp_getattro
583     0,                                   //  tp_setattro
584     0,                                   //  tp_as_buffer
585     Py_TPFLAGS_DEFAULT,                  //  tp_flags
586     "A scalar map container",            //  tp_doc
587     0,                                   //  tp_traverse
588     0,                                   //  tp_clear
589     0,                                   //  tp_richcompare
590     0,                                   //  tp_weaklistoffset
591     MapReflectionFriend::GetIterator,    //  tp_iter
592     0,                                   //  tp_iternext
593     ScalarMapMethods,                    //  tp_methods
594     0,                                   //  tp_members
595     0,                                   //  tp_getset
596     0,                                   //  tp_base
597     0,                                   //  tp_dict
598     0,                                   //  tp_descr_get
599     0,                                   //  tp_descr_set
600     0,                                   //  tp_dictoffset
601     0,                                   //  tp_init
602   };
603 #endif
604 
605 
606 // MessageMap //////////////////////////////////////////////////////////////////
607 
GetMessageMap(PyObject * obj)608 static MessageMapContainer* GetMessageMap(PyObject* obj) {
609   return reinterpret_cast<MessageMapContainer*>(obj);
610 }
611 
GetCMessage(MessageMapContainer * self,Message * message)612 static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
613   // Get or create the CMessage object corresponding to this message.
614   ScopedPyObjectPtr key(PyLong_FromVoidPtr(message));
615   PyObject* ret = PyDict_GetItem(self->message_dict, key.get());
616 
617   if (ret == NULL) {
618     CMessage* cmsg = cmessage::NewEmptyMessage(self->message_class);
619     ret = reinterpret_cast<PyObject*>(cmsg);
620 
621     if (cmsg == NULL) {
622       return NULL;
623     }
624     cmsg->owner = self->owner;
625     cmsg->message = message;
626     cmsg->parent = self->parent;
627 
628     if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) {
629       Py_DECREF(ret);
630       return NULL;
631     }
632   } else {
633     Py_INCREF(ret);
634   }
635 
636   return ret;
637 }
638 
NewMessageMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor,CMessageClass * message_class)639 PyObject* NewMessageMapContainer(
640     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
641     CMessageClass* message_class) {
642   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
643     return NULL;
644   }
645 
646 #if PY_MAJOR_VERSION >= 3
647   PyObject* obj = PyType_GenericAlloc(
648         reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0);
649 #else
650   PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0);
651 #endif
652   if (obj == NULL) {
653     return PyErr_Format(PyExc_RuntimeError,
654                         "Could not allocate new container.");
655   }
656 
657   MessageMapContainer* self = GetMessageMap(obj);
658 
659   self->message = parent->message;
660   self->parent = parent;
661   self->parent_field_descriptor = parent_field_descriptor;
662   self->owner = parent->owner;
663   self->version = 0;
664 
665   self->key_field_descriptor =
666       parent_field_descriptor->message_type()->FindFieldByName("key");
667   self->value_field_descriptor =
668       parent_field_descriptor->message_type()->FindFieldByName("value");
669 
670   self->message_dict = PyDict_New();
671   if (self->message_dict == NULL) {
672     return PyErr_Format(PyExc_RuntimeError,
673                         "Could not allocate message dict.");
674   }
675 
676   Py_INCREF(message_class);
677   self->message_class = message_class;
678 
679   if (self->key_field_descriptor == NULL ||
680       self->value_field_descriptor == NULL) {
681     Py_DECREF(obj);
682     return PyErr_Format(PyExc_KeyError,
683                         "Map entry descriptor did not have key/value fields");
684   }
685 
686   return obj;
687 }
688 
MessageMapSetItem(PyObject * _self,PyObject * key,PyObject * v)689 int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
690                                            PyObject* v) {
691   if (v) {
692     PyErr_Format(PyExc_ValueError,
693                  "Direct assignment of submessage not allowed");
694     return -1;
695   }
696 
697   // Now we know that this is a delete, not a set.
698 
699   MessageMapContainer* self = GetMessageMap(_self);
700   Message* message = self->GetMutableMessage();
701   const Reflection* reflection = message->GetReflection();
702   MapKey map_key;
703   MapValueRef value;
704 
705   self->version++;
706 
707   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
708     return -1;
709   }
710 
711   // Delete key from map.
712   if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
713                                  map_key)) {
714     return 0;
715   } else {
716     PyErr_Format(PyExc_KeyError, "Key not present in map");
717     return -1;
718   }
719 }
720 
MessageMapGetItem(PyObject * _self,PyObject * key)721 PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
722                                                  PyObject* key) {
723   MessageMapContainer* self = GetMessageMap(_self);
724 
725   Message* message = self->GetMutableMessage();
726   const Reflection* reflection = message->GetReflection();
727   MapKey map_key;
728   MapValueRef value;
729 
730   if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
731     return NULL;
732   }
733 
734   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
735                                          map_key, &value)) {
736     self->version++;
737   }
738 
739   return GetCMessage(self, value.MutableMessageValue());
740 }
741 
MessageMapGet(PyObject * self,PyObject * args)742 PyObject* MessageMapGet(PyObject* self, PyObject* args) {
743   PyObject* key;
744   PyObject* default_value = NULL;
745   if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
746     return NULL;
747   }
748 
749   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
750   if (is_present.get() == NULL) {
751     return NULL;
752   }
753 
754   if (PyObject_IsTrue(is_present.get())) {
755     return MapReflectionFriend::MessageMapGetItem(self, key);
756   } else {
757     if (default_value != NULL) {
758       Py_INCREF(default_value);
759       return default_value;
760     } else {
761       Py_RETURN_NONE;
762     }
763   }
764 }
765 
MessageMapDealloc(PyObject * _self)766 static void MessageMapDealloc(PyObject* _self) {
767   MessageMapContainer* self = GetMessageMap(_self);
768   self->owner.reset();
769   Py_DECREF(self->message_dict);
770   Py_DECREF(self->message_class);
771   Py_TYPE(_self)->tp_free(_self);
772 }
773 
774 static PyMethodDef MessageMapMethods[] = {
775   { "__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
776     "Tests whether the map contains this element."},
777   { "clear", (PyCFunction)Clear, METH_NOARGS,
778     "Removes all elements from the map."},
779   { "get", MessageMapGet, METH_VARARGS,
780     "Gets the value for the given key if present, or otherwise a default" },
781   { "get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
782     "Alias for getitem, useful to make explicit that the map is mutated." },
783   /*
784   { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
785     "Makes a deep copy of the class." },
786   { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
787     "Outputs picklable representation of the repeated field." },
788   */
789   {NULL, NULL},
790 };
791 
792 #if PY_MAJOR_VERSION >= 3
793   static PyType_Slot MessageMapContainer_Type_slots[] = {
794       {Py_tp_dealloc, (void *)MessageMapDealloc},
795       {Py_mp_length, (void *)MapReflectionFriend::Length},
796       {Py_mp_subscript, (void *)MapReflectionFriend::MessageMapGetItem},
797       {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem},
798       {Py_tp_methods, (void *)MessageMapMethods},
799       {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
800       {0, 0}
801   };
802 
803   PyType_Spec MessageMapContainer_Type_spec = {
804       FULL_MODULE_NAME ".MessageMapContainer",
805       sizeof(MessageMapContainer),
806       0,
807       Py_TPFLAGS_DEFAULT,
808       MessageMapContainer_Type_slots
809   };
810 
811   PyObject *MessageMapContainer_Type;
812 #else
813   static PyMappingMethods MessageMapMappingMethods = {
814     MapReflectionFriend::Length,              // mp_length
815     MapReflectionFriend::MessageMapGetItem,   // mp_subscript
816     MapReflectionFriend::MessageMapSetItem,   // mp_ass_subscript
817   };
818 
819   PyTypeObject MessageMapContainer_Type = {
820     PyVarObject_HEAD_INIT(&PyType_Type, 0)
821     FULL_MODULE_NAME ".MessageMapContainer",  //  tp_name
822     sizeof(MessageMapContainer),         //  tp_basicsize
823     0,                                   //  tp_itemsize
824     MessageMapDealloc,                   //  tp_dealloc
825     0,                                   //  tp_print
826     0,                                   //  tp_getattr
827     0,                                   //  tp_setattr
828     0,                                   //  tp_compare
829     0,                                   //  tp_repr
830     0,                                   //  tp_as_number
831     0,                                   //  tp_as_sequence
832     &MessageMapMappingMethods,           //  tp_as_mapping
833     0,                                   //  tp_hash
834     0,                                   //  tp_call
835     0,                                   //  tp_str
836     0,                                   //  tp_getattro
837     0,                                   //  tp_setattro
838     0,                                   //  tp_as_buffer
839     Py_TPFLAGS_DEFAULT,                  //  tp_flags
840     "A map container for message",       //  tp_doc
841     0,                                   //  tp_traverse
842     0,                                   //  tp_clear
843     0,                                   //  tp_richcompare
844     0,                                   //  tp_weaklistoffset
845     MapReflectionFriend::GetIterator,    //  tp_iter
846     0,                                   //  tp_iternext
847     MessageMapMethods,                   //  tp_methods
848     0,                                   //  tp_members
849     0,                                   //  tp_getset
850     0,                                   //  tp_base
851     0,                                   //  tp_dict
852     0,                                   //  tp_descr_get
853     0,                                   //  tp_descr_set
854     0,                                   //  tp_dictoffset
855     0,                                   //  tp_init
856   };
857 #endif
858 
859 // MapIterator /////////////////////////////////////////////////////////////////
860 
GetIter(PyObject * obj)861 static MapIterator* GetIter(PyObject* obj) {
862   return reinterpret_cast<MapIterator*>(obj);
863 }
864 
GetIterator(PyObject * _self)865 PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
866   MapContainer* self = GetMap(_self);
867 
868   ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
869   if (obj == NULL) {
870     return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
871   }
872 
873   MapIterator* iter = GetIter(obj.get());
874 
875   Py_INCREF(self);
876   iter->container = self;
877   iter->version = self->version;
878   iter->owner = self->owner;
879 
880   if (MapReflectionFriend::Length(_self) > 0) {
881     Message* message = self->GetMutableMessage();
882     const Reflection* reflection = message->GetReflection();
883 
884     iter->iter.reset(new ::google::protobuf::MapIterator(
885         reflection->MapBegin(message, self->parent_field_descriptor)));
886   }
887 
888   return obj.release();
889 }
890 
IterNext(PyObject * _self)891 PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
892   MapIterator* self = GetIter(_self);
893 
894   // This won't catch mutations to the map performed by MergeFrom(); no easy way
895   // to address that.
896   if (self->version != self->container->version) {
897     return PyErr_Format(PyExc_RuntimeError,
898                         "Map modified during iteration.");
899   }
900 
901   if (self->iter.get() == NULL) {
902     return NULL;
903   }
904 
905   Message* message = self->container->GetMutableMessage();
906   const Reflection* reflection = message->GetReflection();
907 
908   if (*self->iter ==
909       reflection->MapEnd(message, self->container->parent_field_descriptor)) {
910     return NULL;
911   }
912 
913   PyObject* ret = MapKeyToPython(self->container->key_field_descriptor,
914                                  self->iter->GetKey());
915 
916   ++(*self->iter);
917 
918   return ret;
919 }
920 
DeallocMapIterator(PyObject * _self)921 static void DeallocMapIterator(PyObject* _self) {
922   MapIterator* self = GetIter(_self);
923   self->iter.reset();
924   self->owner.reset();
925   Py_XDECREF(self->container);
926   Py_TYPE(_self)->tp_free(_self);
927 }
928 
929 PyTypeObject MapIterator_Type = {
930   PyVarObject_HEAD_INIT(&PyType_Type, 0)
931   FULL_MODULE_NAME ".MapIterator",     //  tp_name
932   sizeof(MapIterator),                 //  tp_basicsize
933   0,                                   //  tp_itemsize
934   DeallocMapIterator,                  //  tp_dealloc
935   0,                                   //  tp_print
936   0,                                   //  tp_getattr
937   0,                                   //  tp_setattr
938   0,                                   //  tp_compare
939   0,                                   //  tp_repr
940   0,                                   //  tp_as_number
941   0,                                   //  tp_as_sequence
942   0,                                   //  tp_as_mapping
943   0,                                   //  tp_hash
944   0,                                   //  tp_call
945   0,                                   //  tp_str
946   0,                                   //  tp_getattro
947   0,                                   //  tp_setattro
948   0,                                   //  tp_as_buffer
949   Py_TPFLAGS_DEFAULT,                  //  tp_flags
950   "A scalar map iterator",             //  tp_doc
951   0,                                   //  tp_traverse
952   0,                                   //  tp_clear
953   0,                                   //  tp_richcompare
954   0,                                   //  tp_weaklistoffset
955   PyObject_SelfIter,                   //  tp_iter
956   MapReflectionFriend::IterNext,       //  tp_iternext
957   0,                                   //  tp_methods
958   0,                                   //  tp_members
959   0,                                   //  tp_getset
960   0,                                   //  tp_base
961   0,                                   //  tp_dict
962   0,                                   //  tp_descr_get
963   0,                                   //  tp_descr_set
964   0,                                   //  tp_dictoffset
965   0,                                   //  tp_init
966 };
967 
968 }  // namespace python
969 }  // namespace protobuf
970 }  // namespace google
971