1"""A flow graph representation for Python bytecode"""
2
3import dis
4import types
5import sys
6
7from compiler import misc
8from compiler.consts \
9     import CO_OPTIMIZED, CO_NEWLOCALS, CO_VARARGS, CO_VARKEYWORDS
10
11class FlowGraph:
12    def __init__(self):
13        self.current = self.entry = Block()
14        self.exit = Block("exit")
15        self.blocks = misc.Set()
16        self.blocks.add(self.entry)
17        self.blocks.add(self.exit)
18
19    def startBlock(self, block):
20        if self._debug:
21            if self.current:
22                print "end", repr(self.current)
23                print "    next", self.current.next
24                print "    prev", self.current.prev
25                print "   ", self.current.get_children()
26            print repr(block)
27        self.current = block
28
29    def nextBlock(self, block=None):
30        # XXX think we need to specify when there is implicit transfer
31        # from one block to the next.  might be better to represent this
32        # with explicit JUMP_ABSOLUTE instructions that are optimized
33        # out when they are unnecessary.
34        #
35        # I think this strategy works: each block has a child
36        # designated as "next" which is returned as the last of the
37        # children.  because the nodes in a graph are emitted in
38        # reverse post order, the "next" block will always be emitted
39        # immediately after its parent.
40        # Worry: maintaining this invariant could be tricky
41        if block is None:
42            block = self.newBlock()
43
44        # Note: If the current block ends with an unconditional control
45        # transfer, then it is techically incorrect to add an implicit
46        # transfer to the block graph. Doing so results in code generation
47        # for unreachable blocks.  That doesn't appear to be very common
48        # with Python code and since the built-in compiler doesn't optimize
49        # it out we don't either.
50        self.current.addNext(block)
51        self.startBlock(block)
52
53    def newBlock(self):
54        b = Block()
55        self.blocks.add(b)
56        return b
57
58    def startExitBlock(self):
59        self.startBlock(self.exit)
60
61    _debug = 0
62
63    def _enable_debug(self):
64        self._debug = 1
65
66    def _disable_debug(self):
67        self._debug = 0
68
69    def emit(self, *inst):
70        if self._debug:
71            print "\t", inst
72        if len(inst) == 2 and isinstance(inst[1], Block):
73            self.current.addOutEdge(inst[1])
74        self.current.emit(inst)
75
76    def getBlocksInOrder(self):
77        """Return the blocks in reverse postorder
78
79        i.e. each node appears before all of its successors
80        """
81        order = order_blocks(self.entry, self.exit)
82        return order
83
84    def getBlocks(self):
85        return self.blocks.elements()
86
87    def getRoot(self):
88        """Return nodes appropriate for use with dominator"""
89        return self.entry
90
91    def getContainedGraphs(self):
92        l = []
93        for b in self.getBlocks():
94            l.extend(b.getContainedGraphs())
95        return l
96
97
98def order_blocks(start_block, exit_block):
99    """Order blocks so that they are emitted in the right order"""
100    # Rules:
101    # - when a block has a next block, the next block must be emitted just after
102    # - when a block has followers (relative jumps), it must be emitted before
103    #   them
104    # - all reachable blocks must be emitted
105    order = []
106
107    # Find all the blocks to be emitted.
108    remaining = set()
109    todo = [start_block]
110    while todo:
111        b = todo.pop()
112        if b in remaining:
113            continue
114        remaining.add(b)
115        for c in b.get_children():
116            if c not in remaining:
117                todo.append(c)
118
119    # A block is dominated by another block if that block must be emitted
120    # before it.
121    dominators = {}
122    for b in remaining:
123        if __debug__ and b.next:
124            assert b is b.next[0].prev[0], (b, b.next)
125        # Make sure every block appears in dominators, even if no
126        # other block must precede it.
127        dominators.setdefault(b, set())
128        # preceeding blocks dominate following blocks
129        for c in b.get_followers():
130            while 1:
131                dominators.setdefault(c, set()).add(b)
132                # Any block that has a next pointer leading to c is also
133                # dominated because the whole chain will be emitted at once.
134                # Walk backwards and add them all.
135                if c.prev and c.prev[0] is not b:
136                    c = c.prev[0]
137                else:
138                    break
139
140    def find_next():
141        # Find a block that can be emitted next.
142        for b in remaining:
143            for c in dominators[b]:
144                if c in remaining:
145                    break # can't emit yet, dominated by a remaining block
146            else:
147                return b
148        assert 0, 'circular dependency, cannot find next block'
149
150    b = start_block
151    while 1:
152        order.append(b)
153        remaining.discard(b)
154        if b.next:
155            b = b.next[0]
156            continue
157        elif b is not exit_block and not b.has_unconditional_transfer():
158            order.append(exit_block)
159        if not remaining:
160            break
161        b = find_next()
162    return order
163
164
165class Block:
166    _count = 0
167
168    def __init__(self, label=''):
169        self.insts = []
170        self.outEdges = set()
171        self.label = label
172        self.bid = Block._count
173        self.next = []
174        self.prev = []
175        Block._count = Block._count + 1
176
177    def __repr__(self):
178        if self.label:
179            return "<block %s id=%d>" % (self.label, self.bid)
180        else:
181            return "<block id=%d>" % (self.bid)
182
183    def __str__(self):
184        insts = map(str, self.insts)
185        return "<block %s %d:\n%s>" % (self.label, self.bid,
186                                       '\n'.join(insts))
187
188    def emit(self, inst):
189        op = inst[0]
190        self.insts.append(inst)
191
192    def getInstructions(self):
193        return self.insts
194
195    def addOutEdge(self, block):
196        self.outEdges.add(block)
197
198    def addNext(self, block):
199        self.next.append(block)
200        assert len(self.next) == 1, map(str, self.next)
201        block.prev.append(self)
202        assert len(block.prev) == 1, map(str, block.prev)
203
204    _uncond_transfer = ('RETURN_VALUE', 'RAISE_VARARGS',
205                        'JUMP_ABSOLUTE', 'JUMP_FORWARD', 'CONTINUE_LOOP',
206                        )
207
208    def has_unconditional_transfer(self):
209        """Returns True if there is an unconditional transfer to an other block
210        at the end of this block. This means there is no risk for the bytecode
211        executer to go past this block's bytecode."""
212        try:
213            op, arg = self.insts[-1]
214        except (IndexError, ValueError):
215            return
216        return op in self._uncond_transfer
217
218    def get_children(self):
219        return list(self.outEdges) + self.next
220
221    def get_followers(self):
222        """Get the whole list of followers, including the next block."""
223        followers = set(self.next)
224        # Blocks that must be emitted *after* this one, because of
225        # bytecode offsets (e.g. relative jumps) pointing to them.
226        for inst in self.insts:
227            if inst[0] in PyFlowGraph.hasjrel:
228                followers.add(inst[1])
229        return followers
230
231    def getContainedGraphs(self):
232        """Return all graphs contained within this block.
233
234        For example, a MAKE_FUNCTION block will contain a reference to
235        the graph for the function body.
236        """
237        contained = []
238        for inst in self.insts:
239            if len(inst) == 1:
240                continue
241            op = inst[1]
242            if hasattr(op, 'graph'):
243                contained.append(op.graph)
244        return contained
245
246# flags for code objects
247
248# the FlowGraph is transformed in place; it exists in one of these states
249RAW = "RAW"
250FLAT = "FLAT"
251CONV = "CONV"
252DONE = "DONE"
253
254class PyFlowGraph(FlowGraph):
255    super_init = FlowGraph.__init__
256
257    def __init__(self, name, filename, args=(), optimized=0, klass=None):
258        self.super_init()
259        self.name = name
260        self.filename = filename
261        self.docstring = None
262        self.args = args # XXX
263        self.argcount = getArgCount(args)
264        self.klass = klass
265        if optimized:
266            self.flags = CO_OPTIMIZED | CO_NEWLOCALS
267        else:
268            self.flags = 0
269        self.consts = []
270        self.names = []
271        # Free variables found by the symbol table scan, including
272        # variables used only in nested scopes, are included here.
273        self.freevars = []
274        self.cellvars = []
275        # The closure list is used to track the order of cell
276        # variables and free variables in the resulting code object.
277        # The offsets used by LOAD_CLOSURE/LOAD_DEREF refer to both
278        # kinds of variables.
279        self.closure = []
280        self.varnames = list(args) or []
281        for i in range(len(self.varnames)):
282            var = self.varnames[i]
283            if isinstance(var, TupleArg):
284                self.varnames[i] = var.getName()
285        self.stage = RAW
286
287    def setDocstring(self, doc):
288        self.docstring = doc
289
290    def setFlag(self, flag):
291        self.flags = self.flags | flag
292        if flag == CO_VARARGS:
293            self.argcount = self.argcount - 1
294
295    def checkFlag(self, flag):
296        if self.flags & flag:
297            return 1
298
299    def setFreeVars(self, names):
300        self.freevars = list(names)
301
302    def setCellVars(self, names):
303        self.cellvars = names
304
305    def getCode(self):
306        """Get a Python code object"""
307        assert self.stage == RAW
308        self.computeStackDepth()
309        self.flattenGraph()
310        assert self.stage == FLAT
311        self.convertArgs()
312        assert self.stage == CONV
313        self.makeByteCode()
314        assert self.stage == DONE
315        return self.newCodeObject()
316
317    def dump(self, io=None):
318        if io:
319            save = sys.stdout
320            sys.stdout = io
321        pc = 0
322        for t in self.insts:
323            opname = t[0]
324            if opname == "SET_LINENO":
325                print
326            if len(t) == 1:
327                print "\t", "%3d" % pc, opname
328                pc = pc + 1
329            else:
330                print "\t", "%3d" % pc, opname, t[1]
331                pc = pc + 3
332        if io:
333            sys.stdout = save
334
335    def computeStackDepth(self):
336        """Compute the max stack depth.
337
338        Approach is to compute the stack effect of each basic block.
339        Then find the path through the code with the largest total
340        effect.
341        """
342        depth = {}
343        exit = None
344        for b in self.getBlocks():
345            depth[b] = findDepth(b.getInstructions())
346
347        seen = {}
348
349        def max_depth(b, d):
350            if b in seen:
351                return d
352            seen[b] = 1
353            d = d + depth[b]
354            children = b.get_children()
355            if children:
356                return max([max_depth(c, d) for c in children])
357            else:
358                if not b.label == "exit":
359                    return max_depth(self.exit, d)
360                else:
361                    return d
362
363        self.stacksize = max_depth(self.entry, 0)
364
365    def flattenGraph(self):
366        """Arrange the blocks in order and resolve jumps"""
367        assert self.stage == RAW
368        self.insts = insts = []
369        pc = 0
370        begin = {}
371        end = {}
372        for b in self.getBlocksInOrder():
373            begin[b] = pc
374            for inst in b.getInstructions():
375                insts.append(inst)
376                if len(inst) == 1:
377                    pc = pc + 1
378                elif inst[0] != "SET_LINENO":
379                    # arg takes 2 bytes
380                    pc = pc + 3
381            end[b] = pc
382        pc = 0
383        for i in range(len(insts)):
384            inst = insts[i]
385            if len(inst) == 1:
386                pc = pc + 1
387            elif inst[0] != "SET_LINENO":
388                pc = pc + 3
389            opname = inst[0]
390            if opname in self.hasjrel:
391                oparg = inst[1]
392                offset = begin[oparg] - pc
393                insts[i] = opname, offset
394            elif opname in self.hasjabs:
395                insts[i] = opname, begin[inst[1]]
396        self.stage = FLAT
397
398    hasjrel = set()
399    for i in dis.hasjrel:
400        hasjrel.add(dis.opname[i])
401    hasjabs = set()
402    for i in dis.hasjabs:
403        hasjabs.add(dis.opname[i])
404
405    def convertArgs(self):
406        """Convert arguments from symbolic to concrete form"""
407        assert self.stage == FLAT
408        self.consts.insert(0, self.docstring)
409        self.sort_cellvars()
410        for i in range(len(self.insts)):
411            t = self.insts[i]
412            if len(t) == 2:
413                opname, oparg = t
414                conv = self._converters.get(opname, None)
415                if conv:
416                    self.insts[i] = opname, conv(self, oparg)
417        self.stage = CONV
418
419    def sort_cellvars(self):
420        """Sort cellvars in the order of varnames and prune from freevars.
421        """
422        cells = {}
423        for name in self.cellvars:
424            cells[name] = 1
425        self.cellvars = [name for name in self.varnames
426                         if name in cells]
427        for name in self.cellvars:
428            del cells[name]
429        self.cellvars = self.cellvars + cells.keys()
430        self.closure = self.cellvars + self.freevars
431
432    def _lookupName(self, name, list):
433        """Return index of name in list, appending if necessary
434
435        This routine uses a list instead of a dictionary, because a
436        dictionary can't store two different keys if the keys have the
437        same value but different types, e.g. 2 and 2L.  The compiler
438        must treat these two separately, so it does an explicit type
439        comparison before comparing the values.
440        """
441        t = type(name)
442        for i in range(len(list)):
443            if t == type(list[i]) and list[i] == name:
444                return i
445        end = len(list)
446        list.append(name)
447        return end
448
449    _converters = {}
450    def _convert_LOAD_CONST(self, arg):
451        if hasattr(arg, 'getCode'):
452            arg = arg.getCode()
453        return self._lookupName(arg, self.consts)
454
455    def _convert_LOAD_FAST(self, arg):
456        self._lookupName(arg, self.names)
457        return self._lookupName(arg, self.varnames)
458    _convert_STORE_FAST = _convert_LOAD_FAST
459    _convert_DELETE_FAST = _convert_LOAD_FAST
460
461    def _convert_LOAD_NAME(self, arg):
462        if self.klass is None:
463            self._lookupName(arg, self.varnames)
464        return self._lookupName(arg, self.names)
465
466    def _convert_NAME(self, arg):
467        if self.klass is None:
468            self._lookupName(arg, self.varnames)
469        return self._lookupName(arg, self.names)
470    _convert_STORE_NAME = _convert_NAME
471    _convert_DELETE_NAME = _convert_NAME
472    _convert_IMPORT_NAME = _convert_NAME
473    _convert_IMPORT_FROM = _convert_NAME
474    _convert_STORE_ATTR = _convert_NAME
475    _convert_LOAD_ATTR = _convert_NAME
476    _convert_DELETE_ATTR = _convert_NAME
477    _convert_LOAD_GLOBAL = _convert_NAME
478    _convert_STORE_GLOBAL = _convert_NAME
479    _convert_DELETE_GLOBAL = _convert_NAME
480
481    def _convert_DEREF(self, arg):
482        self._lookupName(arg, self.names)
483        self._lookupName(arg, self.varnames)
484        return self._lookupName(arg, self.closure)
485    _convert_LOAD_DEREF = _convert_DEREF
486    _convert_STORE_DEREF = _convert_DEREF
487
488    def _convert_LOAD_CLOSURE(self, arg):
489        self._lookupName(arg, self.varnames)
490        return self._lookupName(arg, self.closure)
491
492    _cmp = list(dis.cmp_op)
493    def _convert_COMPARE_OP(self, arg):
494        return self._cmp.index(arg)
495
496    # similarly for other opcodes...
497
498    for name, obj in locals().items():
499        if name[:9] == "_convert_":
500            opname = name[9:]
501            _converters[opname] = obj
502    del name, obj, opname
503
504    def makeByteCode(self):
505        assert self.stage == CONV
506        self.lnotab = lnotab = LineAddrTable()
507        for t in self.insts:
508            opname = t[0]
509            if len(t) == 1:
510                lnotab.addCode(self.opnum[opname])
511            else:
512                oparg = t[1]
513                if opname == "SET_LINENO":
514                    lnotab.nextLine(oparg)
515                    continue
516                hi, lo = twobyte(oparg)
517                try:
518                    lnotab.addCode(self.opnum[opname], lo, hi)
519                except ValueError:
520                    print opname, oparg
521                    print self.opnum[opname], lo, hi
522                    raise
523        self.stage = DONE
524
525    opnum = {}
526    for num in range(len(dis.opname)):
527        opnum[dis.opname[num]] = num
528    del num
529
530    def newCodeObject(self):
531        assert self.stage == DONE
532        if (self.flags & CO_NEWLOCALS) == 0:
533            nlocals = 0
534        else:
535            nlocals = len(self.varnames)
536        argcount = self.argcount
537        if self.flags & CO_VARKEYWORDS:
538            argcount = argcount - 1
539        return types.CodeType(argcount, nlocals, self.stacksize, self.flags,
540                        self.lnotab.getCode(), self.getConsts(),
541                        tuple(self.names), tuple(self.varnames),
542                        self.filename, self.name, self.lnotab.firstline,
543                        self.lnotab.getTable(), tuple(self.freevars),
544                        tuple(self.cellvars))
545
546    def getConsts(self):
547        """Return a tuple for the const slot of the code object
548
549        Must convert references to code (MAKE_FUNCTION) to code
550        objects recursively.
551        """
552        l = []
553        for elt in self.consts:
554            if isinstance(elt, PyFlowGraph):
555                elt = elt.getCode()
556            l.append(elt)
557        return tuple(l)
558
559def isJump(opname):
560    if opname[:4] == 'JUMP':
561        return 1
562
563class TupleArg:
564    """Helper for marking func defs with nested tuples in arglist"""
565    def __init__(self, count, names):
566        self.count = count
567        self.names = names
568    def __repr__(self):
569        return "TupleArg(%s, %s)" % (self.count, self.names)
570    def getName(self):
571        return ".%d" % self.count
572
573def getArgCount(args):
574    argcount = len(args)
575    if args:
576        for arg in args:
577            if isinstance(arg, TupleArg):
578                numNames = len(misc.flatten(arg.names))
579                argcount = argcount - numNames
580    return argcount
581
582def twobyte(val):
583    """Convert an int argument into high and low bytes"""
584    assert isinstance(val, int)
585    return divmod(val, 256)
586
587class LineAddrTable:
588    """lnotab
589
590    This class builds the lnotab, which is documented in compile.c.
591    Here's a brief recap:
592
593    For each SET_LINENO instruction after the first one, two bytes are
594    added to lnotab.  (In some cases, multiple two-byte entries are
595    added.)  The first byte is the distance in bytes between the
596    instruction for the last SET_LINENO and the current SET_LINENO.
597    The second byte is offset in line numbers.  If either offset is
598    greater than 255, multiple two-byte entries are added -- see
599    compile.c for the delicate details.
600    """
601
602    def __init__(self):
603        self.code = []
604        self.codeOffset = 0
605        self.firstline = 0
606        self.lastline = 0
607        self.lastoff = 0
608        self.lnotab = []
609
610    def addCode(self, *args):
611        for arg in args:
612            self.code.append(chr(arg))
613        self.codeOffset = self.codeOffset + len(args)
614
615    def nextLine(self, lineno):
616        if self.firstline == 0:
617            self.firstline = lineno
618            self.lastline = lineno
619        else:
620            # compute deltas
621            addr = self.codeOffset - self.lastoff
622            line = lineno - self.lastline
623            # Python assumes that lineno always increases with
624            # increasing bytecode address (lnotab is unsigned char).
625            # Depending on when SET_LINENO instructions are emitted
626            # this is not always true.  Consider the code:
627            #     a = (1,
628            #          b)
629            # In the bytecode stream, the assignment to "a" occurs
630            # after the loading of "b".  This works with the C Python
631            # compiler because it only generates a SET_LINENO instruction
632            # for the assignment.
633            if line >= 0:
634                push = self.lnotab.append
635                while addr > 255:
636                    push(255); push(0)
637                    addr -= 255
638                while line > 255:
639                    push(addr); push(255)
640                    line -= 255
641                    addr = 0
642                if addr > 0 or line > 0:
643                    push(addr); push(line)
644                self.lastline = lineno
645                self.lastoff = self.codeOffset
646
647    def getCode(self):
648        return ''.join(self.code)
649
650    def getTable(self):
651        return ''.join(map(chr, self.lnotab))
652
653class StackDepthTracker:
654    # XXX 1. need to keep track of stack depth on jumps
655    # XXX 2. at least partly as a result, this code is broken
656
657    def findDepth(self, insts, debug=0):
658        depth = 0
659        maxDepth = 0
660        for i in insts:
661            opname = i[0]
662            if debug:
663                print i,
664            delta = self.effect.get(opname, None)
665            if delta is not None:
666                depth = depth + delta
667            else:
668                # now check patterns
669                for pat, pat_delta in self.patterns:
670                    if opname[:len(pat)] == pat:
671                        delta = pat_delta
672                        depth = depth + delta
673                        break
674                # if we still haven't found a match
675                if delta is None:
676                    meth = getattr(self, opname, None)
677                    if meth is not None:
678                        depth = depth + meth(i[1])
679            if depth > maxDepth:
680                maxDepth = depth
681            if debug:
682                print depth, maxDepth
683        return maxDepth
684
685    effect = {
686        'POP_TOP': -1,
687        'DUP_TOP': 1,
688        'LIST_APPEND': -1,
689        'SET_ADD': -1,
690        'MAP_ADD': -2,
691        'SLICE+1': -1,
692        'SLICE+2': -1,
693        'SLICE+3': -2,
694        'STORE_SLICE+0': -1,
695        'STORE_SLICE+1': -2,
696        'STORE_SLICE+2': -2,
697        'STORE_SLICE+3': -3,
698        'DELETE_SLICE+0': -1,
699        'DELETE_SLICE+1': -2,
700        'DELETE_SLICE+2': -2,
701        'DELETE_SLICE+3': -3,
702        'STORE_SUBSCR': -3,
703        'DELETE_SUBSCR': -2,
704        # PRINT_EXPR?
705        'PRINT_ITEM': -1,
706        'RETURN_VALUE': -1,
707        'YIELD_VALUE': -1,
708        'EXEC_STMT': -3,
709        'BUILD_CLASS': -2,
710        'STORE_NAME': -1,
711        'STORE_ATTR': -2,
712        'DELETE_ATTR': -1,
713        'STORE_GLOBAL': -1,
714        'BUILD_MAP': 1,
715        'COMPARE_OP': -1,
716        'STORE_FAST': -1,
717        'IMPORT_STAR': -1,
718        'IMPORT_NAME': -1,
719        'IMPORT_FROM': 1,
720        'LOAD_ATTR': 0, # unlike other loads
721        # close enough...
722        'SETUP_EXCEPT': 3,
723        'SETUP_FINALLY': 3,
724        'FOR_ITER': 1,
725        'WITH_CLEANUP': -1,
726        }
727    # use pattern match
728    patterns = [
729        ('BINARY_', -1),
730        ('LOAD_', 1),
731        ]
732
733    def UNPACK_SEQUENCE(self, count):
734        return count-1
735    def BUILD_TUPLE(self, count):
736        return -count+1
737    def BUILD_LIST(self, count):
738        return -count+1
739    def BUILD_SET(self, count):
740        return -count+1
741    def CALL_FUNCTION(self, argc):
742        hi, lo = divmod(argc, 256)
743        return -(lo + hi * 2)
744    def CALL_FUNCTION_VAR(self, argc):
745        return self.CALL_FUNCTION(argc)-1
746    def CALL_FUNCTION_KW(self, argc):
747        return self.CALL_FUNCTION(argc)-1
748    def CALL_FUNCTION_VAR_KW(self, argc):
749        return self.CALL_FUNCTION(argc)-2
750    def MAKE_FUNCTION(self, argc):
751        return -argc
752    def MAKE_CLOSURE(self, argc):
753        # XXX need to account for free variables too!
754        return -argc
755    def BUILD_SLICE(self, argc):
756        if argc == 2:
757            return -1
758        elif argc == 3:
759            return -2
760    def DUP_TOPX(self, argc):
761        return argc
762
763findDepth = StackDepthTracker().findDepth
764