1 #include "Python.h"
2 
3 #include "structmember.h"
4 #include "internal/pystate.h"
5 #include "internal/context.h"
6 #include "internal/hamt.h"
7 
8 
9 #define CONTEXT_FREELIST_MAXLEN 255
10 static PyContext *ctx_freelist = NULL;
11 static int ctx_freelist_len = 0;
12 
13 
14 #include "clinic/context.c.h"
15 /*[clinic input]
16 module _contextvars
17 [clinic start generated code]*/
18 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=a0955718c8b8cea6]*/
19 
20 
21 #define ENSURE_Context(o, err_ret)                                  \
22     if (!PyContext_CheckExact(o)) {                                 \
23         PyErr_SetString(PyExc_TypeError,                            \
24                         "an instance of Context was expected");     \
25         return err_ret;                                             \
26     }
27 
28 #define ENSURE_ContextVar(o, err_ret)                               \
29     if (!PyContextVar_CheckExact(o)) {                              \
30         PyErr_SetString(PyExc_TypeError,                            \
31                        "an instance of ContextVar was expected");   \
32         return err_ret;                                             \
33     }
34 
35 #define ENSURE_ContextToken(o, err_ret)                             \
36     if (!PyContextToken_CheckExact(o)) {                            \
37         PyErr_SetString(PyExc_TypeError,                            \
38                         "an instance of Token was expected");       \
39         return err_ret;                                             \
40     }
41 
42 
43 /////////////////////////// Context API
44 
45 
46 static PyContext *
47 context_new_empty(void);
48 
49 static PyContext *
50 context_new_from_vars(PyHamtObject *vars);
51 
52 static inline PyContext *
53 context_get(void);
54 
55 static PyContextToken *
56 token_new(PyContext *ctx, PyContextVar *var, PyObject *val);
57 
58 static PyContextVar *
59 contextvar_new(PyObject *name, PyObject *def);
60 
61 static int
62 contextvar_set(PyContextVar *var, PyObject *val);
63 
64 static int
65 contextvar_del(PyContextVar *var);
66 
67 
68 PyObject *
_PyContext_NewHamtForTests(void)69 _PyContext_NewHamtForTests(void)
70 {
71     return (PyObject *)_PyHamt_New();
72 }
73 
74 
75 PyObject *
PyContext_New(void)76 PyContext_New(void)
77 {
78     return (PyObject *)context_new_empty();
79 }
80 
81 
82 PyObject *
PyContext_Copy(PyObject * octx)83 PyContext_Copy(PyObject * octx)
84 {
85     ENSURE_Context(octx, NULL)
86     PyContext *ctx = (PyContext *)octx;
87     return (PyObject *)context_new_from_vars(ctx->ctx_vars);
88 }
89 
90 
91 PyObject *
PyContext_CopyCurrent(void)92 PyContext_CopyCurrent(void)
93 {
94     PyContext *ctx = context_get();
95     if (ctx == NULL) {
96         return NULL;
97     }
98 
99     return (PyObject *)context_new_from_vars(ctx->ctx_vars);
100 }
101 
102 
103 int
PyContext_Enter(PyObject * octx)104 PyContext_Enter(PyObject *octx)
105 {
106     ENSURE_Context(octx, -1)
107     PyContext *ctx = (PyContext *)octx;
108 
109     if (ctx->ctx_entered) {
110         PyErr_Format(PyExc_RuntimeError,
111                      "cannot enter context: %R is already entered", ctx);
112         return -1;
113     }
114 
115     PyThreadState *ts = PyThreadState_GET();
116     assert(ts != NULL);
117 
118     ctx->ctx_prev = (PyContext *)ts->context;  /* borrow */
119     ctx->ctx_entered = 1;
120 
121     Py_INCREF(ctx);
122     ts->context = (PyObject *)ctx;
123     ts->context_ver++;
124 
125     return 0;
126 }
127 
128 
129 int
PyContext_Exit(PyObject * octx)130 PyContext_Exit(PyObject *octx)
131 {
132     ENSURE_Context(octx, -1)
133     PyContext *ctx = (PyContext *)octx;
134 
135     if (!ctx->ctx_entered) {
136         PyErr_Format(PyExc_RuntimeError,
137                      "cannot exit context: %R has not been entered", ctx);
138         return -1;
139     }
140 
141     PyThreadState *ts = PyThreadState_GET();
142     assert(ts != NULL);
143 
144     if (ts->context != (PyObject *)ctx) {
145         /* Can only happen if someone misuses the C API */
146         PyErr_SetString(PyExc_RuntimeError,
147                         "cannot exit context: thread state references "
148                         "a different context object");
149         return -1;
150     }
151 
152     Py_SETREF(ts->context, (PyObject *)ctx->ctx_prev);
153     ts->context_ver++;
154 
155     ctx->ctx_prev = NULL;
156     ctx->ctx_entered = 0;
157 
158     return 0;
159 }
160 
161 
162 PyObject *
PyContextVar_New(const char * name,PyObject * def)163 PyContextVar_New(const char *name, PyObject *def)
164 {
165     PyObject *pyname = PyUnicode_FromString(name);
166     if (pyname == NULL) {
167         return NULL;
168     }
169     PyContextVar *var = contextvar_new(pyname, def);
170     Py_DECREF(pyname);
171     return (PyObject *)var;
172 }
173 
174 
175 int
PyContextVar_Get(PyObject * ovar,PyObject * def,PyObject ** val)176 PyContextVar_Get(PyObject *ovar, PyObject *def, PyObject **val)
177 {
178     ENSURE_ContextVar(ovar, -1)
179     PyContextVar *var = (PyContextVar *)ovar;
180 
181     PyThreadState *ts = PyThreadState_GET();
182     assert(ts != NULL);
183     if (ts->context == NULL) {
184         goto not_found;
185     }
186 
187     if (var->var_cached != NULL &&
188             var->var_cached_tsid == ts->id &&
189             var->var_cached_tsver == ts->context_ver)
190     {
191         *val = var->var_cached;
192         goto found;
193     }
194 
195     assert(PyContext_CheckExact(ts->context));
196     PyHamtObject *vars = ((PyContext *)ts->context)->ctx_vars;
197 
198     PyObject *found = NULL;
199     int res = _PyHamt_Find(vars, (PyObject*)var, &found);
200     if (res < 0) {
201         goto error;
202     }
203     if (res == 1) {
204         assert(found != NULL);
205         var->var_cached = found;  /* borrow */
206         var->var_cached_tsid = ts->id;
207         var->var_cached_tsver = ts->context_ver;
208 
209         *val = found;
210         goto found;
211     }
212 
213 not_found:
214     if (def == NULL) {
215         if (var->var_default != NULL) {
216             *val = var->var_default;
217             goto found;
218         }
219 
220         *val = NULL;
221         goto found;
222     }
223     else {
224         *val = def;
225         goto found;
226    }
227 
228 found:
229     Py_XINCREF(*val);
230     return 0;
231 
232 error:
233     *val = NULL;
234     return -1;
235 }
236 
237 
238 PyObject *
PyContextVar_Set(PyObject * ovar,PyObject * val)239 PyContextVar_Set(PyObject *ovar, PyObject *val)
240 {
241     ENSURE_ContextVar(ovar, NULL)
242     PyContextVar *var = (PyContextVar *)ovar;
243 
244     if (!PyContextVar_CheckExact(var)) {
245         PyErr_SetString(
246             PyExc_TypeError, "an instance of ContextVar was expected");
247         return NULL;
248     }
249 
250     PyContext *ctx = context_get();
251     if (ctx == NULL) {
252         return NULL;
253     }
254 
255     PyObject *old_val = NULL;
256     int found = _PyHamt_Find(ctx->ctx_vars, (PyObject *)var, &old_val);
257     if (found < 0) {
258         return NULL;
259     }
260 
261     Py_XINCREF(old_val);
262     PyContextToken *tok = token_new(ctx, var, old_val);
263     Py_XDECREF(old_val);
264 
265     if (contextvar_set(var, val)) {
266         Py_DECREF(tok);
267         return NULL;
268     }
269 
270     return (PyObject *)tok;
271 }
272 
273 
274 int
PyContextVar_Reset(PyObject * ovar,PyObject * otok)275 PyContextVar_Reset(PyObject *ovar, PyObject *otok)
276 {
277     ENSURE_ContextVar(ovar, -1)
278     ENSURE_ContextToken(otok, -1)
279     PyContextVar *var = (PyContextVar *)ovar;
280     PyContextToken *tok = (PyContextToken *)otok;
281 
282     if (tok->tok_used) {
283         PyErr_Format(PyExc_RuntimeError,
284                      "%R has already been used once", tok);
285         return -1;
286     }
287 
288     if (var != tok->tok_var) {
289         PyErr_Format(PyExc_ValueError,
290                      "%R was created by a different ContextVar", tok);
291         return -1;
292     }
293 
294     PyContext *ctx = context_get();
295     if (ctx != tok->tok_ctx) {
296         PyErr_Format(PyExc_ValueError,
297                      "%R was created in a different Context", tok);
298         return -1;
299     }
300 
301     tok->tok_used = 1;
302 
303     if (tok->tok_oldval == NULL) {
304         return contextvar_del(var);
305     }
306     else {
307         return contextvar_set(var, tok->tok_oldval);
308     }
309 }
310 
311 
312 /////////////////////////// PyContext
313 
314 /*[clinic input]
315 class _contextvars.Context "PyContext *" "&PyContext_Type"
316 [clinic start generated code]*/
317 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=bdf87f8e0cb580e8]*/
318 
319 
320 static inline PyContext *
_context_alloc(void)321 _context_alloc(void)
322 {
323     PyContext *ctx;
324     if (ctx_freelist_len) {
325         ctx_freelist_len--;
326         ctx = ctx_freelist;
327         ctx_freelist = (PyContext *)ctx->ctx_weakreflist;
328         ctx->ctx_weakreflist = NULL;
329         _Py_NewReference((PyObject *)ctx);
330     }
331     else {
332         ctx = PyObject_GC_New(PyContext, &PyContext_Type);
333         if (ctx == NULL) {
334             return NULL;
335         }
336     }
337 
338     ctx->ctx_vars = NULL;
339     ctx->ctx_prev = NULL;
340     ctx->ctx_entered = 0;
341     ctx->ctx_weakreflist = NULL;
342 
343     return ctx;
344 }
345 
346 
347 static PyContext *
context_new_empty(void)348 context_new_empty(void)
349 {
350     PyContext *ctx = _context_alloc();
351     if (ctx == NULL) {
352         return NULL;
353     }
354 
355     ctx->ctx_vars = _PyHamt_New();
356     if (ctx->ctx_vars == NULL) {
357         Py_DECREF(ctx);
358         return NULL;
359     }
360 
361     _PyObject_GC_TRACK(ctx);
362     return ctx;
363 }
364 
365 
366 static PyContext *
context_new_from_vars(PyHamtObject * vars)367 context_new_from_vars(PyHamtObject *vars)
368 {
369     PyContext *ctx = _context_alloc();
370     if (ctx == NULL) {
371         return NULL;
372     }
373 
374     Py_INCREF(vars);
375     ctx->ctx_vars = vars;
376 
377     _PyObject_GC_TRACK(ctx);
378     return ctx;
379 }
380 
381 
382 static inline PyContext *
context_get(void)383 context_get(void)
384 {
385     PyThreadState *ts = PyThreadState_GET();
386     assert(ts != NULL);
387     PyContext *current_ctx = (PyContext *)ts->context;
388     if (current_ctx == NULL) {
389         current_ctx = context_new_empty();
390         if (current_ctx == NULL) {
391             return NULL;
392         }
393         ts->context = (PyObject *)current_ctx;
394     }
395     return current_ctx;
396 }
397 
398 static int
context_check_key_type(PyObject * key)399 context_check_key_type(PyObject *key)
400 {
401     if (!PyContextVar_CheckExact(key)) {
402         // abort();
403         PyErr_Format(PyExc_TypeError,
404                      "a ContextVar key was expected, got %R", key);
405         return -1;
406     }
407     return 0;
408 }
409 
410 static PyObject *
context_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)411 context_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
412 {
413     if (PyTuple_Size(args) || (kwds != NULL && PyDict_Size(kwds))) {
414         PyErr_SetString(
415             PyExc_TypeError, "Context() does not accept any arguments");
416         return NULL;
417     }
418     return PyContext_New();
419 }
420 
421 static int
context_tp_clear(PyContext * self)422 context_tp_clear(PyContext *self)
423 {
424     Py_CLEAR(self->ctx_prev);
425     Py_CLEAR(self->ctx_vars);
426     return 0;
427 }
428 
429 static int
context_tp_traverse(PyContext * self,visitproc visit,void * arg)430 context_tp_traverse(PyContext *self, visitproc visit, void *arg)
431 {
432     Py_VISIT(self->ctx_prev);
433     Py_VISIT(self->ctx_vars);
434     return 0;
435 }
436 
437 static void
context_tp_dealloc(PyContext * self)438 context_tp_dealloc(PyContext *self)
439 {
440     _PyObject_GC_UNTRACK(self);
441 
442     if (self->ctx_weakreflist != NULL) {
443         PyObject_ClearWeakRefs((PyObject*)self);
444     }
445     (void)context_tp_clear(self);
446 
447     if (ctx_freelist_len < CONTEXT_FREELIST_MAXLEN) {
448         ctx_freelist_len++;
449         self->ctx_weakreflist = (PyObject *)ctx_freelist;
450         ctx_freelist = self;
451     }
452     else {
453         Py_TYPE(self)->tp_free(self);
454     }
455 }
456 
457 static PyObject *
context_tp_iter(PyContext * self)458 context_tp_iter(PyContext *self)
459 {
460     return _PyHamt_NewIterKeys(self->ctx_vars);
461 }
462 
463 static PyObject *
context_tp_richcompare(PyObject * v,PyObject * w,int op)464 context_tp_richcompare(PyObject *v, PyObject *w, int op)
465 {
466     if (!PyContext_CheckExact(v) || !PyContext_CheckExact(w) ||
467             (op != Py_EQ && op != Py_NE))
468     {
469         Py_RETURN_NOTIMPLEMENTED;
470     }
471 
472     int res = _PyHamt_Eq(
473         ((PyContext *)v)->ctx_vars, ((PyContext *)w)->ctx_vars);
474     if (res < 0) {
475         return NULL;
476     }
477 
478     if (op == Py_NE) {
479         res = !res;
480     }
481 
482     if (res) {
483         Py_RETURN_TRUE;
484     }
485     else {
486         Py_RETURN_FALSE;
487     }
488 }
489 
490 static Py_ssize_t
context_tp_len(PyContext * self)491 context_tp_len(PyContext *self)
492 {
493     return _PyHamt_Len(self->ctx_vars);
494 }
495 
496 static PyObject *
context_tp_subscript(PyContext * self,PyObject * key)497 context_tp_subscript(PyContext *self, PyObject *key)
498 {
499     if (context_check_key_type(key)) {
500         return NULL;
501     }
502     PyObject *val = NULL;
503     int found = _PyHamt_Find(self->ctx_vars, key, &val);
504     if (found < 0) {
505         return NULL;
506     }
507     if (found == 0) {
508         PyErr_SetObject(PyExc_KeyError, key);
509         return NULL;
510     }
511     Py_INCREF(val);
512     return val;
513 }
514 
515 static int
context_tp_contains(PyContext * self,PyObject * key)516 context_tp_contains(PyContext *self, PyObject *key)
517 {
518     if (context_check_key_type(key)) {
519         return -1;
520     }
521     PyObject *val = NULL;
522     return _PyHamt_Find(self->ctx_vars, key, &val);
523 }
524 
525 
526 /*[clinic input]
527 _contextvars.Context.get
528     key: object
529     default: object = None
530     /
531 [clinic start generated code]*/
532 
533 static PyObject *
_contextvars_Context_get_impl(PyContext * self,PyObject * key,PyObject * default_value)534 _contextvars_Context_get_impl(PyContext *self, PyObject *key,
535                               PyObject *default_value)
536 /*[clinic end generated code: output=0c54aa7664268189 input=8d4c33c8ecd6d769]*/
537 {
538     if (context_check_key_type(key)) {
539         return NULL;
540     }
541 
542     PyObject *val = NULL;
543     int found = _PyHamt_Find(self->ctx_vars, key, &val);
544     if (found < 0) {
545         return NULL;
546     }
547     if (found == 0) {
548         Py_INCREF(default_value);
549         return default_value;
550     }
551     Py_INCREF(val);
552     return val;
553 }
554 
555 
556 /*[clinic input]
557 _contextvars.Context.items
558 [clinic start generated code]*/
559 
560 static PyObject *
_contextvars_Context_items_impl(PyContext * self)561 _contextvars_Context_items_impl(PyContext *self)
562 /*[clinic end generated code: output=fa1655c8a08502af input=2d570d1455004979]*/
563 {
564     return _PyHamt_NewIterItems(self->ctx_vars);
565 }
566 
567 
568 /*[clinic input]
569 _contextvars.Context.keys
570 [clinic start generated code]*/
571 
572 static PyObject *
_contextvars_Context_keys_impl(PyContext * self)573 _contextvars_Context_keys_impl(PyContext *self)
574 /*[clinic end generated code: output=177227c6b63ec0e2 input=13005e142fbbf37d]*/
575 {
576     return _PyHamt_NewIterKeys(self->ctx_vars);
577 }
578 
579 
580 /*[clinic input]
581 _contextvars.Context.values
582 [clinic start generated code]*/
583 
584 static PyObject *
_contextvars_Context_values_impl(PyContext * self)585 _contextvars_Context_values_impl(PyContext *self)
586 /*[clinic end generated code: output=d286dabfc8db6dde input=c2cbc40a4470e905]*/
587 {
588     return _PyHamt_NewIterValues(self->ctx_vars);
589 }
590 
591 
592 /*[clinic input]
593 _contextvars.Context.copy
594 [clinic start generated code]*/
595 
596 static PyObject *
_contextvars_Context_copy_impl(PyContext * self)597 _contextvars_Context_copy_impl(PyContext *self)
598 /*[clinic end generated code: output=30ba8896c4707a15 input=3e3fd72d598653ab]*/
599 {
600     return (PyObject *)context_new_from_vars(self->ctx_vars);
601 }
602 
603 
604 static PyObject *
context_run(PyContext * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)605 context_run(PyContext *self, PyObject *const *args,
606             Py_ssize_t nargs, PyObject *kwnames)
607 {
608     if (nargs < 1) {
609         PyErr_SetString(PyExc_TypeError,
610                         "run() missing 1 required positional argument");
611         return NULL;
612     }
613 
614     if (PyContext_Enter((PyObject *)self)) {
615         return NULL;
616     }
617 
618     PyObject *call_result = _PyObject_FastCallKeywords(
619         args[0], args + 1, nargs - 1, kwnames);
620 
621     if (PyContext_Exit((PyObject *)self)) {
622         return NULL;
623     }
624 
625     return call_result;
626 }
627 
628 
629 static PyMethodDef PyContext_methods[] = {
630     _CONTEXTVARS_CONTEXT_GET_METHODDEF
631     _CONTEXTVARS_CONTEXT_ITEMS_METHODDEF
632     _CONTEXTVARS_CONTEXT_KEYS_METHODDEF
633     _CONTEXTVARS_CONTEXT_VALUES_METHODDEF
634     _CONTEXTVARS_CONTEXT_COPY_METHODDEF
635     {"run", (PyCFunction)context_run, METH_FASTCALL | METH_KEYWORDS, NULL},
636     {NULL, NULL}
637 };
638 
639 static PySequenceMethods PyContext_as_sequence = {
640     0,                                   /* sq_length */
641     0,                                   /* sq_concat */
642     0,                                   /* sq_repeat */
643     0,                                   /* sq_item */
644     0,                                   /* sq_slice */
645     0,                                   /* sq_ass_item */
646     0,                                   /* sq_ass_slice */
647     (objobjproc)context_tp_contains,     /* sq_contains */
648     0,                                   /* sq_inplace_concat */
649     0,                                   /* sq_inplace_repeat */
650 };
651 
652 static PyMappingMethods PyContext_as_mapping = {
653     (lenfunc)context_tp_len,             /* mp_length */
654     (binaryfunc)context_tp_subscript,    /* mp_subscript */
655 };
656 
657 PyTypeObject PyContext_Type = {
658     PyVarObject_HEAD_INIT(&PyType_Type, 0)
659     "Context",
660     sizeof(PyContext),
661     .tp_methods = PyContext_methods,
662     .tp_as_mapping = &PyContext_as_mapping,
663     .tp_as_sequence = &PyContext_as_sequence,
664     .tp_iter = (getiterfunc)context_tp_iter,
665     .tp_dealloc = (destructor)context_tp_dealloc,
666     .tp_getattro = PyObject_GenericGetAttr,
667     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
668     .tp_richcompare = context_tp_richcompare,
669     .tp_traverse = (traverseproc)context_tp_traverse,
670     .tp_clear = (inquiry)context_tp_clear,
671     .tp_new = context_tp_new,
672     .tp_weaklistoffset = offsetof(PyContext, ctx_weakreflist),
673     .tp_hash = PyObject_HashNotImplemented,
674 };
675 
676 
677 /////////////////////////// ContextVar
678 
679 
680 static int
contextvar_set(PyContextVar * var,PyObject * val)681 contextvar_set(PyContextVar *var, PyObject *val)
682 {
683     var->var_cached = NULL;
684     PyThreadState *ts = PyThreadState_Get();
685 
686     PyContext *ctx = context_get();
687     if (ctx == NULL) {
688         return -1;
689     }
690 
691     PyHamtObject *new_vars = _PyHamt_Assoc(
692         ctx->ctx_vars, (PyObject *)var, val);
693     if (new_vars == NULL) {
694         return -1;
695     }
696 
697     Py_SETREF(ctx->ctx_vars, new_vars);
698 
699     var->var_cached = val;  /* borrow */
700     var->var_cached_tsid = ts->id;
701     var->var_cached_tsver = ts->context_ver;
702     return 0;
703 }
704 
705 static int
contextvar_del(PyContextVar * var)706 contextvar_del(PyContextVar *var)
707 {
708     var->var_cached = NULL;
709 
710     PyContext *ctx = context_get();
711     if (ctx == NULL) {
712         return -1;
713     }
714 
715     PyHamtObject *vars = ctx->ctx_vars;
716     PyHamtObject *new_vars = _PyHamt_Without(vars, (PyObject *)var);
717     if (new_vars == NULL) {
718         return -1;
719     }
720 
721     if (vars == new_vars) {
722         Py_DECREF(new_vars);
723         PyErr_SetObject(PyExc_LookupError, (PyObject *)var);
724         return -1;
725     }
726 
727     Py_SETREF(ctx->ctx_vars, new_vars);
728     return 0;
729 }
730 
731 static Py_hash_t
contextvar_generate_hash(void * addr,PyObject * name)732 contextvar_generate_hash(void *addr, PyObject *name)
733 {
734     /* Take hash of `name` and XOR it with the object's addr.
735 
736        The structure of the tree is encoded in objects' hashes, which
737        means that sufficiently similar hashes would result in tall trees
738        with many Collision nodes.  Which would, in turn, result in slower
739        get and set operations.
740 
741        The XORing helps to ensure that:
742 
743        (1) sequentially allocated ContextVar objects have
744            different hashes;
745 
746        (2) context variables with equal names have
747            different hashes.
748     */
749 
750     Py_hash_t name_hash = PyObject_Hash(name);
751     if (name_hash == -1) {
752         return -1;
753     }
754 
755     Py_hash_t res = _Py_HashPointer(addr) ^ name_hash;
756     return res == -1 ? -2 : res;
757 }
758 
759 static PyContextVar *
contextvar_new(PyObject * name,PyObject * def)760 contextvar_new(PyObject *name, PyObject *def)
761 {
762     if (!PyUnicode_Check(name)) {
763         PyErr_SetString(PyExc_TypeError,
764                         "context variable name must be a str");
765         return NULL;
766     }
767 
768     PyContextVar *var = PyObject_GC_New(PyContextVar, &PyContextVar_Type);
769     if (var == NULL) {
770         return NULL;
771     }
772 
773     var->var_hash = contextvar_generate_hash(var, name);
774     if (var->var_hash == -1) {
775         Py_DECREF(var);
776         return NULL;
777     }
778 
779     Py_INCREF(name);
780     var->var_name = name;
781 
782     Py_XINCREF(def);
783     var->var_default = def;
784 
785     var->var_cached = NULL;
786     var->var_cached_tsid = 0;
787     var->var_cached_tsver = 0;
788 
789     if (_PyObject_GC_MAY_BE_TRACKED(name) ||
790             (def != NULL && _PyObject_GC_MAY_BE_TRACKED(def)))
791     {
792         PyObject_GC_Track(var);
793     }
794     return var;
795 }
796 
797 
798 /*[clinic input]
799 class _contextvars.ContextVar "PyContextVar *" "&PyContextVar_Type"
800 [clinic start generated code]*/
801 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=445da935fa8883c3]*/
802 
803 
804 static PyObject *
contextvar_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)805 contextvar_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
806 {
807     static char *kwlist[] = {"", "default", NULL};
808     PyObject *name;
809     PyObject *def = NULL;
810 
811     if (!PyArg_ParseTupleAndKeywords(
812             args, kwds, "O|$O:ContextVar", kwlist, &name, &def))
813     {
814         return NULL;
815     }
816 
817     return (PyObject *)contextvar_new(name, def);
818 }
819 
820 static int
contextvar_tp_clear(PyContextVar * self)821 contextvar_tp_clear(PyContextVar *self)
822 {
823     Py_CLEAR(self->var_name);
824     Py_CLEAR(self->var_default);
825     self->var_cached = NULL;
826     self->var_cached_tsid = 0;
827     self->var_cached_tsver = 0;
828     return 0;
829 }
830 
831 static int
contextvar_tp_traverse(PyContextVar * self,visitproc visit,void * arg)832 contextvar_tp_traverse(PyContextVar *self, visitproc visit, void *arg)
833 {
834     Py_VISIT(self->var_name);
835     Py_VISIT(self->var_default);
836     return 0;
837 }
838 
839 static void
contextvar_tp_dealloc(PyContextVar * self)840 contextvar_tp_dealloc(PyContextVar *self)
841 {
842     PyObject_GC_UnTrack(self);
843     (void)contextvar_tp_clear(self);
844     Py_TYPE(self)->tp_free(self);
845 }
846 
847 static Py_hash_t
contextvar_tp_hash(PyContextVar * self)848 contextvar_tp_hash(PyContextVar *self)
849 {
850     return self->var_hash;
851 }
852 
853 static PyObject *
contextvar_tp_repr(PyContextVar * self)854 contextvar_tp_repr(PyContextVar *self)
855 {
856     _PyUnicodeWriter writer;
857 
858     _PyUnicodeWriter_Init(&writer);
859 
860     if (_PyUnicodeWriter_WriteASCIIString(
861             &writer, "<ContextVar name=", 17) < 0)
862     {
863         goto error;
864     }
865 
866     PyObject *name = PyObject_Repr(self->var_name);
867     if (name == NULL) {
868         goto error;
869     }
870     if (_PyUnicodeWriter_WriteStr(&writer, name) < 0) {
871         Py_DECREF(name);
872         goto error;
873     }
874     Py_DECREF(name);
875 
876     if (self->var_default != NULL) {
877         if (_PyUnicodeWriter_WriteASCIIString(&writer, " default=", 9) < 0) {
878             goto error;
879         }
880 
881         PyObject *def = PyObject_Repr(self->var_default);
882         if (def == NULL) {
883             goto error;
884         }
885         if (_PyUnicodeWriter_WriteStr(&writer, def) < 0) {
886             Py_DECREF(def);
887             goto error;
888         }
889         Py_DECREF(def);
890     }
891 
892     PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
893     if (addr == NULL) {
894         goto error;
895     }
896     if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
897         Py_DECREF(addr);
898         goto error;
899     }
900     Py_DECREF(addr);
901 
902     return _PyUnicodeWriter_Finish(&writer);
903 
904 error:
905     _PyUnicodeWriter_Dealloc(&writer);
906     return NULL;
907 }
908 
909 
910 /*[clinic input]
911 _contextvars.ContextVar.get
912     default: object = NULL
913     /
914 [clinic start generated code]*/
915 
916 static PyObject *
_contextvars_ContextVar_get_impl(PyContextVar * self,PyObject * default_value)917 _contextvars_ContextVar_get_impl(PyContextVar *self, PyObject *default_value)
918 /*[clinic end generated code: output=0746bd0aa2ced7bf input=8d002b02eebbb247]*/
919 {
920     if (!PyContextVar_CheckExact(self)) {
921         PyErr_SetString(
922             PyExc_TypeError, "an instance of ContextVar was expected");
923         return NULL;
924     }
925 
926     PyObject *val;
927     if (PyContextVar_Get((PyObject *)self, default_value, &val) < 0) {
928         return NULL;
929     }
930 
931     if (val == NULL) {
932         PyErr_SetObject(PyExc_LookupError, (PyObject *)self);
933         return NULL;
934     }
935 
936     return val;
937 }
938 
939 /*[clinic input]
940 _contextvars.ContextVar.set
941     value: object
942     /
943 [clinic start generated code]*/
944 
945 static PyObject *
_contextvars_ContextVar_set(PyContextVar * self,PyObject * value)946 _contextvars_ContextVar_set(PyContextVar *self, PyObject *value)
947 /*[clinic end generated code: output=446ed5e820d6d60b input=a2d88f57c6d86f7c]*/
948 {
949     return PyContextVar_Set((PyObject *)self, value);
950 }
951 
952 /*[clinic input]
953 _contextvars.ContextVar.reset
954     token: object
955     /
956 [clinic start generated code]*/
957 
958 static PyObject *
_contextvars_ContextVar_reset(PyContextVar * self,PyObject * token)959 _contextvars_ContextVar_reset(PyContextVar *self, PyObject *token)
960 /*[clinic end generated code: output=d4ee34d0742d62ee input=4c871b6f1f31a65f]*/
961 {
962     if (!PyContextToken_CheckExact(token)) {
963         PyErr_Format(PyExc_TypeError,
964                      "expected an instance of Token, got %R", token);
965         return NULL;
966     }
967 
968     if (PyContextVar_Reset((PyObject *)self, token)) {
969         return NULL;
970     }
971 
972     Py_RETURN_NONE;
973 }
974 
975 
976 static PyObject *
contextvar_cls_getitem(PyObject * self,PyObject * args)977 contextvar_cls_getitem(PyObject *self, PyObject *args)
978 {
979     Py_RETURN_NONE;
980 }
981 
982 static PyMemberDef PyContextVar_members[] = {
983     {"name", T_OBJECT, offsetof(PyContextVar, var_name), READONLY},
984     {NULL}
985 };
986 
987 static PyMethodDef PyContextVar_methods[] = {
988     _CONTEXTVARS_CONTEXTVAR_GET_METHODDEF
989     _CONTEXTVARS_CONTEXTVAR_SET_METHODDEF
990     _CONTEXTVARS_CONTEXTVAR_RESET_METHODDEF
991     {"__class_getitem__", contextvar_cls_getitem,
992         METH_VARARGS | METH_STATIC, NULL},
993     {NULL, NULL}
994 };
995 
996 PyTypeObject PyContextVar_Type = {
997     PyVarObject_HEAD_INIT(&PyType_Type, 0)
998     "ContextVar",
999     sizeof(PyContextVar),
1000     .tp_methods = PyContextVar_methods,
1001     .tp_members = PyContextVar_members,
1002     .tp_dealloc = (destructor)contextvar_tp_dealloc,
1003     .tp_getattro = PyObject_GenericGetAttr,
1004     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
1005     .tp_traverse = (traverseproc)contextvar_tp_traverse,
1006     .tp_clear = (inquiry)contextvar_tp_clear,
1007     .tp_new = contextvar_tp_new,
1008     .tp_free = PyObject_GC_Del,
1009     .tp_hash = (hashfunc)contextvar_tp_hash,
1010     .tp_repr = (reprfunc)contextvar_tp_repr,
1011 };
1012 
1013 
1014 /////////////////////////// Token
1015 
1016 static PyObject * get_token_missing(void);
1017 
1018 
1019 /*[clinic input]
1020 class _contextvars.Token "PyContextToken *" "&PyContextToken_Type"
1021 [clinic start generated code]*/
1022 /*[clinic end generated code: output=da39a3ee5e6b4b0d input=338a5e2db13d3f5b]*/
1023 
1024 
1025 static PyObject *
token_tp_new(PyTypeObject * type,PyObject * args,PyObject * kwds)1026 token_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
1027 {
1028     PyErr_SetString(PyExc_RuntimeError,
1029                     "Tokens can only be created by ContextVars");
1030     return NULL;
1031 }
1032 
1033 static int
token_tp_clear(PyContextToken * self)1034 token_tp_clear(PyContextToken *self)
1035 {
1036     Py_CLEAR(self->tok_ctx);
1037     Py_CLEAR(self->tok_var);
1038     Py_CLEAR(self->tok_oldval);
1039     return 0;
1040 }
1041 
1042 static int
token_tp_traverse(PyContextToken * self,visitproc visit,void * arg)1043 token_tp_traverse(PyContextToken *self, visitproc visit, void *arg)
1044 {
1045     Py_VISIT(self->tok_ctx);
1046     Py_VISIT(self->tok_var);
1047     Py_VISIT(self->tok_oldval);
1048     return 0;
1049 }
1050 
1051 static void
token_tp_dealloc(PyContextToken * self)1052 token_tp_dealloc(PyContextToken *self)
1053 {
1054     PyObject_GC_UnTrack(self);
1055     (void)token_tp_clear(self);
1056     Py_TYPE(self)->tp_free(self);
1057 }
1058 
1059 static PyObject *
token_tp_repr(PyContextToken * self)1060 token_tp_repr(PyContextToken *self)
1061 {
1062     _PyUnicodeWriter writer;
1063 
1064     _PyUnicodeWriter_Init(&writer);
1065 
1066     if (_PyUnicodeWriter_WriteASCIIString(&writer, "<Token", 6) < 0) {
1067         goto error;
1068     }
1069 
1070     if (self->tok_used) {
1071         if (_PyUnicodeWriter_WriteASCIIString(&writer, " used", 5) < 0) {
1072             goto error;
1073         }
1074     }
1075 
1076     if (_PyUnicodeWriter_WriteASCIIString(&writer, " var=", 5) < 0) {
1077         goto error;
1078     }
1079 
1080     PyObject *var = PyObject_Repr((PyObject *)self->tok_var);
1081     if (var == NULL) {
1082         goto error;
1083     }
1084     if (_PyUnicodeWriter_WriteStr(&writer, var) < 0) {
1085         Py_DECREF(var);
1086         goto error;
1087     }
1088     Py_DECREF(var);
1089 
1090     PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
1091     if (addr == NULL) {
1092         goto error;
1093     }
1094     if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
1095         Py_DECREF(addr);
1096         goto error;
1097     }
1098     Py_DECREF(addr);
1099 
1100     return _PyUnicodeWriter_Finish(&writer);
1101 
1102 error:
1103     _PyUnicodeWriter_Dealloc(&writer);
1104     return NULL;
1105 }
1106 
1107 static PyObject *
token_get_var(PyContextToken * self,void * Py_UNUSED (ignored))1108 token_get_var(PyContextToken *self, void *Py_UNUSED(ignored))
1109 {
1110     Py_INCREF(self->tok_var);
1111     return (PyObject *)self->tok_var;
1112 }
1113 
1114 static PyObject *
token_get_old_value(PyContextToken * self,void * Py_UNUSED (ignored))1115 token_get_old_value(PyContextToken *self, void *Py_UNUSED(ignored))
1116 {
1117     if (self->tok_oldval == NULL) {
1118         return get_token_missing();
1119     }
1120 
1121     Py_INCREF(self->tok_oldval);
1122     return self->tok_oldval;
1123 }
1124 
1125 static PyGetSetDef PyContextTokenType_getsetlist[] = {
1126     {"var", (getter)token_get_var, NULL, NULL},
1127     {"old_value", (getter)token_get_old_value, NULL, NULL},
1128     {NULL}
1129 };
1130 
1131 PyTypeObject PyContextToken_Type = {
1132     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1133     "Token",
1134     sizeof(PyContextToken),
1135     .tp_getset = PyContextTokenType_getsetlist,
1136     .tp_dealloc = (destructor)token_tp_dealloc,
1137     .tp_getattro = PyObject_GenericGetAttr,
1138     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
1139     .tp_traverse = (traverseproc)token_tp_traverse,
1140     .tp_clear = (inquiry)token_tp_clear,
1141     .tp_new = token_tp_new,
1142     .tp_free = PyObject_GC_Del,
1143     .tp_hash = PyObject_HashNotImplemented,
1144     .tp_repr = (reprfunc)token_tp_repr,
1145 };
1146 
1147 static PyContextToken *
token_new(PyContext * ctx,PyContextVar * var,PyObject * val)1148 token_new(PyContext *ctx, PyContextVar *var, PyObject *val)
1149 {
1150     PyContextToken *tok = PyObject_GC_New(PyContextToken, &PyContextToken_Type);
1151     if (tok == NULL) {
1152         return NULL;
1153     }
1154 
1155     Py_INCREF(ctx);
1156     tok->tok_ctx = ctx;
1157 
1158     Py_INCREF(var);
1159     tok->tok_var = var;
1160 
1161     Py_XINCREF(val);
1162     tok->tok_oldval = val;
1163 
1164     tok->tok_used = 0;
1165 
1166     PyObject_GC_Track(tok);
1167     return tok;
1168 }
1169 
1170 
1171 /////////////////////////// Token.MISSING
1172 
1173 
1174 static PyObject *_token_missing;
1175 
1176 
1177 typedef struct {
1178     PyObject_HEAD
1179 } PyContextTokenMissing;
1180 
1181 
1182 static PyObject *
context_token_missing_tp_repr(PyObject * self)1183 context_token_missing_tp_repr(PyObject *self)
1184 {
1185     return PyUnicode_FromString("<Token.MISSING>");
1186 }
1187 
1188 
1189 PyTypeObject PyContextTokenMissing_Type = {
1190     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1191     "Token.MISSING",
1192     sizeof(PyContextTokenMissing),
1193     .tp_getattro = PyObject_GenericGetAttr,
1194     .tp_flags = Py_TPFLAGS_DEFAULT,
1195     .tp_repr = context_token_missing_tp_repr,
1196 };
1197 
1198 
1199 static PyObject *
get_token_missing(void)1200 get_token_missing(void)
1201 {
1202     if (_token_missing != NULL) {
1203         Py_INCREF(_token_missing);
1204         return _token_missing;
1205     }
1206 
1207     _token_missing = (PyObject *)PyObject_New(
1208         PyContextTokenMissing, &PyContextTokenMissing_Type);
1209     if (_token_missing == NULL) {
1210         return NULL;
1211     }
1212 
1213     Py_INCREF(_token_missing);
1214     return _token_missing;
1215 }
1216 
1217 
1218 ///////////////////////////
1219 
1220 
1221 int
PyContext_ClearFreeList(void)1222 PyContext_ClearFreeList(void)
1223 {
1224     int size = ctx_freelist_len;
1225     while (ctx_freelist_len) {
1226         PyContext *ctx = ctx_freelist;
1227         ctx_freelist = (PyContext *)ctx->ctx_weakreflist;
1228         ctx->ctx_weakreflist = NULL;
1229         PyObject_GC_Del(ctx);
1230         ctx_freelist_len--;
1231     }
1232     return size;
1233 }
1234 
1235 
1236 void
_PyContext_Fini(void)1237 _PyContext_Fini(void)
1238 {
1239     Py_CLEAR(_token_missing);
1240     (void)PyContext_ClearFreeList();
1241     (void)_PyHamt_Fini();
1242 }
1243 
1244 
1245 int
_PyContext_Init(void)1246 _PyContext_Init(void)
1247 {
1248     if (!_PyHamt_Init()) {
1249         return 0;
1250     }
1251 
1252     if ((PyType_Ready(&PyContext_Type) < 0) ||
1253         (PyType_Ready(&PyContextVar_Type) < 0) ||
1254         (PyType_Ready(&PyContextToken_Type) < 0) ||
1255         (PyType_Ready(&PyContextTokenMissing_Type) < 0))
1256     {
1257         return 0;
1258     }
1259 
1260     PyObject *missing = get_token_missing();
1261     if (PyDict_SetItemString(
1262         PyContextToken_Type.tp_dict, "MISSING", missing))
1263     {
1264         Py_DECREF(missing);
1265         return 0;
1266     }
1267     Py_DECREF(missing);
1268 
1269     return 1;
1270 }
1271