1"""
2altgraph.ObjectGraph - Graph of objects with an identifier
3==========================================================
4
5A graph of objects that have a "graphident" attribute.
6graphident is the key for the object in the graph
7"""
8
9from altgraph import GraphError
10from altgraph.Graph import Graph
11from altgraph.GraphUtil import filter_stack
12
13class ObjectGraph(object):
14    """
15    A graph of objects that have a "graphident" attribute.
16    graphident is the key for the object in the graph
17    """
18    def __init__(self, graph=None, debug=0):
19        if graph is None:
20            graph = Graph()
21        self.graphident = self
22        self.graph = graph
23        self.debug = debug
24        self.indent = 0
25        graph.add_node(self, None)
26
27    def __repr__(self):
28        return '<%s>' % (type(self).__name__,)
29
30    def flatten(self, condition=None, start=None):
31        """
32        Iterate over the subgraph that is entirely reachable by condition
33        starting from the given start node or the ObjectGraph root
34        """
35        if start is None:
36            start = self
37        start = self.getRawIdent(start)
38        return self.graph.iterdata(start=start, condition=condition)
39
40    def nodes(self):
41        for ident in self.graph:
42            node = self.graph.node_data(ident)
43            if node is not None:
44                yield self.graph.node_data(ident)
45
46
47    def get_edges(self, node):
48        start = self.getRawIdent(node)
49        _, _, outraw, incraw = self.graph.describe_node(start)
50        def iter_edges(lst, n):
51            seen = set()
52            for tpl in (self.graph.describe_edge(e) for e in lst):
53                ident = tpl[n]
54                if ident not in seen:
55                    yield self.findNode(ident)
56                    seen.add(ident)
57        return iter_edges(outraw, 3), iter_edges(incraw, 2)
58
59    def edgeData(self, fromNode, toNode):
60        start = self.getRawIdent(fromNode)
61        stop = self.getRawIdent(toNode)
62        edge = self.graph.edge_by_node(start, stop)
63        return self.graph.edge_data(edge)
64
65    def updateEdgeData(self, fromNode, toNode, edgeData):
66        start = self.getRawIdent(fromNode)
67        stop = self.getRawIdent(toNode)
68        edge = self.graph.edge_by_node(start, stop)
69        self.graph.update_edge_data(edge, edgeData)
70
71    def filterStack(self, filters):
72        """
73        Filter the ObjectGraph in-place by removing all edges to nodes that
74        do not match every filter in the given filter list
75
76        Returns a tuple containing the number of: (nodes_visited, nodes_removed, nodes_orphaned)
77        """
78        visited, removes, orphans = filter_stack(self.graph, self, filters)
79
80        for last_good, tail in orphans:
81            self.graph.add_edge(last_good, tail, edge_data='orphan')
82
83        for node in removes:
84            self.graph.hide_node(node)
85
86        return len(visited)-1, len(removes), len(orphans)
87
88    def removeNode(self, node):
89        """
90        Remove the given node from the graph if it exists
91        """
92        ident = self.getIdent(node)
93        if ident is not None:
94            self.graph.hide_node(ident)
95
96    def removeReference(self, fromnode, tonode):
97        """
98        Remove all edges from fromnode to tonode
99        """
100        if fromnode is None:
101            fromnode = self
102        fromident = self.getIdent(fromnode)
103        toident = self.getIdent(tonode)
104        if fromident is not None and toident is not None:
105            while True:
106                edge = self.graph.edge_by_node(fromident, toident)
107                if edge is None:
108                    break
109                self.graph.hide_edge(edge)
110
111    def getIdent(self, node):
112        """
113        Get the graph identifier for a node
114        """
115        ident = self.getRawIdent(node)
116        if ident is not None:
117            return ident
118        node = self.findNode(node)
119        if node is None:
120            return None
121        return node.graphident
122
123    def getRawIdent(self, node):
124        """
125        Get the identifier for a node object
126        """
127        if node is self:
128            return node
129        ident = getattr(node, 'graphident', None)
130        return ident
131
132    def __contains__(self, node):
133        return self.findNode(node) is not None
134
135    def findNode(self, node):
136        """
137        Find the node on the graph
138        """
139        ident = self.getRawIdent(node)
140        if ident is None:
141            ident = node
142        try:
143            return self.graph.node_data(ident)
144        except KeyError:
145            return None
146
147    def addNode(self, node):
148        """
149        Add a node to the graph referenced by the root
150        """
151        self.msg(4, "addNode", node)
152
153        try:
154            self.graph.restore_node(node.graphident)
155        except GraphError:
156            self.graph.add_node(node.graphident, node)
157
158    def createReference(self, fromnode, tonode, edge_data=None):
159        """
160        Create a reference from fromnode to tonode
161        """
162        if fromnode is None:
163            fromnode = self
164        fromident, toident = self.getIdent(fromnode), self.getIdent(tonode)
165        if fromident is None or toident is None:
166            return
167        self.msg(4, "createReference", fromnode, tonode, edge_data)
168        self.graph.add_edge(fromident, toident, edge_data=edge_data)
169
170    def createNode(self, cls, name, *args, **kw):
171        """
172        Add a node of type cls to the graph if it does not already exist
173        by the given name
174        """
175        m = self.findNode(name)
176        if m is None:
177            m = cls(name, *args, **kw)
178            self.addNode(m)
179        return m
180
181    def msg(self, level, s, *args):
182        """
183        Print a debug message with the given level
184        """
185        if s and level <= self.debug:
186            print ("%s%s %s" % ("  " * self.indent, s, ' '.join(map(repr, args))))
187
188    def msgin(self, level, s, *args):
189        """
190        Print a debug message and indent
191        """
192        if level <= self.debug:
193            self.msg(level, s, *args)
194            self.indent = self.indent + 1
195
196    def msgout(self, level, s, *args):
197        """
198        Dedent and print a debug message
199        """
200        if level <= self.debug:
201            self.indent = self.indent - 1
202            self.msg(level, s, *args)
203