1import contextlib
2import os
3import sys
4import tracemalloc
5import unittest
6from unittest.mock import patch
7from test.support.script_helper import (assert_python_ok, assert_python_failure,
8                                        interpreter_requires_environment)
9from test import support
10
11try:
12    import _testcapi
13except ImportError:
14    _testcapi = None
15
16
17EMPTY_STRING_SIZE = sys.getsizeof(b'')
18INVALID_NFRAME = (-1, 2**30)
19
20
21def get_frames(nframe, lineno_delta):
22    frames = []
23    frame = sys._getframe(1)
24    for index in range(nframe):
25        code = frame.f_code
26        lineno = frame.f_lineno + lineno_delta
27        frames.append((code.co_filename, lineno))
28        lineno_delta = 0
29        frame = frame.f_back
30        if frame is None:
31            break
32    return tuple(frames)
33
34def allocate_bytes(size):
35    nframe = tracemalloc.get_traceback_limit()
36    bytes_len = (size - EMPTY_STRING_SIZE)
37    frames = get_frames(nframe, 1)
38    data = b'x' * bytes_len
39    return data, tracemalloc.Traceback(frames, min(len(frames), nframe))
40
41def create_snapshots():
42    traceback_limit = 2
43
44    # _tracemalloc._get_traces() returns a list of (domain, size,
45    # traceback_frames) tuples. traceback_frames is a tuple of (filename,
46    # line_number) tuples.
47    raw_traces = [
48        (0, 10, (('a.py', 2), ('b.py', 4)), 3),
49        (0, 10, (('a.py', 2), ('b.py', 4)), 3),
50        (0, 10, (('a.py', 2), ('b.py', 4)), 3),
51
52        (1, 2, (('a.py', 5), ('b.py', 4)), 3),
53
54        (2, 66, (('b.py', 1),), 1),
55
56        (3, 7, (('<unknown>', 0),), 1),
57    ]
58    snapshot = tracemalloc.Snapshot(raw_traces, traceback_limit)
59
60    raw_traces2 = [
61        (0, 10, (('a.py', 2), ('b.py', 4)), 3),
62        (0, 10, (('a.py', 2), ('b.py', 4)), 3),
63        (0, 10, (('a.py', 2), ('b.py', 4)), 3),
64
65        (2, 2, (('a.py', 5), ('b.py', 4)), 3),
66        (2, 5000, (('a.py', 5), ('b.py', 4)), 3),
67
68        (4, 400, (('c.py', 578),), 1),
69    ]
70    snapshot2 = tracemalloc.Snapshot(raw_traces2, traceback_limit)
71
72    return (snapshot, snapshot2)
73
74def frame(filename, lineno):
75    return tracemalloc._Frame((filename, lineno))
76
77def traceback(*frames):
78    return tracemalloc.Traceback(frames)
79
80def traceback_lineno(filename, lineno):
81    return traceback((filename, lineno))
82
83def traceback_filename(filename):
84    return traceback_lineno(filename, 0)
85
86
87class TestTracemallocEnabled(unittest.TestCase):
88    def setUp(self):
89        if tracemalloc.is_tracing():
90            self.skipTest("tracemalloc must be stopped before the test")
91
92        tracemalloc.start(1)
93
94    def tearDown(self):
95        tracemalloc.stop()
96
97    def test_get_tracemalloc_memory(self):
98        data = [allocate_bytes(123) for count in range(1000)]
99        size = tracemalloc.get_tracemalloc_memory()
100        self.assertGreaterEqual(size, 0)
101
102        tracemalloc.clear_traces()
103        size2 = tracemalloc.get_tracemalloc_memory()
104        self.assertGreaterEqual(size2, 0)
105        self.assertLessEqual(size2, size)
106
107    def test_get_object_traceback(self):
108        tracemalloc.clear_traces()
109        obj_size = 12345
110        obj, obj_traceback = allocate_bytes(obj_size)
111        traceback = tracemalloc.get_object_traceback(obj)
112        self.assertEqual(traceback, obj_traceback)
113
114    def test_new_reference(self):
115        tracemalloc.clear_traces()
116        # gc.collect() indirectly calls PyList_ClearFreeList()
117        support.gc_collect()
118
119        # Create a list and "destroy it": put it in the PyListObject free list
120        obj = []
121        obj = None
122
123        # Create a list which should reuse the previously created empty list
124        obj = []
125
126        nframe = tracemalloc.get_traceback_limit()
127        frames = get_frames(nframe, -3)
128        obj_traceback = tracemalloc.Traceback(frames, min(len(frames), nframe))
129
130        traceback = tracemalloc.get_object_traceback(obj)
131        self.assertIsNotNone(traceback)
132        self.assertEqual(traceback, obj_traceback)
133
134    def test_set_traceback_limit(self):
135        obj_size = 10
136
137        tracemalloc.stop()
138        self.assertRaises(ValueError, tracemalloc.start, -1)
139
140        tracemalloc.stop()
141        tracemalloc.start(10)
142        obj2, obj2_traceback = allocate_bytes(obj_size)
143        traceback = tracemalloc.get_object_traceback(obj2)
144        self.assertEqual(len(traceback), 10)
145        self.assertEqual(traceback, obj2_traceback)
146
147        tracemalloc.stop()
148        tracemalloc.start(1)
149        obj, obj_traceback = allocate_bytes(obj_size)
150        traceback = tracemalloc.get_object_traceback(obj)
151        self.assertEqual(len(traceback), 1)
152        self.assertEqual(traceback, obj_traceback)
153
154    def find_trace(self, traces, traceback):
155        for trace in traces:
156            if trace[2] == traceback._frames:
157                return trace
158
159        self.fail("trace not found")
160
161    def test_get_traces(self):
162        tracemalloc.clear_traces()
163        obj_size = 12345
164        obj, obj_traceback = allocate_bytes(obj_size)
165
166        traces = tracemalloc._get_traces()
167        trace = self.find_trace(traces, obj_traceback)
168
169        self.assertIsInstance(trace, tuple)
170        domain, size, traceback, length = trace
171        self.assertEqual(size, obj_size)
172        self.assertEqual(traceback, obj_traceback._frames)
173
174        tracemalloc.stop()
175        self.assertEqual(tracemalloc._get_traces(), [])
176
177    def test_get_traces_intern_traceback(self):
178        # dummy wrappers to get more useful and identical frames in the traceback
179        def allocate_bytes2(size):
180            return allocate_bytes(size)
181        def allocate_bytes3(size):
182            return allocate_bytes2(size)
183        def allocate_bytes4(size):
184            return allocate_bytes3(size)
185
186        # Ensure that two identical tracebacks are not duplicated
187        tracemalloc.stop()
188        tracemalloc.start(4)
189        obj_size = 123
190        obj1, obj1_traceback = allocate_bytes4(obj_size)
191        obj2, obj2_traceback = allocate_bytes4(obj_size)
192
193        traces = tracemalloc._get_traces()
194
195        obj1_traceback._frames = tuple(reversed(obj1_traceback._frames))
196        obj2_traceback._frames = tuple(reversed(obj2_traceback._frames))
197
198        trace1 = self.find_trace(traces, obj1_traceback)
199        trace2 = self.find_trace(traces, obj2_traceback)
200        domain1, size1, traceback1, length1 = trace1
201        domain2, size2, traceback2, length2 = trace2
202        self.assertIs(traceback2, traceback1)
203
204    def test_get_traced_memory(self):
205        # Python allocates some internals objects, so the test must tolerate
206        # a small difference between the expected size and the real usage
207        max_error = 2048
208
209        # allocate one object
210        obj_size = 1024 * 1024
211        tracemalloc.clear_traces()
212        obj, obj_traceback = allocate_bytes(obj_size)
213        size, peak_size = tracemalloc.get_traced_memory()
214        self.assertGreaterEqual(size, obj_size)
215        self.assertGreaterEqual(peak_size, size)
216
217        self.assertLessEqual(size - obj_size, max_error)
218        self.assertLessEqual(peak_size - size, max_error)
219
220        # destroy the object
221        obj = None
222        size2, peak_size2 = tracemalloc.get_traced_memory()
223        self.assertLess(size2, size)
224        self.assertGreaterEqual(size - size2, obj_size - max_error)
225        self.assertGreaterEqual(peak_size2, peak_size)
226
227        # clear_traces() must reset traced memory counters
228        tracemalloc.clear_traces()
229        self.assertEqual(tracemalloc.get_traced_memory(), (0, 0))
230
231        # allocate another object
232        obj, obj_traceback = allocate_bytes(obj_size)
233        size, peak_size = tracemalloc.get_traced_memory()
234        self.assertGreaterEqual(size, obj_size)
235
236        # stop() also resets traced memory counters
237        tracemalloc.stop()
238        self.assertEqual(tracemalloc.get_traced_memory(), (0, 0))
239
240    def test_clear_traces(self):
241        obj, obj_traceback = allocate_bytes(123)
242        traceback = tracemalloc.get_object_traceback(obj)
243        self.assertIsNotNone(traceback)
244
245        tracemalloc.clear_traces()
246        traceback2 = tracemalloc.get_object_traceback(obj)
247        self.assertIsNone(traceback2)
248
249    def test_reset_peak(self):
250        # Python allocates some internals objects, so the test must tolerate
251        # a small difference between the expected size and the real usage
252        tracemalloc.clear_traces()
253
254        # Example: allocate a large piece of memory, temporarily
255        large_sum = sum(list(range(100000)))
256        size1, peak1 = tracemalloc.get_traced_memory()
257
258        # reset_peak() resets peak to traced memory: peak2 < peak1
259        tracemalloc.reset_peak()
260        size2, peak2 = tracemalloc.get_traced_memory()
261        self.assertGreaterEqual(peak2, size2)
262        self.assertLess(peak2, peak1)
263
264        # check that peak continue to be updated if new memory is allocated:
265        # peak3 > peak2
266        obj_size = 1024 * 1024
267        obj, obj_traceback = allocate_bytes(obj_size)
268        size3, peak3 = tracemalloc.get_traced_memory()
269        self.assertGreaterEqual(peak3, size3)
270        self.assertGreater(peak3, peak2)
271        self.assertGreaterEqual(peak3 - peak2, obj_size)
272
273    def test_is_tracing(self):
274        tracemalloc.stop()
275        self.assertFalse(tracemalloc.is_tracing())
276
277        tracemalloc.start()
278        self.assertTrue(tracemalloc.is_tracing())
279
280    def test_snapshot(self):
281        obj, source = allocate_bytes(123)
282
283        # take a snapshot
284        snapshot = tracemalloc.take_snapshot()
285
286        # This can vary
287        self.assertGreater(snapshot.traces[1].traceback.total_nframe, 10)
288
289        # write on disk
290        snapshot.dump(support.TESTFN)
291        self.addCleanup(support.unlink, support.TESTFN)
292
293        # load from disk
294        snapshot2 = tracemalloc.Snapshot.load(support.TESTFN)
295        self.assertEqual(snapshot2.traces, snapshot.traces)
296
297        # tracemalloc must be tracing memory allocations to take a snapshot
298        tracemalloc.stop()
299        with self.assertRaises(RuntimeError) as cm:
300            tracemalloc.take_snapshot()
301        self.assertEqual(str(cm.exception),
302                         "the tracemalloc module must be tracing memory "
303                         "allocations to take a snapshot")
304
305    def test_snapshot_save_attr(self):
306        # take a snapshot with a new attribute
307        snapshot = tracemalloc.take_snapshot()
308        snapshot.test_attr = "new"
309        snapshot.dump(support.TESTFN)
310        self.addCleanup(support.unlink, support.TESTFN)
311
312        # load() should recreate the attribute
313        snapshot2 = tracemalloc.Snapshot.load(support.TESTFN)
314        self.assertEqual(snapshot2.test_attr, "new")
315
316    def fork_child(self):
317        if not tracemalloc.is_tracing():
318            return 2
319
320        obj_size = 12345
321        obj, obj_traceback = allocate_bytes(obj_size)
322        traceback = tracemalloc.get_object_traceback(obj)
323        if traceback is None:
324            return 3
325
326        # everything is fine
327        return 0
328
329    @unittest.skipUnless(hasattr(os, 'fork'), 'need os.fork()')
330    def test_fork(self):
331        # check that tracemalloc is still working after fork
332        pid = os.fork()
333        if not pid:
334            # child
335            exitcode = 1
336            try:
337                exitcode = self.fork_child()
338            finally:
339                os._exit(exitcode)
340        else:
341            support.wait_process(pid, exitcode=0)
342
343
344class TestSnapshot(unittest.TestCase):
345    maxDiff = 4000
346
347    def test_create_snapshot(self):
348        raw_traces = [(0, 5, (('a.py', 2),), 10)]
349
350        with contextlib.ExitStack() as stack:
351            stack.enter_context(patch.object(tracemalloc, 'is_tracing',
352                                             return_value=True))
353            stack.enter_context(patch.object(tracemalloc, 'get_traceback_limit',
354                                             return_value=5))
355            stack.enter_context(patch.object(tracemalloc, '_get_traces',
356                                             return_value=raw_traces))
357
358            snapshot = tracemalloc.take_snapshot()
359            self.assertEqual(snapshot.traceback_limit, 5)
360            self.assertEqual(len(snapshot.traces), 1)
361            trace = snapshot.traces[0]
362            self.assertEqual(trace.size, 5)
363            self.assertEqual(trace.traceback.total_nframe, 10)
364            self.assertEqual(len(trace.traceback), 1)
365            self.assertEqual(trace.traceback[0].filename, 'a.py')
366            self.assertEqual(trace.traceback[0].lineno, 2)
367
368    def test_filter_traces(self):
369        snapshot, snapshot2 = create_snapshots()
370        filter1 = tracemalloc.Filter(False, "b.py")
371        filter2 = tracemalloc.Filter(True, "a.py", 2)
372        filter3 = tracemalloc.Filter(True, "a.py", 5)
373
374        original_traces = list(snapshot.traces._traces)
375
376        # exclude b.py
377        snapshot3 = snapshot.filter_traces((filter1,))
378        self.assertEqual(snapshot3.traces._traces, [
379            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
380            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
381            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
382            (1, 2, (('a.py', 5), ('b.py', 4)), 3),
383            (3, 7, (('<unknown>', 0),), 1),
384        ])
385
386        # filter_traces() must not touch the original snapshot
387        self.assertEqual(snapshot.traces._traces, original_traces)
388
389        # only include two lines of a.py
390        snapshot4 = snapshot3.filter_traces((filter2, filter3))
391        self.assertEqual(snapshot4.traces._traces, [
392            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
393            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
394            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
395            (1, 2, (('a.py', 5), ('b.py', 4)), 3),
396        ])
397
398        # No filter: just duplicate the snapshot
399        snapshot5 = snapshot.filter_traces(())
400        self.assertIsNot(snapshot5, snapshot)
401        self.assertIsNot(snapshot5.traces, snapshot.traces)
402        self.assertEqual(snapshot5.traces, snapshot.traces)
403
404        self.assertRaises(TypeError, snapshot.filter_traces, filter1)
405
406    def test_filter_traces_domain(self):
407        snapshot, snapshot2 = create_snapshots()
408        filter1 = tracemalloc.Filter(False, "a.py", domain=1)
409        filter2 = tracemalloc.Filter(True, "a.py", domain=1)
410
411        original_traces = list(snapshot.traces._traces)
412
413        # exclude a.py of domain 1
414        snapshot3 = snapshot.filter_traces((filter1,))
415        self.assertEqual(snapshot3.traces._traces, [
416            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
417            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
418            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
419            (2, 66, (('b.py', 1),), 1),
420            (3, 7, (('<unknown>', 0),), 1),
421        ])
422
423        # include domain 1
424        snapshot3 = snapshot.filter_traces((filter1,))
425        self.assertEqual(snapshot3.traces._traces, [
426            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
427            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
428            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
429            (2, 66, (('b.py', 1),), 1),
430            (3, 7, (('<unknown>', 0),), 1),
431        ])
432
433    def test_filter_traces_domain_filter(self):
434        snapshot, snapshot2 = create_snapshots()
435        filter1 = tracemalloc.DomainFilter(False, domain=3)
436        filter2 = tracemalloc.DomainFilter(True, domain=3)
437
438        # exclude domain 2
439        snapshot3 = snapshot.filter_traces((filter1,))
440        self.assertEqual(snapshot3.traces._traces, [
441            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
442            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
443            (0, 10, (('a.py', 2), ('b.py', 4)), 3),
444            (1, 2, (('a.py', 5), ('b.py', 4)), 3),
445            (2, 66, (('b.py', 1),), 1),
446        ])
447
448        # include domain 2
449        snapshot3 = snapshot.filter_traces((filter2,))
450        self.assertEqual(snapshot3.traces._traces, [
451            (3, 7, (('<unknown>', 0),), 1),
452        ])
453
454    def test_snapshot_group_by_line(self):
455        snapshot, snapshot2 = create_snapshots()
456        tb_0 = traceback_lineno('<unknown>', 0)
457        tb_a_2 = traceback_lineno('a.py', 2)
458        tb_a_5 = traceback_lineno('a.py', 5)
459        tb_b_1 = traceback_lineno('b.py', 1)
460        tb_c_578 = traceback_lineno('c.py', 578)
461
462        # stats per file and line
463        stats1 = snapshot.statistics('lineno')
464        self.assertEqual(stats1, [
465            tracemalloc.Statistic(tb_b_1, 66, 1),
466            tracemalloc.Statistic(tb_a_2, 30, 3),
467            tracemalloc.Statistic(tb_0, 7, 1),
468            tracemalloc.Statistic(tb_a_5, 2, 1),
469        ])
470
471        # stats per file and line (2)
472        stats2 = snapshot2.statistics('lineno')
473        self.assertEqual(stats2, [
474            tracemalloc.Statistic(tb_a_5, 5002, 2),
475            tracemalloc.Statistic(tb_c_578, 400, 1),
476            tracemalloc.Statistic(tb_a_2, 30, 3),
477        ])
478
479        # stats diff per file and line
480        statistics = snapshot2.compare_to(snapshot, 'lineno')
481        self.assertEqual(statistics, [
482            tracemalloc.StatisticDiff(tb_a_5, 5002, 5000, 2, 1),
483            tracemalloc.StatisticDiff(tb_c_578, 400, 400, 1, 1),
484            tracemalloc.StatisticDiff(tb_b_1, 0, -66, 0, -1),
485            tracemalloc.StatisticDiff(tb_0, 0, -7, 0, -1),
486            tracemalloc.StatisticDiff(tb_a_2, 30, 0, 3, 0),
487        ])
488
489    def test_snapshot_group_by_file(self):
490        snapshot, snapshot2 = create_snapshots()
491        tb_0 = traceback_filename('<unknown>')
492        tb_a = traceback_filename('a.py')
493        tb_b = traceback_filename('b.py')
494        tb_c = traceback_filename('c.py')
495
496        # stats per file
497        stats1 = snapshot.statistics('filename')
498        self.assertEqual(stats1, [
499            tracemalloc.Statistic(tb_b, 66, 1),
500            tracemalloc.Statistic(tb_a, 32, 4),
501            tracemalloc.Statistic(tb_0, 7, 1),
502        ])
503
504        # stats per file (2)
505        stats2 = snapshot2.statistics('filename')
506        self.assertEqual(stats2, [
507            tracemalloc.Statistic(tb_a, 5032, 5),
508            tracemalloc.Statistic(tb_c, 400, 1),
509        ])
510
511        # stats diff per file
512        diff = snapshot2.compare_to(snapshot, 'filename')
513        self.assertEqual(diff, [
514            tracemalloc.StatisticDiff(tb_a, 5032, 5000, 5, 1),
515            tracemalloc.StatisticDiff(tb_c, 400, 400, 1, 1),
516            tracemalloc.StatisticDiff(tb_b, 0, -66, 0, -1),
517            tracemalloc.StatisticDiff(tb_0, 0, -7, 0, -1),
518        ])
519
520    def test_snapshot_group_by_traceback(self):
521        snapshot, snapshot2 = create_snapshots()
522
523        # stats per file
524        tb1 = traceback(('a.py', 2), ('b.py', 4))
525        tb2 = traceback(('a.py', 5), ('b.py', 4))
526        tb3 = traceback(('b.py', 1))
527        tb4 = traceback(('<unknown>', 0))
528        stats1 = snapshot.statistics('traceback')
529        self.assertEqual(stats1, [
530            tracemalloc.Statistic(tb3, 66, 1),
531            tracemalloc.Statistic(tb1, 30, 3),
532            tracemalloc.Statistic(tb4, 7, 1),
533            tracemalloc.Statistic(tb2, 2, 1),
534        ])
535
536        # stats per file (2)
537        tb5 = traceback(('c.py', 578))
538        stats2 = snapshot2.statistics('traceback')
539        self.assertEqual(stats2, [
540            tracemalloc.Statistic(tb2, 5002, 2),
541            tracemalloc.Statistic(tb5, 400, 1),
542            tracemalloc.Statistic(tb1, 30, 3),
543        ])
544
545        # stats diff per file
546        diff = snapshot2.compare_to(snapshot, 'traceback')
547        self.assertEqual(diff, [
548            tracemalloc.StatisticDiff(tb2, 5002, 5000, 2, 1),
549            tracemalloc.StatisticDiff(tb5, 400, 400, 1, 1),
550            tracemalloc.StatisticDiff(tb3, 0, -66, 0, -1),
551            tracemalloc.StatisticDiff(tb4, 0, -7, 0, -1),
552            tracemalloc.StatisticDiff(tb1, 30, 0, 3, 0),
553        ])
554
555        self.assertRaises(ValueError,
556                          snapshot.statistics, 'traceback', cumulative=True)
557
558    def test_snapshot_group_by_cumulative(self):
559        snapshot, snapshot2 = create_snapshots()
560        tb_0 = traceback_filename('<unknown>')
561        tb_a = traceback_filename('a.py')
562        tb_b = traceback_filename('b.py')
563        tb_a_2 = traceback_lineno('a.py', 2)
564        tb_a_5 = traceback_lineno('a.py', 5)
565        tb_b_1 = traceback_lineno('b.py', 1)
566        tb_b_4 = traceback_lineno('b.py', 4)
567
568        # per file
569        stats = snapshot.statistics('filename', True)
570        self.assertEqual(stats, [
571            tracemalloc.Statistic(tb_b, 98, 5),
572            tracemalloc.Statistic(tb_a, 32, 4),
573            tracemalloc.Statistic(tb_0, 7, 1),
574        ])
575
576        # per line
577        stats = snapshot.statistics('lineno', True)
578        self.assertEqual(stats, [
579            tracemalloc.Statistic(tb_b_1, 66, 1),
580            tracemalloc.Statistic(tb_b_4, 32, 4),
581            tracemalloc.Statistic(tb_a_2, 30, 3),
582            tracemalloc.Statistic(tb_0, 7, 1),
583            tracemalloc.Statistic(tb_a_5, 2, 1),
584        ])
585
586    def test_trace_format(self):
587        snapshot, snapshot2 = create_snapshots()
588        trace = snapshot.traces[0]
589        self.assertEqual(str(trace), 'b.py:4: 10 B')
590        traceback = trace.traceback
591        self.assertEqual(str(traceback), 'b.py:4')
592        frame = traceback[0]
593        self.assertEqual(str(frame), 'b.py:4')
594
595    def test_statistic_format(self):
596        snapshot, snapshot2 = create_snapshots()
597        stats = snapshot.statistics('lineno')
598        stat = stats[0]
599        self.assertEqual(str(stat),
600                         'b.py:1: size=66 B, count=1, average=66 B')
601
602    def test_statistic_diff_format(self):
603        snapshot, snapshot2 = create_snapshots()
604        stats = snapshot2.compare_to(snapshot, 'lineno')
605        stat = stats[0]
606        self.assertEqual(str(stat),
607                         'a.py:5: size=5002 B (+5000 B), count=2 (+1), average=2501 B')
608
609    def test_slices(self):
610        snapshot, snapshot2 = create_snapshots()
611        self.assertEqual(snapshot.traces[:2],
612                         (snapshot.traces[0], snapshot.traces[1]))
613
614        traceback = snapshot.traces[0].traceback
615        self.assertEqual(traceback[:2],
616                         (traceback[0], traceback[1]))
617
618    def test_format_traceback(self):
619        snapshot, snapshot2 = create_snapshots()
620        def getline(filename, lineno):
621            return '  <%s, %s>' % (filename, lineno)
622        with unittest.mock.patch('tracemalloc.linecache.getline',
623                                 side_effect=getline):
624            tb = snapshot.traces[0].traceback
625            self.assertEqual(tb.format(),
626                             ['  File "b.py", line 4',
627                              '    <b.py, 4>',
628                              '  File "a.py", line 2',
629                              '    <a.py, 2>'])
630
631            self.assertEqual(tb.format(limit=1),
632                             ['  File "a.py", line 2',
633                              '    <a.py, 2>'])
634
635            self.assertEqual(tb.format(limit=-1),
636                             ['  File "b.py", line 4',
637                              '    <b.py, 4>'])
638
639            self.assertEqual(tb.format(most_recent_first=True),
640                             ['  File "a.py", line 2',
641                              '    <a.py, 2>',
642                              '  File "b.py", line 4',
643                              '    <b.py, 4>'])
644
645            self.assertEqual(tb.format(limit=1, most_recent_first=True),
646                             ['  File "a.py", line 2',
647                              '    <a.py, 2>'])
648
649            self.assertEqual(tb.format(limit=-1, most_recent_first=True),
650                             ['  File "b.py", line 4',
651                              '    <b.py, 4>'])
652
653
654class TestFilters(unittest.TestCase):
655    maxDiff = 2048
656
657    def test_filter_attributes(self):
658        # test default values
659        f = tracemalloc.Filter(True, "abc")
660        self.assertEqual(f.inclusive, True)
661        self.assertEqual(f.filename_pattern, "abc")
662        self.assertIsNone(f.lineno)
663        self.assertEqual(f.all_frames, False)
664
665        # test custom values
666        f = tracemalloc.Filter(False, "test.py", 123, True)
667        self.assertEqual(f.inclusive, False)
668        self.assertEqual(f.filename_pattern, "test.py")
669        self.assertEqual(f.lineno, 123)
670        self.assertEqual(f.all_frames, True)
671
672        # parameters passed by keyword
673        f = tracemalloc.Filter(inclusive=False, filename_pattern="test.py", lineno=123, all_frames=True)
674        self.assertEqual(f.inclusive, False)
675        self.assertEqual(f.filename_pattern, "test.py")
676        self.assertEqual(f.lineno, 123)
677        self.assertEqual(f.all_frames, True)
678
679        # read-only attribute
680        self.assertRaises(AttributeError, setattr, f, "filename_pattern", "abc")
681
682    def test_filter_match(self):
683        # filter without line number
684        f = tracemalloc.Filter(True, "abc")
685        self.assertTrue(f._match_frame("abc", 0))
686        self.assertTrue(f._match_frame("abc", 5))
687        self.assertTrue(f._match_frame("abc", 10))
688        self.assertFalse(f._match_frame("12356", 0))
689        self.assertFalse(f._match_frame("12356", 5))
690        self.assertFalse(f._match_frame("12356", 10))
691
692        f = tracemalloc.Filter(False, "abc")
693        self.assertFalse(f._match_frame("abc", 0))
694        self.assertFalse(f._match_frame("abc", 5))
695        self.assertFalse(f._match_frame("abc", 10))
696        self.assertTrue(f._match_frame("12356", 0))
697        self.assertTrue(f._match_frame("12356", 5))
698        self.assertTrue(f._match_frame("12356", 10))
699
700        # filter with line number > 0
701        f = tracemalloc.Filter(True, "abc", 5)
702        self.assertFalse(f._match_frame("abc", 0))
703        self.assertTrue(f._match_frame("abc", 5))
704        self.assertFalse(f._match_frame("abc", 10))
705        self.assertFalse(f._match_frame("12356", 0))
706        self.assertFalse(f._match_frame("12356", 5))
707        self.assertFalse(f._match_frame("12356", 10))
708
709        f = tracemalloc.Filter(False, "abc", 5)
710        self.assertTrue(f._match_frame("abc", 0))
711        self.assertFalse(f._match_frame("abc", 5))
712        self.assertTrue(f._match_frame("abc", 10))
713        self.assertTrue(f._match_frame("12356", 0))
714        self.assertTrue(f._match_frame("12356", 5))
715        self.assertTrue(f._match_frame("12356", 10))
716
717        # filter with line number 0
718        f = tracemalloc.Filter(True, "abc", 0)
719        self.assertTrue(f._match_frame("abc", 0))
720        self.assertFalse(f._match_frame("abc", 5))
721        self.assertFalse(f._match_frame("abc", 10))
722        self.assertFalse(f._match_frame("12356", 0))
723        self.assertFalse(f._match_frame("12356", 5))
724        self.assertFalse(f._match_frame("12356", 10))
725
726        f = tracemalloc.Filter(False, "abc", 0)
727        self.assertFalse(f._match_frame("abc", 0))
728        self.assertTrue(f._match_frame("abc", 5))
729        self.assertTrue(f._match_frame("abc", 10))
730        self.assertTrue(f._match_frame("12356", 0))
731        self.assertTrue(f._match_frame("12356", 5))
732        self.assertTrue(f._match_frame("12356", 10))
733
734    def test_filter_match_filename(self):
735        def fnmatch(inclusive, filename, pattern):
736            f = tracemalloc.Filter(inclusive, pattern)
737            return f._match_frame(filename, 0)
738
739        self.assertTrue(fnmatch(True, "abc", "abc"))
740        self.assertFalse(fnmatch(True, "12356", "abc"))
741        self.assertFalse(fnmatch(True, "<unknown>", "abc"))
742
743        self.assertFalse(fnmatch(False, "abc", "abc"))
744        self.assertTrue(fnmatch(False, "12356", "abc"))
745        self.assertTrue(fnmatch(False, "<unknown>", "abc"))
746
747    def test_filter_match_filename_joker(self):
748        def fnmatch(filename, pattern):
749            filter = tracemalloc.Filter(True, pattern)
750            return filter._match_frame(filename, 0)
751
752        # empty string
753        self.assertFalse(fnmatch('abc', ''))
754        self.assertFalse(fnmatch('', 'abc'))
755        self.assertTrue(fnmatch('', ''))
756        self.assertTrue(fnmatch('', '*'))
757
758        # no *
759        self.assertTrue(fnmatch('abc', 'abc'))
760        self.assertFalse(fnmatch('abc', 'abcd'))
761        self.assertFalse(fnmatch('abc', 'def'))
762
763        # a*
764        self.assertTrue(fnmatch('abc', 'a*'))
765        self.assertTrue(fnmatch('abc', 'abc*'))
766        self.assertFalse(fnmatch('abc', 'b*'))
767        self.assertFalse(fnmatch('abc', 'abcd*'))
768
769        # a*b
770        self.assertTrue(fnmatch('abc', 'a*c'))
771        self.assertTrue(fnmatch('abcdcx', 'a*cx'))
772        self.assertFalse(fnmatch('abb', 'a*c'))
773        self.assertFalse(fnmatch('abcdce', 'a*cx'))
774
775        # a*b*c
776        self.assertTrue(fnmatch('abcde', 'a*c*e'))
777        self.assertTrue(fnmatch('abcbdefeg', 'a*bd*eg'))
778        self.assertFalse(fnmatch('abcdd', 'a*c*e'))
779        self.assertFalse(fnmatch('abcbdefef', 'a*bd*eg'))
780
781        # replace .pyc suffix with .py
782        self.assertTrue(fnmatch('a.pyc', 'a.py'))
783        self.assertTrue(fnmatch('a.py', 'a.pyc'))
784
785        if os.name == 'nt':
786            # case insensitive
787            self.assertTrue(fnmatch('aBC', 'ABc'))
788            self.assertTrue(fnmatch('aBcDe', 'Ab*dE'))
789
790            self.assertTrue(fnmatch('a.pyc', 'a.PY'))
791            self.assertTrue(fnmatch('a.py', 'a.PYC'))
792        else:
793            # case sensitive
794            self.assertFalse(fnmatch('aBC', 'ABc'))
795            self.assertFalse(fnmatch('aBcDe', 'Ab*dE'))
796
797            self.assertFalse(fnmatch('a.pyc', 'a.PY'))
798            self.assertFalse(fnmatch('a.py', 'a.PYC'))
799
800        if os.name == 'nt':
801            # normalize alternate separator "/" to the standard separator "\"
802            self.assertTrue(fnmatch(r'a/b', r'a\b'))
803            self.assertTrue(fnmatch(r'a\b', r'a/b'))
804            self.assertTrue(fnmatch(r'a/b\c', r'a\b/c'))
805            self.assertTrue(fnmatch(r'a/b/c', r'a\b\c'))
806        else:
807            # there is no alternate separator
808            self.assertFalse(fnmatch(r'a/b', r'a\b'))
809            self.assertFalse(fnmatch(r'a\b', r'a/b'))
810            self.assertFalse(fnmatch(r'a/b\c', r'a\b/c'))
811            self.assertFalse(fnmatch(r'a/b/c', r'a\b\c'))
812
813        # as of 3.5, .pyo is no longer munged to .py
814        self.assertFalse(fnmatch('a.pyo', 'a.py'))
815
816    def test_filter_match_trace(self):
817        t1 = (("a.py", 2), ("b.py", 3))
818        t2 = (("b.py", 4), ("b.py", 5))
819        t3 = (("c.py", 5), ('<unknown>', 0))
820        unknown = (('<unknown>', 0),)
821
822        f = tracemalloc.Filter(True, "b.py", all_frames=True)
823        self.assertTrue(f._match_traceback(t1))
824        self.assertTrue(f._match_traceback(t2))
825        self.assertFalse(f._match_traceback(t3))
826        self.assertFalse(f._match_traceback(unknown))
827
828        f = tracemalloc.Filter(True, "b.py", all_frames=False)
829        self.assertFalse(f._match_traceback(t1))
830        self.assertTrue(f._match_traceback(t2))
831        self.assertFalse(f._match_traceback(t3))
832        self.assertFalse(f._match_traceback(unknown))
833
834        f = tracemalloc.Filter(False, "b.py", all_frames=True)
835        self.assertFalse(f._match_traceback(t1))
836        self.assertFalse(f._match_traceback(t2))
837        self.assertTrue(f._match_traceback(t3))
838        self.assertTrue(f._match_traceback(unknown))
839
840        f = tracemalloc.Filter(False, "b.py", all_frames=False)
841        self.assertTrue(f._match_traceback(t1))
842        self.assertFalse(f._match_traceback(t2))
843        self.assertTrue(f._match_traceback(t3))
844        self.assertTrue(f._match_traceback(unknown))
845
846        f = tracemalloc.Filter(False, "<unknown>", all_frames=False)
847        self.assertTrue(f._match_traceback(t1))
848        self.assertTrue(f._match_traceback(t2))
849        self.assertTrue(f._match_traceback(t3))
850        self.assertFalse(f._match_traceback(unknown))
851
852        f = tracemalloc.Filter(True, "<unknown>", all_frames=True)
853        self.assertFalse(f._match_traceback(t1))
854        self.assertFalse(f._match_traceback(t2))
855        self.assertTrue(f._match_traceback(t3))
856        self.assertTrue(f._match_traceback(unknown))
857
858        f = tracemalloc.Filter(False, "<unknown>", all_frames=True)
859        self.assertTrue(f._match_traceback(t1))
860        self.assertTrue(f._match_traceback(t2))
861        self.assertFalse(f._match_traceback(t3))
862        self.assertFalse(f._match_traceback(unknown))
863
864
865class TestCommandLine(unittest.TestCase):
866    def test_env_var_disabled_by_default(self):
867        # not tracing by default
868        code = 'import tracemalloc; print(tracemalloc.is_tracing())'
869        ok, stdout, stderr = assert_python_ok('-c', code)
870        stdout = stdout.rstrip()
871        self.assertEqual(stdout, b'False')
872
873    @unittest.skipIf(interpreter_requires_environment(),
874                     'Cannot run -E tests when PYTHON env vars are required.')
875    def test_env_var_ignored_with_E(self):
876        """PYTHON* environment variables must be ignored when -E is present."""
877        code = 'import tracemalloc; print(tracemalloc.is_tracing())'
878        ok, stdout, stderr = assert_python_ok('-E', '-c', code, PYTHONTRACEMALLOC='1')
879        stdout = stdout.rstrip()
880        self.assertEqual(stdout, b'False')
881
882    def test_env_var_disabled(self):
883        # tracing at startup
884        code = 'import tracemalloc; print(tracemalloc.is_tracing())'
885        ok, stdout, stderr = assert_python_ok('-c', code, PYTHONTRACEMALLOC='0')
886        stdout = stdout.rstrip()
887        self.assertEqual(stdout, b'False')
888
889    def test_env_var_enabled_at_startup(self):
890        # tracing at startup
891        code = 'import tracemalloc; print(tracemalloc.is_tracing())'
892        ok, stdout, stderr = assert_python_ok('-c', code, PYTHONTRACEMALLOC='1')
893        stdout = stdout.rstrip()
894        self.assertEqual(stdout, b'True')
895
896    def test_env_limit(self):
897        # start and set the number of frames
898        code = 'import tracemalloc; print(tracemalloc.get_traceback_limit())'
899        ok, stdout, stderr = assert_python_ok('-c', code, PYTHONTRACEMALLOC='10')
900        stdout = stdout.rstrip()
901        self.assertEqual(stdout, b'10')
902
903    def check_env_var_invalid(self, nframe):
904        with support.SuppressCrashReport():
905            ok, stdout, stderr = assert_python_failure(
906                '-c', 'pass',
907                PYTHONTRACEMALLOC=str(nframe))
908
909        if b'ValueError: the number of frames must be in range' in stderr:
910            return
911        if b'PYTHONTRACEMALLOC: invalid number of frames' in stderr:
912            return
913        self.fail(f"unexpected output: {stderr!a}")
914
915
916    def test_env_var_invalid(self):
917        for nframe in INVALID_NFRAME:
918            with self.subTest(nframe=nframe):
919                self.check_env_var_invalid(nframe)
920
921    def test_sys_xoptions(self):
922        for xoptions, nframe in (
923            ('tracemalloc', 1),
924            ('tracemalloc=1', 1),
925            ('tracemalloc=15', 15),
926        ):
927            with self.subTest(xoptions=xoptions, nframe=nframe):
928                code = 'import tracemalloc; print(tracemalloc.get_traceback_limit())'
929                ok, stdout, stderr = assert_python_ok('-X', xoptions, '-c', code)
930                stdout = stdout.rstrip()
931                self.assertEqual(stdout, str(nframe).encode('ascii'))
932
933    def check_sys_xoptions_invalid(self, nframe):
934        args = ('-X', 'tracemalloc=%s' % nframe, '-c', 'pass')
935        with support.SuppressCrashReport():
936            ok, stdout, stderr = assert_python_failure(*args)
937
938        if b'ValueError: the number of frames must be in range' in stderr:
939            return
940        if b'-X tracemalloc=NFRAME: invalid number of frames' in stderr:
941            return
942        self.fail(f"unexpected output: {stderr!a}")
943
944    def test_sys_xoptions_invalid(self):
945        for nframe in INVALID_NFRAME:
946            with self.subTest(nframe=nframe):
947                self.check_sys_xoptions_invalid(nframe)
948
949    @unittest.skipIf(_testcapi is None, 'need _testcapi')
950    def test_pymem_alloc0(self):
951        # Issue #21639: Check that PyMem_Malloc(0) with tracemalloc enabled
952        # does not crash.
953        code = 'import _testcapi; _testcapi.test_pymem_alloc0(); 1'
954        assert_python_ok('-X', 'tracemalloc', '-c', code)
955
956
957@unittest.skipIf(_testcapi is None, 'need _testcapi')
958class TestCAPI(unittest.TestCase):
959    maxDiff = 80 * 20
960
961    def setUp(self):
962        if tracemalloc.is_tracing():
963            self.skipTest("tracemalloc must be stopped before the test")
964
965        self.domain = 5
966        self.size = 123
967        self.obj = allocate_bytes(self.size)[0]
968
969        # for the type "object", id(obj) is the address of its memory block.
970        # This type is not tracked by the garbage collector
971        self.ptr = id(self.obj)
972
973    def tearDown(self):
974        tracemalloc.stop()
975
976    def get_traceback(self):
977        frames = _testcapi.tracemalloc_get_traceback(self.domain, self.ptr)
978        if frames is not None:
979            return tracemalloc.Traceback(frames)
980        else:
981            return None
982
983    def track(self, release_gil=False, nframe=1):
984        frames = get_frames(nframe, 1)
985        _testcapi.tracemalloc_track(self.domain, self.ptr, self.size,
986                                    release_gil)
987        return frames
988
989    def untrack(self):
990        _testcapi.tracemalloc_untrack(self.domain, self.ptr)
991
992    def get_traced_memory(self):
993        # Get the traced size in the domain
994        snapshot = tracemalloc.take_snapshot()
995        domain_filter = tracemalloc.DomainFilter(True, self.domain)
996        snapshot = snapshot.filter_traces([domain_filter])
997        return sum(trace.size for trace in snapshot.traces)
998
999    def check_track(self, release_gil):
1000        nframe = 5
1001        tracemalloc.start(nframe)
1002
1003        size = tracemalloc.get_traced_memory()[0]
1004
1005        frames = self.track(release_gil, nframe)
1006        self.assertEqual(self.get_traceback(),
1007                         tracemalloc.Traceback(frames))
1008
1009        self.assertEqual(self.get_traced_memory(), self.size)
1010
1011    def test_track(self):
1012        self.check_track(False)
1013
1014    def test_track_without_gil(self):
1015        # check that calling _PyTraceMalloc_Track() without holding the GIL
1016        # works too
1017        self.check_track(True)
1018
1019    def test_track_already_tracked(self):
1020        nframe = 5
1021        tracemalloc.start(nframe)
1022
1023        # track a first time
1024        self.track()
1025
1026        # calling _PyTraceMalloc_Track() must remove the old trace and add
1027        # a new trace with the new traceback
1028        frames = self.track(nframe=nframe)
1029        self.assertEqual(self.get_traceback(),
1030                         tracemalloc.Traceback(frames))
1031
1032    def test_untrack(self):
1033        tracemalloc.start()
1034
1035        self.track()
1036        self.assertIsNotNone(self.get_traceback())
1037        self.assertEqual(self.get_traced_memory(), self.size)
1038
1039        # untrack must remove the trace
1040        self.untrack()
1041        self.assertIsNone(self.get_traceback())
1042        self.assertEqual(self.get_traced_memory(), 0)
1043
1044        # calling _PyTraceMalloc_Untrack() multiple times must not crash
1045        self.untrack()
1046        self.untrack()
1047
1048    def test_stop_track(self):
1049        tracemalloc.start()
1050        tracemalloc.stop()
1051
1052        with self.assertRaises(RuntimeError):
1053            self.track()
1054        self.assertIsNone(self.get_traceback())
1055
1056    def test_stop_untrack(self):
1057        tracemalloc.start()
1058        self.track()
1059
1060        tracemalloc.stop()
1061        with self.assertRaises(RuntimeError):
1062            self.untrack()
1063
1064
1065def test_main():
1066    support.run_unittest(
1067        TestTracemallocEnabled,
1068        TestSnapshot,
1069        TestFilters,
1070        TestCommandLine,
1071        TestCAPI,
1072    )
1073
1074if __name__ == "__main__":
1075    test_main()
1076