1 #include "Python.h"
2 
3 #include "pycore_context.h"
4 #include "pycore_gc.h"            // _PyObject_GC_MAY_BE_TRACKED()
5 #include "pycore_hamt.h"
6 #include "pycore_object.h"
7 #include "pycore_pyerrors.h"
8 #include "pycore_pystate.h"       // _PyThreadState_GET()
9 #include "structmember.h"         // PyMemberDef
10 
11 
12 #define CONTEXT_FREELIST_MAXLEN 255
13 static PyContext *ctx_freelist = NULL;
14 static int ctx_freelist_len = 0;
15 
16 
17 #include "clinic/context.c.h"
18 /*[clinic input]
19 module _contextvars
20 [clinic start generated code]*/
21 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=a0955718c8b8cea6]*/
22 
23 
24 #define ENSURE_Context(o, err_ret)                                  \
25     if (!PyContext_CheckExact(o)) {                                 \
26         PyErr_SetString(PyExc_TypeError,                            \
27                         "an instance of Context was expected");     \
28         return err_ret;                                             \
29     }
30 
31 #define ENSURE_ContextVar(o, err_ret)                               \
32     if (!PyContextVar_CheckExact(o)) {                              \
33         PyErr_SetString(PyExc_TypeError,                            \
34                        "an instance of ContextVar was expected");   \
35         return err_ret;                                             \
36     }
37 
38 #define ENSURE_ContextToken(o, err_ret)                             \
39     if (!PyContextToken_CheckExact(o)) {                            \
40         PyErr_SetString(PyExc_TypeError,                            \
41                         "an instance of Token was expected");       \
42         return err_ret;                                             \
43     }
44 
45 
46 /////////////////////////// Context API
47 
48 
49 static PyContext *
50 context_new_empty(void);
51 
52 static PyContext *
53 context_new_from_vars(PyHamtObject *vars);
54 
55 static inline PyContext *
56 context_get(void);
57 
58 static PyContextToken *
59 token_new(PyContext *ctx, PyContextVar *var, PyObject *val);
60 
61 static PyContextVar *
62 contextvar_new(PyObject *name, PyObject *def);
63 
64 static int
65 contextvar_set(PyContextVar *var, PyObject *val);
66 
67 static int
68 contextvar_del(PyContextVar *var);
69 
70 
71 PyObject *
_PyContext_NewHamtForTests(void)72 _PyContext_NewHamtForTests(void)
73 {
74     return (PyObject *)_PyHamt_New();
75 }
76 
77 
78 PyObject *
PyContext_New(void)79 PyContext_New(void)
80 {
81     return (PyObject *)context_new_empty();
82 }
83 
84 
85 PyObject *
PyContext_Copy(PyObject * octx)86 PyContext_Copy(PyObject * octx)
87 {
88     ENSURE_Context(octx, NULL)
89     PyContext *ctx = (PyContext *)octx;
90     return (PyObject *)context_new_from_vars(ctx->ctx_vars);
91 }
92 
93 
94 PyObject *
PyContext_CopyCurrent(void)95 PyContext_CopyCurrent(void)
96 {
97     PyContext *ctx = context_get();
98     if (ctx == NULL) {
99         return NULL;
100     }
101 
102     return (PyObject *)context_new_from_vars(ctx->ctx_vars);
103 }
104 
105 
106 static int
_PyContext_Enter(PyThreadState * ts,PyObject * octx)107 _PyContext_Enter(PyThreadState *ts, PyObject *octx)
108 {
109     ENSURE_Context(octx, -1)
110     PyContext *ctx = (PyContext *)octx;
111 
112     if (ctx->ctx_entered) {
113         _PyErr_Format(ts, PyExc_RuntimeError,
114                       "cannot enter context: %R is already entered", ctx);
115         return -1;
116     }
117 
118     ctx->ctx_prev = (PyContext *)ts->context;  /* borrow */
119     ctx->ctx_entered = 1;
120 
121     Py_INCREF(ctx);
122     ts->context = (PyObject *)ctx;
123     ts->context_ver++;
124 
125     return 0;
126 }
127 
128 
129 int
PyContext_Enter(PyObject * octx)130 PyContext_Enter(PyObject *octx)
131 {
132     PyThreadState *ts = _PyThreadState_GET();
133     assert(ts != NULL);
134     return _PyContext_Enter(ts, octx);
135 }
136 
137 
138 static int
_PyContext_Exit(PyThreadState * ts,PyObject * octx)139 _PyContext_Exit(PyThreadState *ts, PyObject *octx)
140 {
141     ENSURE_Context(octx, -1)
142     PyContext *ctx = (PyContext *)octx;
143 
144     if (!ctx->ctx_entered) {
145         PyErr_Format(PyExc_RuntimeError,
146                      "cannot exit context: %R has not been entered", ctx);
147         return -1;
148     }
149 
150     if (ts->context != (PyObject *)ctx) {
151         /* Can only happen if someone misuses the C API */
152         PyErr_SetString(PyExc_RuntimeError,
153                         "cannot exit context: thread state references "
154                         "a different context object");
155         return -1;
156     }
157 
158     Py_SETREF(ts->context, (PyObject *)ctx->ctx_prev);
159     ts->context_ver++;
160 
161     ctx->ctx_prev = NULL;
162     ctx->ctx_entered = 0;
163 
164     return 0;
165 }
166 
167 int
PyContext_Exit(PyObject * octx)168 PyContext_Exit(PyObject *octx)
169 {
170     PyThreadState *ts = _PyThreadState_GET();
171     assert(ts != NULL);
172     return _PyContext_Exit(ts, octx);
173 }
174 
175 
176 PyObject *
PyContextVar_New(const char * name,PyObject * def)177 PyContextVar_New(const char *name, PyObject *def)
178 {
179     PyObject *pyname = PyUnicode_FromString(name);
180     if (pyname == NULL) {
181         return NULL;
182     }
183     PyContextVar *var = contextvar_new(pyname, def);
184     Py_DECREF(pyname);
185     return (PyObject *)var;
186 }
187 
188 
189 int
PyContextVar_Get(PyObject * ovar,PyObject * def,PyObject ** val)190 PyContextVar_Get(PyObject *ovar, PyObject *def, PyObject **val)
191 {
192     ENSURE_ContextVar(ovar, -1)
193     PyContextVar *var = (PyContextVar *)ovar;
194 
195     PyThreadState *ts = _PyThreadState_GET();
196     assert(ts != NULL);
197     if (ts->context == NULL) {
198         goto not_found;
199     }
200 
201     if (var->var_cached != NULL &&
202             var->var_cached_tsid == ts->id &&
203             var->var_cached_tsver == ts->context_ver)
204     {
205         *val = var->var_cached;
206         goto found;
207     }
208 
209     assert(PyContext_CheckExact(ts->context));
210     PyHamtObject *vars = ((PyContext *)ts->context)->ctx_vars;
211 
212     PyObject *found = NULL;
213     int res = _PyHamt_Find(vars, (PyObject*)var, &found);
214     if (res < 0) {
215         goto error;
216     }
217     if (res == 1) {
218         assert(found != NULL);
219         var->var_cached = found;  /* borrow */
220         var->var_cached_tsid = ts->id;
221         var->var_cached_tsver = ts->context_ver;
222 
223         *val = found;
224         goto found;
225     }
226 
227 not_found:
228     if (def == NULL) {
229         if (var->var_default != NULL) {
230             *val = var->var_default;
231             goto found;
232         }
233 
234         *val = NULL;
235         goto found;
236     }
237     else {
238         *val = def;
239         goto found;
240    }
241 
242 found:
243     Py_XINCREF(*val);
244     return 0;
245 
246 error:
247     *val = NULL;
248     return -1;
249 }
250 
251 
252 PyObject *
PyContextVar_Set(PyObject * ovar,PyObject * val)253 PyContextVar_Set(PyObject *ovar, PyObject *val)
254 {
255     ENSURE_ContextVar(ovar, NULL)
256     PyContextVar *var = (PyContextVar *)ovar;
257 
258     if (!PyContextVar_CheckExact(var)) {
259         PyErr_SetString(
260             PyExc_TypeError, "an instance of ContextVar was expected");
261         return NULL;
262     }
263 
264     PyContext *ctx = context_get();
265     if (ctx == NULL) {
266         return NULL;
267     }
268 
269     PyObject *old_val = NULL;
270     int found = _PyHamt_Find(ctx->ctx_vars, (PyObject *)var, &old_val);
271     if (found < 0) {
272         return NULL;
273     }
274 
275     Py_XINCREF(old_val);
276     PyContextToken *tok = token_new(ctx, var, old_val);
277     Py_XDECREF(old_val);
278 
279     if (contextvar_set(var, val)) {
280         Py_DECREF(tok);
281         return NULL;
282     }
283 
284     return (PyObject *)tok;
285 }
286 
287 
288 int
PyContextVar_Reset(PyObject * ovar,PyObject * otok)289 PyContextVar_Reset(PyObject *ovar, PyObject *otok)
290 {
291     ENSURE_ContextVar(ovar, -1)
292     ENSURE_ContextToken(otok, -1)
293     PyContextVar *var = (PyContextVar *)ovar;
294     PyContextToken *tok = (PyContextToken *)otok;
295 
296     if (tok->tok_used) {
297         PyErr_Format(PyExc_RuntimeError,
298                      "%R has already been used once", tok);
299         return -1;
300     }
301 
302     if (var != tok->tok_var) {
303         PyErr_Format(PyExc_ValueError,
304                      "%R was created by a different ContextVar", tok);
305         return -1;
306     }
307 
308     PyContext *ctx = context_get();
309     if (ctx != tok->tok_ctx) {
310         PyErr_Format(PyExc_ValueError,
311                      "%R was created in a different Context", tok);
312         return -1;
313     }
314 
315     tok->tok_used = 1;
316 
317     if (tok->tok_oldval == NULL) {
318         return contextvar_del(var);
319     }
320     else {
321         return contextvar_set(var, tok->tok_oldval);
322     }
323 }
324 
325 
326 /////////////////////////// PyContext
327 
328 /*[clinic input]
329 class _contextvars.Context "PyContext *" "&PyContext_Type"
330 [clinic start generated code]*/
331 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=bdf87f8e0cb580e8]*/
332 
333 
334 static inline PyContext *
_context_alloc(void)335 _context_alloc(void)
336 {
337     PyContext *ctx;
338     if (ctx_freelist_len) {
339         ctx_freelist_len--;
340         ctx = ctx_freelist;
341         ctx_freelist = (PyContext *)ctx->ctx_weakreflist;
342         ctx->ctx_weakreflist = NULL;
343         _Py_NewReference((PyObject *)ctx);
344     }
345     else {
346         ctx = PyObject_GC_New(PyContext, &PyContext_Type);
347         if (ctx == NULL) {
348             return NULL;
349         }
350     }
351 
352     ctx->ctx_vars = NULL;
353     ctx->ctx_prev = NULL;
354     ctx->ctx_entered = 0;
355     ctx->ctx_weakreflist = NULL;
356 
357     return ctx;
358 }
359 
360 
361 static PyContext *
context_new_empty(void)362 context_new_empty(void)
363 {
364     PyContext *ctx = _context_alloc();
365     if (ctx == NULL) {
366         return NULL;
367     }
368 
369     ctx->ctx_vars = _PyHamt_New();
370     if (ctx->ctx_vars == NULL) {
371         Py_DECREF(ctx);
372         return NULL;
373     }
374 
375     _PyObject_GC_TRACK(ctx);
376     return ctx;
377 }
378 
379 
380 static PyContext *
context_new_from_vars(PyHamtObject * vars)381 context_new_from_vars(PyHamtObject *vars)
382 {
383     PyContext *ctx = _context_alloc();
384     if (ctx == NULL) {
385         return NULL;
386     }
387 
388     Py_INCREF(vars);
389     ctx->ctx_vars = vars;
390 
391     _PyObject_GC_TRACK(ctx);
392     return ctx;
393 }
394 
395 
396 static inline PyContext *
context_get(void)397 context_get(void)
398 {
399     PyThreadState *ts = _PyThreadState_GET();
400     assert(ts != NULL);
401     PyContext *current_ctx = (PyContext *)ts->context;
402     if (current_ctx == NULL) {
403         current_ctx = context_new_empty();
404         if (current_ctx == NULL) {
405             return NULL;
406         }
407         ts->context = (PyObject *)current_ctx;
408     }
409     return current_ctx;
410 }
411 
412 static int
context_check_key_type(PyObject * key)413 context_check_key_type(PyObject *key)
414 {
415     if (!PyContextVar_CheckExact(key)) {
416         // abort();
417         PyErr_Format(PyExc_TypeError,
418                      "a ContextVar key was expected, got %R", key);
419         return -1;
420     }
421     return 0;
422 }
423 
424 static PyObject *
context_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)425 context_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
426 {
427     if (PyTuple_Size(args) || (kwds != NULL && PyDict_Size(kwds))) {
428         PyErr_SetString(
429             PyExc_TypeError, "Context() does not accept any arguments");
430         return NULL;
431     }
432     return PyContext_New();
433 }
434 
435 static int
context_tp_clear(PyContext * self)436 context_tp_clear(PyContext *self)
437 {
438     Py_CLEAR(self->ctx_prev);
439     Py_CLEAR(self->ctx_vars);
440     return 0;
441 }
442 
443 static int
context_tp_traverse(PyContext * self,visitproc visit,void * arg)444 context_tp_traverse(PyContext *self, visitproc visit, void *arg)
445 {
446     Py_VISIT(self->ctx_prev);
447     Py_VISIT(self->ctx_vars);
448     return 0;
449 }
450 
451 static void
context_tp_dealloc(PyContext * self)452 context_tp_dealloc(PyContext *self)
453 {
454     _PyObject_GC_UNTRACK(self);
455 
456     if (self->ctx_weakreflist != NULL) {
457         PyObject_ClearWeakRefs((PyObject*)self);
458     }
459     (void)context_tp_clear(self);
460 
461     if (ctx_freelist_len < CONTEXT_FREELIST_MAXLEN) {
462         ctx_freelist_len++;
463         self->ctx_weakreflist = (PyObject *)ctx_freelist;
464         ctx_freelist = self;
465     }
466     else {
467         Py_TYPE(self)->tp_free(self);
468     }
469 }
470 
471 static PyObject *
context_tp_iter(PyContext * self)472 context_tp_iter(PyContext *self)
473 {
474     return _PyHamt_NewIterKeys(self->ctx_vars);
475 }
476 
477 static PyObject *
context_tp_richcompare(PyObject * v,PyObject * w,int op)478 context_tp_richcompare(PyObject *v, PyObject *w, int op)
479 {
480     if (!PyContext_CheckExact(v) || !PyContext_CheckExact(w) ||
481             (op != Py_EQ && op != Py_NE))
482     {
483         Py_RETURN_NOTIMPLEMENTED;
484     }
485 
486     int res = _PyHamt_Eq(
487         ((PyContext *)v)->ctx_vars, ((PyContext *)w)->ctx_vars);
488     if (res < 0) {
489         return NULL;
490     }
491 
492     if (op == Py_NE) {
493         res = !res;
494     }
495 
496     if (res) {
497         Py_RETURN_TRUE;
498     }
499     else {
500         Py_RETURN_FALSE;
501     }
502 }
503 
504 static Py_ssize_t
context_tp_len(PyContext * self)505 context_tp_len(PyContext *self)
506 {
507     return _PyHamt_Len(self->ctx_vars);
508 }
509 
510 static PyObject *
context_tp_subscript(PyContext * self,PyObject * key)511 context_tp_subscript(PyContext *self, PyObject *key)
512 {
513     if (context_check_key_type(key)) {
514         return NULL;
515     }
516     PyObject *val = NULL;
517     int found = _PyHamt_Find(self->ctx_vars, key, &val);
518     if (found < 0) {
519         return NULL;
520     }
521     if (found == 0) {
522         PyErr_SetObject(PyExc_KeyError, key);
523         return NULL;
524     }
525     Py_INCREF(val);
526     return val;
527 }
528 
529 static int
context_tp_contains(PyContext * self,PyObject * key)530 context_tp_contains(PyContext *self, PyObject *key)
531 {
532     if (context_check_key_type(key)) {
533         return -1;
534     }
535     PyObject *val = NULL;
536     return _PyHamt_Find(self->ctx_vars, key, &val);
537 }
538 
539 
540 /*[clinic input]
541 _contextvars.Context.get
542     key: object
543     default: object = None
544     /
545 
546 Return the value for `key` if `key` has the value in the context object.
547 
548 If `key` does not exist, return `default`. If `default` is not given,
549 return None.
550 [clinic start generated code]*/
551 
552 static PyObject *
_contextvars_Context_get_impl(PyContext * self,PyObject * key,PyObject * default_value)553 _contextvars_Context_get_impl(PyContext *self, PyObject *key,
554                               PyObject *default_value)
555 /*[clinic end generated code: output=0c54aa7664268189 input=c8eeb81505023995]*/
556 {
557     if (context_check_key_type(key)) {
558         return NULL;
559     }
560 
561     PyObject *val = NULL;
562     int found = _PyHamt_Find(self->ctx_vars, key, &val);
563     if (found < 0) {
564         return NULL;
565     }
566     if (found == 0) {
567         Py_INCREF(default_value);
568         return default_value;
569     }
570     Py_INCREF(val);
571     return val;
572 }
573 
574 
575 /*[clinic input]
576 _contextvars.Context.items
577 
578 Return all variables and their values in the context object.
579 
580 The result is returned as a list of 2-tuples (variable, value).
581 [clinic start generated code]*/
582 
583 static PyObject *
_contextvars_Context_items_impl(PyContext * self)584 _contextvars_Context_items_impl(PyContext *self)
585 /*[clinic end generated code: output=fa1655c8a08502af input=00db64ae379f9f42]*/
586 {
587     return _PyHamt_NewIterItems(self->ctx_vars);
588 }
589 
590 
591 /*[clinic input]
592 _contextvars.Context.keys
593 
594 Return a list of all variables in the context object.
595 [clinic start generated code]*/
596 
597 static PyObject *
_contextvars_Context_keys_impl(PyContext * self)598 _contextvars_Context_keys_impl(PyContext *self)
599 /*[clinic end generated code: output=177227c6b63ec0e2 input=114b53aebca3449c]*/
600 {
601     return _PyHamt_NewIterKeys(self->ctx_vars);
602 }
603 
604 
605 /*[clinic input]
606 _contextvars.Context.values
607 
608 Return a list of all variables' values in the context object.
609 [clinic start generated code]*/
610 
611 static PyObject *
_contextvars_Context_values_impl(PyContext * self)612 _contextvars_Context_values_impl(PyContext *self)
613 /*[clinic end generated code: output=d286dabfc8db6dde input=ce8075d04a6ea526]*/
614 {
615     return _PyHamt_NewIterValues(self->ctx_vars);
616 }
617 
618 
619 /*[clinic input]
620 _contextvars.Context.copy
621 
622 Return a shallow copy of the context object.
623 [clinic start generated code]*/
624 
625 static PyObject *
_contextvars_Context_copy_impl(PyContext * self)626 _contextvars_Context_copy_impl(PyContext *self)
627 /*[clinic end generated code: output=30ba8896c4707a15 input=ebafdbdd9c72d592]*/
628 {
629     return (PyObject *)context_new_from_vars(self->ctx_vars);
630 }
631 
632 
633 static PyObject *
context_run(PyContext * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)634 context_run(PyContext *self, PyObject *const *args,
635             Py_ssize_t nargs, PyObject *kwnames)
636 {
637     PyThreadState *ts = _PyThreadState_GET();
638 
639     if (nargs < 1) {
640         _PyErr_SetString(ts, PyExc_TypeError,
641                          "run() missing 1 required positional argument");
642         return NULL;
643     }
644 
645     if (_PyContext_Enter(ts, (PyObject *)self)) {
646         return NULL;
647     }
648 
649     PyObject *call_result = _PyObject_VectorcallTstate(
650         ts, args[0], args + 1, nargs - 1, kwnames);
651 
652     if (_PyContext_Exit(ts, (PyObject *)self)) {
653         return NULL;
654     }
655 
656     return call_result;
657 }
658 
659 
660 static PyMethodDef PyContext_methods[] = {
661     _CONTEXTVARS_CONTEXT_GET_METHODDEF
662     _CONTEXTVARS_CONTEXT_ITEMS_METHODDEF
663     _CONTEXTVARS_CONTEXT_KEYS_METHODDEF
664     _CONTEXTVARS_CONTEXT_VALUES_METHODDEF
665     _CONTEXTVARS_CONTEXT_COPY_METHODDEF
666     {"run", (PyCFunction)(void(*)(void))context_run, METH_FASTCALL | METH_KEYWORDS, NULL},
667     {NULL, NULL}
668 };
669 
670 static PySequenceMethods PyContext_as_sequence = {
671     0,                                   /* sq_length */
672     0,                                   /* sq_concat */
673     0,                                   /* sq_repeat */
674     0,                                   /* sq_item */
675     0,                                   /* sq_slice */
676     0,                                   /* sq_ass_item */
677     0,                                   /* sq_ass_slice */
678     (objobjproc)context_tp_contains,     /* sq_contains */
679     0,                                   /* sq_inplace_concat */
680     0,                                   /* sq_inplace_repeat */
681 };
682 
683 static PyMappingMethods PyContext_as_mapping = {
684     (lenfunc)context_tp_len,             /* mp_length */
685     (binaryfunc)context_tp_subscript,    /* mp_subscript */
686 };
687 
688 PyTypeObject PyContext_Type = {
689     PyVarObject_HEAD_INIT(&PyType_Type, 0)
690     "Context",
691     sizeof(PyContext),
692     .tp_methods = PyContext_methods,
693     .tp_as_mapping = &PyContext_as_mapping,
694     .tp_as_sequence = &PyContext_as_sequence,
695     .tp_iter = (getiterfunc)context_tp_iter,
696     .tp_dealloc = (destructor)context_tp_dealloc,
697     .tp_getattro = PyObject_GenericGetAttr,
698     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
699     .tp_richcompare = context_tp_richcompare,
700     .tp_traverse = (traverseproc)context_tp_traverse,
701     .tp_clear = (inquiry)context_tp_clear,
702     .tp_new = context_tp_new,
703     .tp_weaklistoffset = offsetof(PyContext, ctx_weakreflist),
704     .tp_hash = PyObject_HashNotImplemented,
705 };
706 
707 
708 /////////////////////////// ContextVar
709 
710 
711 static int
contextvar_set(PyContextVar * var,PyObject * val)712 contextvar_set(PyContextVar *var, PyObject *val)
713 {
714     var->var_cached = NULL;
715     PyThreadState *ts = PyThreadState_Get();
716 
717     PyContext *ctx = context_get();
718     if (ctx == NULL) {
719         return -1;
720     }
721 
722     PyHamtObject *new_vars = _PyHamt_Assoc(
723         ctx->ctx_vars, (PyObject *)var, val);
724     if (new_vars == NULL) {
725         return -1;
726     }
727 
728     Py_SETREF(ctx->ctx_vars, new_vars);
729 
730     var->var_cached = val;  /* borrow */
731     var->var_cached_tsid = ts->id;
732     var->var_cached_tsver = ts->context_ver;
733     return 0;
734 }
735 
736 static int
contextvar_del(PyContextVar * var)737 contextvar_del(PyContextVar *var)
738 {
739     var->var_cached = NULL;
740 
741     PyContext *ctx = context_get();
742     if (ctx == NULL) {
743         return -1;
744     }
745 
746     PyHamtObject *vars = ctx->ctx_vars;
747     PyHamtObject *new_vars = _PyHamt_Without(vars, (PyObject *)var);
748     if (new_vars == NULL) {
749         return -1;
750     }
751 
752     if (vars == new_vars) {
753         Py_DECREF(new_vars);
754         PyErr_SetObject(PyExc_LookupError, (PyObject *)var);
755         return -1;
756     }
757 
758     Py_SETREF(ctx->ctx_vars, new_vars);
759     return 0;
760 }
761 
762 static Py_hash_t
contextvar_generate_hash(void * addr,PyObject * name)763 contextvar_generate_hash(void *addr, PyObject *name)
764 {
765     /* Take hash of `name` and XOR it with the object's addr.
766 
767        The structure of the tree is encoded in objects' hashes, which
768        means that sufficiently similar hashes would result in tall trees
769        with many Collision nodes.  Which would, in turn, result in slower
770        get and set operations.
771 
772        The XORing helps to ensure that:
773 
774        (1) sequentially allocated ContextVar objects have
775            different hashes;
776 
777        (2) context variables with equal names have
778            different hashes.
779     */
780 
781     Py_hash_t name_hash = PyObject_Hash(name);
782     if (name_hash == -1) {
783         return -1;
784     }
785 
786     Py_hash_t res = _Py_HashPointer(addr) ^ name_hash;
787     return res == -1 ? -2 : res;
788 }
789 
790 static PyContextVar *
contextvar_new(PyObject * name,PyObject * def)791 contextvar_new(PyObject *name, PyObject *def)
792 {
793     if (!PyUnicode_Check(name)) {
794         PyErr_SetString(PyExc_TypeError,
795                         "context variable name must be a str");
796         return NULL;
797     }
798 
799     PyContextVar *var = PyObject_GC_New(PyContextVar, &PyContextVar_Type);
800     if (var == NULL) {
801         return NULL;
802     }
803 
804     var->var_hash = contextvar_generate_hash(var, name);
805     if (var->var_hash == -1) {
806         Py_DECREF(var);
807         return NULL;
808     }
809 
810     Py_INCREF(name);
811     var->var_name = name;
812 
813     Py_XINCREF(def);
814     var->var_default = def;
815 
816     var->var_cached = NULL;
817     var->var_cached_tsid = 0;
818     var->var_cached_tsver = 0;
819 
820     if (_PyObject_GC_MAY_BE_TRACKED(name) ||
821             (def != NULL && _PyObject_GC_MAY_BE_TRACKED(def)))
822     {
823         PyObject_GC_Track(var);
824     }
825     return var;
826 }
827 
828 
829 /*[clinic input]
830 class _contextvars.ContextVar "PyContextVar *" "&PyContextVar_Type"
831 [clinic start generated code]*/
832 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=445da935fa8883c3]*/
833 
834 
835 static PyObject *
contextvar_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)836 contextvar_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
837 {
838     static char *kwlist[] = {"", "default", NULL};
839     PyObject *name;
840     PyObject *def = NULL;
841 
842     if (!PyArg_ParseTupleAndKeywords(
843             args, kwds, "O|$O:ContextVar", kwlist, &name, &def))
844     {
845         return NULL;
846     }
847 
848     return (PyObject *)contextvar_new(name, def);
849 }
850 
851 static int
contextvar_tp_clear(PyContextVar * self)852 contextvar_tp_clear(PyContextVar *self)
853 {
854     Py_CLEAR(self->var_name);
855     Py_CLEAR(self->var_default);
856     self->var_cached = NULL;
857     self->var_cached_tsid = 0;
858     self->var_cached_tsver = 0;
859     return 0;
860 }
861 
862 static int
contextvar_tp_traverse(PyContextVar * self,visitproc visit,void * arg)863 contextvar_tp_traverse(PyContextVar *self, visitproc visit, void *arg)
864 {
865     Py_VISIT(self->var_name);
866     Py_VISIT(self->var_default);
867     return 0;
868 }
869 
870 static void
contextvar_tp_dealloc(PyContextVar * self)871 contextvar_tp_dealloc(PyContextVar *self)
872 {
873     PyObject_GC_UnTrack(self);
874     (void)contextvar_tp_clear(self);
875     Py_TYPE(self)->tp_free(self);
876 }
877 
878 static Py_hash_t
contextvar_tp_hash(PyContextVar * self)879 contextvar_tp_hash(PyContextVar *self)
880 {
881     return self->var_hash;
882 }
883 
884 static PyObject *
contextvar_tp_repr(PyContextVar * self)885 contextvar_tp_repr(PyContextVar *self)
886 {
887     _PyUnicodeWriter writer;
888 
889     _PyUnicodeWriter_Init(&writer);
890 
891     if (_PyUnicodeWriter_WriteASCIIString(
892             &writer, "<ContextVar name=", 17) < 0)
893     {
894         goto error;
895     }
896 
897     PyObject *name = PyObject_Repr(self->var_name);
898     if (name == NULL) {
899         goto error;
900     }
901     if (_PyUnicodeWriter_WriteStr(&writer, name) < 0) {
902         Py_DECREF(name);
903         goto error;
904     }
905     Py_DECREF(name);
906 
907     if (self->var_default != NULL) {
908         if (_PyUnicodeWriter_WriteASCIIString(&writer, " default=", 9) < 0) {
909             goto error;
910         }
911 
912         PyObject *def = PyObject_Repr(self->var_default);
913         if (def == NULL) {
914             goto error;
915         }
916         if (_PyUnicodeWriter_WriteStr(&writer, def) < 0) {
917             Py_DECREF(def);
918             goto error;
919         }
920         Py_DECREF(def);
921     }
922 
923     PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
924     if (addr == NULL) {
925         goto error;
926     }
927     if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
928         Py_DECREF(addr);
929         goto error;
930     }
931     Py_DECREF(addr);
932 
933     return _PyUnicodeWriter_Finish(&writer);
934 
935 error:
936     _PyUnicodeWriter_Dealloc(&writer);
937     return NULL;
938 }
939 
940 
941 /*[clinic input]
942 _contextvars.ContextVar.get
943     default: object = NULL
944     /
945 
946 Return a value for the context variable for the current context.
947 
948 If there is no value for the variable in the current context, the method will:
949  * return the value of the default argument of the method, if provided; or
950  * return the default value for the context variable, if it was created
951    with one; or
952  * raise a LookupError.
953 [clinic start generated code]*/
954 
955 static PyObject *
_contextvars_ContextVar_get_impl(PyContextVar * self,PyObject * default_value)956 _contextvars_ContextVar_get_impl(PyContextVar *self, PyObject *default_value)
957 /*[clinic end generated code: output=0746bd0aa2ced7bf input=30aa2ab9e433e401]*/
958 {
959     if (!PyContextVar_CheckExact(self)) {
960         PyErr_SetString(
961             PyExc_TypeError, "an instance of ContextVar was expected");
962         return NULL;
963     }
964 
965     PyObject *val;
966     if (PyContextVar_Get((PyObject *)self, default_value, &val) < 0) {
967         return NULL;
968     }
969 
970     if (val == NULL) {
971         PyErr_SetObject(PyExc_LookupError, (PyObject *)self);
972         return NULL;
973     }
974 
975     return val;
976 }
977 
978 /*[clinic input]
979 _contextvars.ContextVar.set
980     value: object
981     /
982 
983 Call to set a new value for the context variable in the current context.
984 
985 The required value argument is the new value for the context variable.
986 
987 Returns a Token object that can be used to restore the variable to its previous
988 value via the `ContextVar.reset()` method.
989 [clinic start generated code]*/
990 
991 static PyObject *
_contextvars_ContextVar_set(PyContextVar * self,PyObject * value)992 _contextvars_ContextVar_set(PyContextVar *self, PyObject *value)
993 /*[clinic end generated code: output=446ed5e820d6d60b input=c0a6887154227453]*/
994 {
995     return PyContextVar_Set((PyObject *)self, value);
996 }
997 
998 /*[clinic input]
999 _contextvars.ContextVar.reset
1000     token: object
1001     /
1002 
1003 Reset the context variable.
1004 
1005 The variable is reset to the value it had before the `ContextVar.set()` that
1006 created the token was used.
1007 [clinic start generated code]*/
1008 
1009 static PyObject *
_contextvars_ContextVar_reset(PyContextVar * self,PyObject * token)1010 _contextvars_ContextVar_reset(PyContextVar *self, PyObject *token)
1011 /*[clinic end generated code: output=d4ee34d0742d62ee input=ebe2881e5af4ffda]*/
1012 {
1013     if (!PyContextToken_CheckExact(token)) {
1014         PyErr_Format(PyExc_TypeError,
1015                      "expected an instance of Token, got %R", token);
1016         return NULL;
1017     }
1018 
1019     if (PyContextVar_Reset((PyObject *)self, token)) {
1020         return NULL;
1021     }
1022 
1023     Py_RETURN_NONE;
1024 }
1025 
1026 
1027 static PyMemberDef PyContextVar_members[] = {
1028     {"name", T_OBJECT, offsetof(PyContextVar, var_name), READONLY},
1029     {NULL}
1030 };
1031 
1032 static PyMethodDef PyContextVar_methods[] = {
1033     _CONTEXTVARS_CONTEXTVAR_GET_METHODDEF
1034     _CONTEXTVARS_CONTEXTVAR_SET_METHODDEF
1035     _CONTEXTVARS_CONTEXTVAR_RESET_METHODDEF
1036     {"__class_getitem__", (PyCFunction)Py_GenericAlias,
1037     METH_O|METH_CLASS,       PyDoc_STR("See PEP 585")},
1038     {NULL, NULL}
1039 };
1040 
1041 PyTypeObject PyContextVar_Type = {
1042     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1043     "ContextVar",
1044     sizeof(PyContextVar),
1045     .tp_methods = PyContextVar_methods,
1046     .tp_members = PyContextVar_members,
1047     .tp_dealloc = (destructor)contextvar_tp_dealloc,
1048     .tp_getattro = PyObject_GenericGetAttr,
1049     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
1050     .tp_traverse = (traverseproc)contextvar_tp_traverse,
1051     .tp_clear = (inquiry)contextvar_tp_clear,
1052     .tp_new = contextvar_tp_new,
1053     .tp_free = PyObject_GC_Del,
1054     .tp_hash = (hashfunc)contextvar_tp_hash,
1055     .tp_repr = (reprfunc)contextvar_tp_repr,
1056 };
1057 
1058 
1059 /////////////////////////// Token
1060 
1061 static PyObject * get_token_missing(void);
1062 
1063 
1064 /*[clinic input]
1065 class _contextvars.Token "PyContextToken *" "&PyContextToken_Type"
1066 [clinic start generated code]*/
1067 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=338a5e2db13d3f5b]*/
1068 
1069 
1070 static PyObject *
token_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)1071 token_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
1072 {
1073     PyErr_SetString(PyExc_RuntimeError,
1074                     "Tokens can only be created by ContextVars");
1075     return NULL;
1076 }
1077 
1078 static int
token_tp_clear(PyContextToken * self)1079 token_tp_clear(PyContextToken *self)
1080 {
1081     Py_CLEAR(self->tok_ctx);
1082     Py_CLEAR(self->tok_var);
1083     Py_CLEAR(self->tok_oldval);
1084     return 0;
1085 }
1086 
1087 static int
token_tp_traverse(PyContextToken * self,visitproc visit,void * arg)1088 token_tp_traverse(PyContextToken *self, visitproc visit, void *arg)
1089 {
1090     Py_VISIT(self->tok_ctx);
1091     Py_VISIT(self->tok_var);
1092     Py_VISIT(self->tok_oldval);
1093     return 0;
1094 }
1095 
1096 static void
token_tp_dealloc(PyContextToken * self)1097 token_tp_dealloc(PyContextToken *self)
1098 {
1099     PyObject_GC_UnTrack(self);
1100     (void)token_tp_clear(self);
1101     Py_TYPE(self)->tp_free(self);
1102 }
1103 
1104 static PyObject *
token_tp_repr(PyContextToken * self)1105 token_tp_repr(PyContextToken *self)
1106 {
1107     _PyUnicodeWriter writer;
1108 
1109     _PyUnicodeWriter_Init(&writer);
1110 
1111     if (_PyUnicodeWriter_WriteASCIIString(&writer, "<Token", 6) < 0) {
1112         goto error;
1113     }
1114 
1115     if (self->tok_used) {
1116         if (_PyUnicodeWriter_WriteASCIIString(&writer, " used", 5) < 0) {
1117             goto error;
1118         }
1119     }
1120 
1121     if (_PyUnicodeWriter_WriteASCIIString(&writer, " var=", 5) < 0) {
1122         goto error;
1123     }
1124 
1125     PyObject *var = PyObject_Repr((PyObject *)self->tok_var);
1126     if (var == NULL) {
1127         goto error;
1128     }
1129     if (_PyUnicodeWriter_WriteStr(&writer, var) < 0) {
1130         Py_DECREF(var);
1131         goto error;
1132     }
1133     Py_DECREF(var);
1134 
1135     PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
1136     if (addr == NULL) {
1137         goto error;
1138     }
1139     if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
1140         Py_DECREF(addr);
1141         goto error;
1142     }
1143     Py_DECREF(addr);
1144 
1145     return _PyUnicodeWriter_Finish(&writer);
1146 
1147 error:
1148     _PyUnicodeWriter_Dealloc(&writer);
1149     return NULL;
1150 }
1151 
1152 static PyObject *
token_get_var(PyContextToken * self,void * Py_UNUSED (ignored))1153 token_get_var(PyContextToken *self, void *Py_UNUSED(ignored))
1154 {
1155     Py_INCREF(self->tok_var);
1156     return (PyObject *)self->tok_var;
1157 }
1158 
1159 static PyObject *
token_get_old_value(PyContextToken * self,void * Py_UNUSED (ignored))1160 token_get_old_value(PyContextToken *self, void *Py_UNUSED(ignored))
1161 {
1162     if (self->tok_oldval == NULL) {
1163         return get_token_missing();
1164     }
1165 
1166     Py_INCREF(self->tok_oldval);
1167     return self->tok_oldval;
1168 }
1169 
1170 static PyGetSetDef PyContextTokenType_getsetlist[] = {
1171     {"var", (getter)token_get_var, NULL, NULL},
1172     {"old_value", (getter)token_get_old_value, NULL, NULL},
1173     {NULL}
1174 };
1175 
1176 static PyMethodDef PyContextTokenType_methods[] = {
1177     {"__class_getitem__",    (PyCFunction)Py_GenericAlias,
1178     METH_O|METH_CLASS,       PyDoc_STR("See PEP 585")},
1179     {NULL}
1180 };
1181 
1182 PyTypeObject PyContextToken_Type = {
1183     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1184     "Token",
1185     sizeof(PyContextToken),
1186     .tp_methods = PyContextTokenType_methods,
1187     .tp_getset = PyContextTokenType_getsetlist,
1188     .tp_dealloc = (destructor)token_tp_dealloc,
1189     .tp_getattro = PyObject_GenericGetAttr,
1190     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
1191     .tp_traverse = (traverseproc)token_tp_traverse,
1192     .tp_clear = (inquiry)token_tp_clear,
1193     .tp_new = token_tp_new,
1194     .tp_free = PyObject_GC_Del,
1195     .tp_hash = PyObject_HashNotImplemented,
1196     .tp_repr = (reprfunc)token_tp_repr,
1197 };
1198 
1199 static PyContextToken *
token_new(PyContext * ctx,PyContextVar * var,PyObject * val)1200 token_new(PyContext *ctx, PyContextVar *var, PyObject *val)
1201 {
1202     PyContextToken *tok = PyObject_GC_New(PyContextToken, &PyContextToken_Type);
1203     if (tok == NULL) {
1204         return NULL;
1205     }
1206 
1207     Py_INCREF(ctx);
1208     tok->tok_ctx = ctx;
1209 
1210     Py_INCREF(var);
1211     tok->tok_var = var;
1212 
1213     Py_XINCREF(val);
1214     tok->tok_oldval = val;
1215 
1216     tok->tok_used = 0;
1217 
1218     PyObject_GC_Track(tok);
1219     return tok;
1220 }
1221 
1222 
1223 /////////////////////////// Token.MISSING
1224 
1225 
1226 static PyObject *_token_missing;
1227 
1228 
1229 typedef struct {
1230     PyObject_HEAD
1231 } PyContextTokenMissing;
1232 
1233 
1234 static PyObject *
context_token_missing_tp_repr(PyObject * self)1235 context_token_missing_tp_repr(PyObject *self)
1236 {
1237     return PyUnicode_FromString("<Token.MISSING>");
1238 }
1239 
1240 
1241 PyTypeObject PyContextTokenMissing_Type = {
1242     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1243     "Token.MISSING",
1244     sizeof(PyContextTokenMissing),
1245     .tp_getattro = PyObject_GenericGetAttr,
1246     .tp_flags = Py_TPFLAGS_DEFAULT,
1247     .tp_repr = context_token_missing_tp_repr,
1248 };
1249 
1250 
1251 static PyObject *
get_token_missing(void)1252 get_token_missing(void)
1253 {
1254     if (_token_missing != NULL) {
1255         Py_INCREF(_token_missing);
1256         return _token_missing;
1257     }
1258 
1259     _token_missing = (PyObject *)PyObject_New(
1260         PyContextTokenMissing, &PyContextTokenMissing_Type);
1261     if (_token_missing == NULL) {
1262         return NULL;
1263     }
1264 
1265     Py_INCREF(_token_missing);
1266     return _token_missing;
1267 }
1268 
1269 
1270 ///////////////////////////
1271 
1272 
1273 void
_PyContext_ClearFreeList(void)1274 _PyContext_ClearFreeList(void)
1275 {
1276     for (; ctx_freelist_len; ctx_freelist_len--) {
1277         PyContext *ctx = ctx_freelist;
1278         ctx_freelist = (PyContext *)ctx->ctx_weakreflist;
1279         ctx->ctx_weakreflist = NULL;
1280         PyObject_GC_Del(ctx);
1281     }
1282 }
1283 
1284 
1285 void
_PyContext_Fini(void)1286 _PyContext_Fini(void)
1287 {
1288     Py_CLEAR(_token_missing);
1289     _PyContext_ClearFreeList();
1290     _PyHamt_Fini();
1291 }
1292 
1293 
1294 int
_PyContext_Init(void)1295 _PyContext_Init(void)
1296 {
1297     if (!_PyHamt_Init()) {
1298         return 0;
1299     }
1300 
1301     if ((PyType_Ready(&PyContext_Type) < 0) ||
1302         (PyType_Ready(&PyContextVar_Type) < 0) ||
1303         (PyType_Ready(&PyContextToken_Type) < 0) ||
1304         (PyType_Ready(&PyContextTokenMissing_Type) < 0))
1305     {
1306         return 0;
1307     }
1308 
1309     PyObject *missing = get_token_missing();
1310     if (PyDict_SetItemString(
1311         PyContextToken_Type.tp_dict, "MISSING", missing))
1312     {
1313         Py_DECREF(missing);
1314         return 0;
1315     }
1316     Py_DECREF(missing);
1317 
1318     return 1;
1319 }
1320