1import unittest
2from modulegraph import modulegraph
3import pkg_resources
4import os
5import imp
6import sys
7import shutil
8import warnings
9from altgraph import Graph
10import textwrap
11import xml.etree.ElementTree as ET
12import pickle
13
14try:
15    bytes
16except NameError:
17    bytes = str
18
19try:
20    from StringIO import StringIO
21except ImportError:
22    from io import StringIO
23
24TESTDATA = os.path.join(
25        os.path.dirname(os.path.abspath(__file__)),
26        "testdata", "nspkg")
27
28try:
29    expectedFailure = unittest.expectedFailure
30except AttributeError:
31    import functools
32    def expectedFailure(function):
33        @functools.wraps(function)
34        def wrapper(*args, **kwds):
35            try:
36                function(*args, **kwds)
37            except AssertionError:
38                pass
39
40            else:
41                self.fail("unexpected pass")
42
43class TestDependencyInfo (unittest.TestCase):
44    def test_pickling(self):
45        info = modulegraph.DependencyInfo(function=True, conditional=False, tryexcept=True, fromlist=False)
46        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
47            b = pickle.dumps(info, proto)
48            self.assertTrue(isinstance(b, bytes))
49
50            o = pickle.loads(b)
51            self.assertEqual(o, info)
52
53    def test_merging(self):
54        info1 = modulegraph.DependencyInfo(function=True, conditional=False, tryexcept=True, fromlist=False)
55        info2 = modulegraph.DependencyInfo(function=False, conditional=True, tryexcept=True, fromlist=False)
56        self.assertEqual(
57            info1._merged(info2), modulegraph.DependencyInfo(function=True, conditional=True, tryexcept=True, fromlist=False))
58
59        info2 = modulegraph.DependencyInfo(function=False, conditional=True, tryexcept=False, fromlist=False)
60        self.assertEqual(
61            info1._merged(info2), modulegraph.DependencyInfo(function=True, conditional=True, tryexcept=True, fromlist=False))
62
63        info2 = modulegraph.DependencyInfo(function=False, conditional=False, tryexcept=False, fromlist=False)
64        self.assertEqual(
65            info1._merged(info2), modulegraph.DependencyInfo(function=False, conditional=False, tryexcept=False, fromlist=False))
66
67        info1 = modulegraph.DependencyInfo(function=True, conditional=False, tryexcept=True, fromlist=True)
68        self.assertEqual(
69            info1._merged(info2), modulegraph.DependencyInfo(function=False, conditional=False, tryexcept=False, fromlist=False))
70
71        info2 = modulegraph.DependencyInfo(function=False, conditional=False, tryexcept=False, fromlist=True)
72        self.assertEqual(
73            info1._merged(info2), modulegraph.DependencyInfo(function=False, conditional=False, tryexcept=False, fromlist=True))
74
75
76class TestFunctions (unittest.TestCase):
77    if not hasattr(unittest.TestCase, 'assertIsInstance'):
78        def assertIsInstance(self, obj, types):
79            self.assertTrue(isinstance(obj, types), '%r is not instance of %r'%(obj, types))
80
81    def test_eval_str_tuple(self):
82        for v in [
83            '()',
84            '("hello",)',
85            '("hello", "world")',
86            "('hello',)",
87            "('hello', 'world')",
88            "('hello', \"world\")",
89            ]:
90
91            self.assertEqual(modulegraph._eval_str_tuple(v), eval(v))
92
93        self.assertRaises(ValueError, modulegraph._eval_str_tuple, "")
94        self.assertRaises(ValueError, modulegraph._eval_str_tuple, "'a'")
95        self.assertRaises(ValueError, modulegraph._eval_str_tuple, "'a', 'b'")
96        self.assertRaises(ValueError, modulegraph._eval_str_tuple, "('a', ('b', 'c'))")
97        self.assertRaises(ValueError, modulegraph._eval_str_tuple, "('a', ('b\", 'c'))")
98
99    def test_namespace_package_path(self):
100        class DS (object):
101            def __init__(self, path, namespace_packages=None):
102                self.location = path
103                self._namespace_packages = namespace_packages
104
105            def has_metadata(self, key):
106                if key == 'namespace_packages.txt':
107                    return self._namespace_packages is not None
108
109                raise ValueError("invalid lookup key")
110
111            def get_metadata(self, key):
112                if key == 'namespace_packages.txt':
113                    if self._namespace_packages is None:
114                        raise ValueError("no file")
115
116                    return self._namespace_packages
117
118                raise ValueError("invalid lookup key")
119
120        class WS (object):
121            def __init__(self, path=None):
122                pass
123
124            def __iter__(self):
125                yield DS("/pkg/pkg1")
126                yield DS("/pkg/pkg2", "foo\n")
127                yield DS("/pkg/pkg3", "bar.baz\n")
128                yield DS("/pkg/pkg4", "foobar\nfoo\n")
129
130        saved_ws = pkg_resources.WorkingSet
131        try:
132            pkg_resources.WorkingSet = WS
133
134            self.assertEqual(modulegraph._namespace_package_path("sys", ["appdir/pkg"]), ["appdir/pkg"])
135            self.assertEqual(modulegraph._namespace_package_path("foo", ["appdir/pkg"]), ["appdir/pkg", "/pkg/pkg2/foo", "/pkg/pkg4/foo"])
136            self.assertEqual(modulegraph._namespace_package_path("bar.baz", ["appdir/pkg"]), ["appdir/pkg", "/pkg/pkg3/bar/baz"])
137
138        finally:
139            pkg_resources.WorkingSet = saved_ws
140
141    def test_os_listdir(self):
142        root = os.path.join(
143                os.path.dirname(os.path.abspath(__file__)), 'testdata')
144
145        self.assertEqual(modulegraph.os_listdir('/etc/'), os.listdir('/etc'))
146        self.assertRaises(IOError, modulegraph.os_listdir, '/etc/hosts/foobar')
147        self.assertRaises(IOError, modulegraph.os_listdir, os.path.join(root, 'test.egg', 'bar'))
148
149        self.assertEqual(list(sorted(modulegraph.os_listdir(os.path.join(root, 'test.egg', 'foo')))),
150            [ 'bar', 'bar.txt', 'baz.txt' ])
151
152    def test_code_to_file(self):
153        try:
154            code = modulegraph._code_to_file.__code__
155        except AttributeError:
156            code = modulegraph._code_to_file.func_code
157
158        data = modulegraph._code_to_file(code)
159        self.assertTrue(hasattr(data, 'read'))
160
161        content = data.read()
162        self.assertIsInstance(content, bytes)
163        data.close()
164
165    def test_find_module(self):
166        for path in ('syspath', 'syspath.zip', 'syspath.egg'):
167            path = os.path.join(os.path.dirname(TESTDATA), path)
168            if os.path.exists(os.path.join(path, 'mymodule.pyc')):
169                os.unlink(os.path.join(path, 'mymodule.pyc'))
170
171            # Plain module
172            info = modulegraph.find_module('mymodule', path=[path] + sys.path)
173
174            fp = info[0]
175            filename = info[1]
176            description = info[2]
177
178            self.assertTrue(hasattr(fp, 'read'))
179
180            if path.endswith('.zip') or path.endswith('.egg'):
181                # Zip importers may precompile
182                if filename.endswith('.py'):
183                    self.assertEqual(filename, os.path.join(path, 'mymodule.py'))
184                    self.assertEqual(description, ('.py', 'rU', imp.PY_SOURCE))
185
186                else:
187                    self.assertEqual(filename, os.path.join(path, 'mymodule.pyc'))
188                    self.assertEqual(description, ('.pyc', 'rb', imp.PY_COMPILED))
189
190            else:
191                self.assertEqual(filename, os.path.join(path, 'mymodule.py'))
192                self.assertEqual(description, ('.py', 'rU', imp.PY_SOURCE))
193
194            # Compiled plain module, no source
195            if path.endswith('.zip') or path.endswith('.egg'):
196                self.assertRaises(ImportError, modulegraph.find_module, 'mymodule2', path=[path] + sys.path)
197
198            else:
199                info = modulegraph.find_module('mymodule2', path=[path] + sys.path)
200
201                fp = info[0]
202                filename = info[1]
203                description = info[2]
204
205                self.assertTrue(hasattr(fp, 'read'))
206                self.assertEqual(filename, os.path.join(path, 'mymodule2.pyc'))
207                self.assertEqual(description, ('.pyc', 'rb', imp.PY_COMPILED))
208
209                fp.close()
210
211            # Compiled plain module, with source
212#            info = modulegraph.find_module('mymodule3', path=[path] + sys.path)
213#
214#            fp = info[0]
215#            filename = info[1]
216#            description = info[2]
217#
218#            self.assertTrue(hasattr(fp, 'read'))
219#
220#            if sys.version_info[:2] >= (3,2):
221#                self.assertEqual(filename, os.path.join(path, '__pycache__', 'mymodule3.cpython-32.pyc'))
222#            else:
223#                self.assertEqual(filename, os.path.join(path, 'mymodule3.pyc'))
224#            self.assertEqual(description, ('.pyc', 'rb', imp.PY_COMPILED))
225
226
227            # Package
228            info = modulegraph.find_module('mypkg', path=[path] + sys.path)
229            fp = info[0]
230            filename = info[1]
231            description = info[2]
232
233            self.assertEqual(fp, None)
234            self.assertEqual(filename, os.path.join(path, 'mypkg'))
235            self.assertEqual(description, ('', '', imp.PKG_DIRECTORY))
236
237            # Extension
238            if path.endswith('.zip'):
239                self.assertRaises(ImportError, modulegraph.find_module, 'myext', path=[path] + sys.path)
240
241            else:
242                info = modulegraph.find_module('myext', path=[path] + sys.path)
243                fp = info[0]
244                filename = info[1]
245                description = info[2]
246
247                if sys.platform == 'win32':
248                    ext = '.pyd'
249                else:
250                    # This is a ly, but is good enough for now
251                    ext = '.so'
252
253                self.assertEqual(filename, os.path.join(path, 'myext' + ext))
254                self.assertEqual(description, (ext, 'rb', imp.C_EXTENSION))
255                self.assertEqual(fp, None)
256
257    def test_moduleInfoForPath(self):
258        self.assertEqual(modulegraph.moduleInfoForPath("/somewhere/else/file.txt"), None)
259
260        info = modulegraph.moduleInfoForPath("/somewhere/else/file.py")
261        self.assertEqual(info[0], "file")
262        if sys.version_info[:2] >= (3,4):
263            self.assertEqual(info[1], "r")
264        else:
265            self.assertEqual(info[1], "U")
266        self.assertEqual(info[2], imp.PY_SOURCE)
267
268        info = modulegraph.moduleInfoForPath("/somewhere/else/file.pyc")
269        self.assertEqual(info[0], "file")
270        self.assertEqual(info[1], "rb")
271        self.assertEqual(info[2], imp.PY_COMPILED)
272
273        if sys.platform in ('darwin', 'linux2'):
274            info = modulegraph.moduleInfoForPath("/somewhere/else/file.so")
275            self.assertEqual(info[0], "file")
276            self.assertEqual(info[1], "rb")
277            self.assertEqual(info[2], imp.C_EXTENSION)
278
279        elif sys.platform in ('win32',):
280            info = modulegraph.moduleInfoForPath("/somewhere/else/file.pyd")
281            self.assertEqual(info[0], "file")
282            self.assertEqual(info[1], "rb")
283            self.assertEqual(info[2], imp.C_EXTENSION)
284
285    if sys.version_info[:2] > (2,5):
286        exec(textwrap.dedent('''\
287            def test_deprecated(self):
288                saved_add = modulegraph.addPackagePath
289                saved_replace = modulegraph.replacePackage
290                try:
291                    called = []
292
293                    def log_add(*args, **kwds):
294                        called.append(('add', args, kwds))
295                    def log_replace(*args, **kwds):
296                        called.append(('replace', args, kwds))
297
298                    modulegraph.addPackagePath = log_add
299                    modulegraph.replacePackage = log_replace
300
301                    with warnings.catch_warnings(record=True) as w:
302                        warnings.simplefilter("always")
303                        modulegraph.ReplacePackage('a', 'b')
304                        modulegraph.AddPackagePath('c', 'd')
305
306                    self.assertEqual(len(w), 2)
307                    self.assertTrue(w[-1].category is DeprecationWarning)
308                    self.assertTrue(w[-2].category is DeprecationWarning)
309
310                    self.assertEqual(called, [
311                        ('replace', ('a', 'b'), {}),
312                        ('add', ('c', 'd'), {}),
313                    ])
314
315                finally:
316                    modulegraph.addPackagePath = saved_add
317                    modulegraph.replacePackage = saved_replace
318            '''), locals(), globals())
319
320    def test_addPackage(self):
321        saved = modulegraph._packagePathMap
322        self.assertIsInstance(saved, dict)
323        try:
324            modulegraph._packagePathMap = {}
325
326            modulegraph.addPackagePath('foo', 'a')
327            self.assertEqual(modulegraph._packagePathMap, { 'foo': ['a'] })
328
329            modulegraph.addPackagePath('foo', 'b')
330            self.assertEqual(modulegraph._packagePathMap, { 'foo': ['a', 'b'] })
331
332            modulegraph.addPackagePath('bar', 'b')
333            self.assertEqual(modulegraph._packagePathMap, { 'foo': ['a', 'b'], 'bar': ['b'] })
334
335        finally:
336            modulegraph._packagePathMap = saved
337
338
339    def test_replacePackage(self):
340        saved = modulegraph._replacePackageMap
341        self.assertIsInstance(saved, dict)
342        try:
343            modulegraph._replacePackageMap = {}
344
345            modulegraph.replacePackage("a", "b")
346            self.assertEqual(modulegraph._replacePackageMap, {"a": "b"})
347            modulegraph.replacePackage("a", "c")
348            self.assertEqual(modulegraph._replacePackageMap, {"a": "c"})
349            modulegraph.replacePackage("b", "c")
350            self.assertEqual(modulegraph._replacePackageMap, {"a": "c", 'b': 'c'})
351
352        finally:
353            modulegraph._replacePackageMap = saved
354
355class TestNode (unittest.TestCase):
356    if not hasattr(unittest.TestCase, 'assertIsInstance'):
357        def assertIsInstance(self, obj, types):
358            self.assertTrue(isinstance(obj, types), '%r is not instance of %r'%(obj, types))
359    def testBasicAttributes(self):
360        n = modulegraph.Node("foobar.xyz")
361        self.assertIsInstance(n.debug, int)
362        self.assertEqual(n.identifier, n.graphident)
363        self.assertEqual(n.identifier, 'foobar.xyz')
364        self.assertEqual(n.filename, None)
365        self.assertEqual(n.packagepath, None)
366        self.assertEqual(n.code, None)
367        self.assertEqual(n.globalnames, set())
368        self.assertEqual(n.starimports, set())
369
370    def testMapping(self):
371        n = modulegraph.Node("foobar.xyz")
372        self.assertEqual(n._namespace, {})
373
374        self.assertFalse('foo' in n)
375        self.assertRaises(KeyError, n.__getitem__, 'foo')
376        self.assertEqual(n.get('foo'), None)
377        self.assertEqual(n.get('foo', 'a'), 'a')
378        n['foo'] = 42
379        self.assertEqual(n['foo'], 42)
380        self.assertTrue('foo' in n)
381        self.assertEqual(n._namespace, {'foo':42})
382
383    def testOrder(self):
384        n1 = modulegraph.Node("n1")
385        n2 = modulegraph.Node("n2")
386
387        self.assertTrue(n1 < n2)
388        self.assertFalse(n2 < n1)
389        self.assertTrue(n1 <= n1)
390        self.assertFalse(n1 == n2)
391        self.assertTrue(n1 == n1)
392        self.assertTrue(n1 != n2)
393        self.assertFalse(n1 != n1)
394        self.assertTrue(n2 > n1)
395        self.assertFalse(n1 > n2)
396        self.assertTrue(n1 >= n1)
397        self.assertTrue(n2 >= n1)
398
399    def testHashing(self):
400        n1a = modulegraph.Node('n1')
401        n1b = modulegraph.Node('n1')
402        n2 = modulegraph.Node('n2')
403
404        d = {}
405        d[n1a] = 'n1'
406        d[n2] = 'n2'
407        self.assertEqual(d[n1b], 'n1')
408        self.assertEqual(d[n2], 'n2')
409
410    def test_infoTuple(self):
411        n = modulegraph.Node('n1')
412        self.assertEqual(n.infoTuple(), ('n1',))
413
414    def assertNoMethods(self, klass):
415        d = dict(klass.__dict__)
416        del d['__doc__']
417        del d['__module__']
418        if '__qualname__' in d:
419            # New in Python 3.3
420            del d['__qualname__']
421        if '__dict__' in d:
422            # New in Python 3.4
423            del d['__dict__']
424        self.assertEqual(d, {})
425
426    def assertHasExactMethods(self, klass, *methods):
427        d = dict(klass.__dict__)
428        del d['__doc__']
429        del d['__module__']
430        if '__qualname__' in d:
431            # New in Python 3.3
432            del d['__qualname__']
433        if '__dict__' in d:
434            # New in Python 3.4
435            del d['__dict__']
436
437        for nm in methods:
438            self.assertTrue(nm in d, "%s doesn't have attribute %r"%(klass, nm))
439            del d[nm]
440
441        self.assertEqual(d, {})
442
443
444    if not hasattr(unittest.TestCase, 'assertIsSubclass'):
445        def assertIsSubclass(self, cls1, cls2, message=None):
446            self.assertTrue(issubclass(cls1, cls2),
447                    message or "%r is not a subclass of %r"%(cls1, cls2))
448
449    def test_subclasses(self):
450        self.assertIsSubclass(modulegraph.AliasNode, modulegraph.Node)
451        self.assertIsSubclass(modulegraph.Script, modulegraph.Node)
452        self.assertIsSubclass(modulegraph.BadModule, modulegraph.Node)
453        self.assertIsSubclass(modulegraph.ExcludedModule, modulegraph.BadModule)
454        self.assertIsSubclass(modulegraph.MissingModule, modulegraph.BadModule)
455        self.assertIsSubclass(modulegraph.BaseModule, modulegraph.Node)
456        self.assertIsSubclass(modulegraph.BuiltinModule, modulegraph.BaseModule)
457        self.assertIsSubclass(modulegraph.SourceModule, modulegraph.BaseModule)
458        self.assertIsSubclass(modulegraph.CompiledModule, modulegraph.BaseModule)
459        self.assertIsSubclass(modulegraph.Package, modulegraph.BaseModule)
460        self.assertIsSubclass(modulegraph.Extension, modulegraph.BaseModule)
461
462        # These classes have no new functionality, check that no code
463        # got added:
464        self.assertNoMethods(modulegraph.BadModule)
465        self.assertNoMethods(modulegraph.ExcludedModule)
466        self.assertNoMethods(modulegraph.MissingModule)
467        self.assertNoMethods(modulegraph.BuiltinModule)
468        self.assertNoMethods(modulegraph.SourceModule)
469        self.assertNoMethods(modulegraph.CompiledModule)
470        self.assertNoMethods(modulegraph.Package)
471        self.assertNoMethods(modulegraph.Extension)
472
473        # AliasNode is basicly a clone of an existing node
474        self.assertHasExactMethods(modulegraph.Script, '__init__', 'infoTuple')
475        n1 = modulegraph.Node('n1')
476        n1.packagepath = ['a', 'b']
477
478        a1 = modulegraph.AliasNode('a1', n1)
479        self.assertEqual(a1.graphident, 'a1')
480        self.assertEqual(a1.identifier, 'n1')
481        self.assertTrue(a1.packagepath is n1.packagepath)
482        self.assertTrue(a1._namespace is n1._namespace)
483        self.assertTrue(a1.globalnames is n1.globalnames)
484        self.assertTrue(a1.starimports is n1.starimports)
485
486        v = a1.infoTuple()
487        self.assertEqual(v, ('a1', 'n1'))
488
489        # Scripts have a filename
490        self.assertHasExactMethods(modulegraph.Script, '__init__', 'infoTuple')
491        s1 = modulegraph.Script('do_import')
492        self.assertEqual(s1.graphident, 'do_import')
493        self.assertEqual(s1.identifier, 'do_import')
494        self.assertEqual(s1.filename, 'do_import')
495
496        v = s1.infoTuple()
497        self.assertEqual(v, ('do_import',))
498
499        # BaseModule adds some attributes and a custom infotuple
500        self.assertHasExactMethods(modulegraph.BaseModule, '__init__', 'infoTuple')
501        m1 = modulegraph.BaseModule('foo')
502        self.assertEqual(m1.graphident, 'foo')
503        self.assertEqual(m1.identifier, 'foo')
504        self.assertEqual(m1.filename, None)
505        self.assertEqual(m1.packagepath, None)
506
507        m1 = modulegraph.BaseModule('foo', 'bar',  ['a'])
508        self.assertEqual(m1.graphident, 'foo')
509        self.assertEqual(m1.identifier, 'foo')
510        self.assertEqual(m1.filename, 'bar')
511        self.assertEqual(m1.packagepath, ['a'])
512
513class TestModuleGraph (unittest.TestCase):
514    # Test for class modulegraph.modulegraph.ModuleGraph
515    if not hasattr(unittest.TestCase, 'assertIsInstance'):
516        def assertIsInstance(self, obj, types):
517            self.assertTrue(isinstance(obj, types), '%r is not instance of %r'%(obj, types))
518
519    def test_constructor(self):
520        o = modulegraph.ModuleGraph()
521        self.assertTrue(o.path is sys.path)
522        self.assertEqual(o.lazynodes, {})
523        self.assertEqual(o.replace_paths, ())
524        self.assertEqual(o.debug, 0)
525
526        # Stricter tests would be nice, but that requires
527        # better control over what's on sys.path
528        self.assertIsInstance(o.nspackages, dict)
529
530        g = Graph.Graph()
531        o = modulegraph.ModuleGraph(['a', 'b', 'c'], ['modA'], [
532                ('fromA', 'toB'), ('fromC', 'toD')],
533                {
534                    'modA': ['modB', 'modC'],
535                    'modC': ['modE', 'modF'],
536                }, g, 1)
537        self.assertEqual(o.path, ['a', 'b', 'c'])
538        self.assertEqual(o.lazynodes, {
539            'modA': None,
540            'modC': ['modE', 'modF'],
541        })
542        self.assertEqual(o.replace_paths, [('fromA', 'toB'), ('fromC', 'toD')])
543        self.assertEqual(o.nspackages, {})
544        self.assertTrue(o.graph is g)
545        self.assertEqual(o.debug, 1)
546
547    def test_calc_setuptools_nspackages(self):
548        stdlib = [ fn for fn in sys.path if fn.startswith(sys.prefix) and 'site-packages' not in fn ]
549        for subdir in [ nm for nm in os.listdir(TESTDATA) if nm != 'src' ]:
550            graph = modulegraph.ModuleGraph(path=[
551                    os.path.join(TESTDATA, subdir, "parent"),
552                    os.path.join(TESTDATA, subdir, "child"),
553                ] + stdlib)
554
555            pkgs = graph.nspackages
556            self.assertTrue('namedpkg' in pkgs)
557            self.assertEqual(set(pkgs['namedpkg']),
558                    set([
559                        os.path.join(TESTDATA, subdir, "parent", "namedpkg"),
560                        os.path.join(TESTDATA, subdir, "child", "namedpkg"),
561                    ]))
562            self.assertFalse(os.path.exists(os.path.join(TESTDATA, subdir, "parent", "namedpkg", "__init__.py")))
563            self.assertFalse(os.path.exists(os.path.join(TESTDATA, subdir, "child", "namedpkg", "__init__.py")))
564
565    def testImpliedReference(self):
566        graph = modulegraph.ModuleGraph()
567
568        record = []
569        def import_hook(*args):
570            record.append(('import_hook',) + args)
571            return [graph.createNode(modulegraph.Node, args[0])]
572
573        def _safe_import_hook(*args):
574            record.append(('_safe_import_hook',) + args)
575            return [graph.createNode(modulegraph.Node, args[0])]
576
577        graph.import_hook = import_hook
578        graph._safe_import_hook = _safe_import_hook
579
580        n1 = graph.createNode(modulegraph.Node, 'n1')
581        n2 = graph.createNode(modulegraph.Node, 'n2')
582
583        graph.implyNodeReference(n1, n2)
584        outs, ins = map(list, graph.get_edges(n1))
585        self.assertEqual(outs, [n2])
586        self.assertEqual(ins, [])
587
588        self.assertEqual(record, [])
589
590        graph.implyNodeReference(n2, "n3")
591        n3 = graph.findNode('n3')
592        outs, ins = map(list, graph.get_edges(n2))
593        self.assertEqual(outs, [n3])
594        self.assertEqual(ins, [n1])
595        self.assertEqual(record, [
596            ('_safe_import_hook', 'n3', n2, None)
597        ])
598
599
600
601    @expectedFailure
602    def test_findNode(self):
603        self.fail("findNode")
604
605    def test_run_script(self):
606        script = os.path.join(os.path.dirname(TESTDATA), 'script')
607
608        graph = modulegraph.ModuleGraph()
609        master = graph.createNode(modulegraph.Node, 'root')
610        m = graph.run_script(script, master)
611        self.assertEqual(list(graph.get_edges(master)[0])[0], m)
612        self.assertEqual(set(graph.get_edges(m)[0]), set([
613            graph.findNode('sys'),
614            graph.findNode('os'),
615        ]))
616
617    @expectedFailure
618    def test_import_hook(self):
619        self.fail("import_hook")
620
621    def test_determine_parent(self):
622        graph = modulegraph.ModuleGraph()
623        graph.import_hook('os.path', None)
624        graph.import_hook('idlelib', None)
625        graph.import_hook('xml.dom', None)
626
627        for node in graph.nodes():
628            if isinstance(node, modulegraph.Package):
629                break
630        else:
631            self.fail("No package located, should have at least 'os'")
632
633        self.assertIsInstance(node, modulegraph.Package)
634        parent = graph._determine_parent(node)
635        self.assertEqual(parent.identifier, node.identifier)
636        self.assertEqual(parent, graph.findNode(node.identifier))
637        self.assertTrue(isinstance(parent, modulegraph.Package))
638
639        # XXX: Might be a usecase for some odd code in determine_parent...
640        #node = modulegraph.Package('encodings')
641        #node.packagepath = parent.packagepath
642        #m = graph._determine_parent(node)
643        #self.assertTrue(m is parent)
644
645        m = graph.findNode('xml')
646        self.assertEqual(graph._determine_parent(m), m)
647
648        m = graph.findNode('xml.dom')
649        self.assertEqual(graph._determine_parent(m), graph.findNode('xml.dom'))
650
651
652    @expectedFailure
653    def test_find_head_package(self):
654        self.fail("find_head_package")
655
656    def test_load_tail(self):
657        # XXX: This test is dodgy!
658        graph = modulegraph.ModuleGraph()
659
660        record = []
661        def _import_module(partname, fqname, parent):
662            record.append((partname, fqname, parent))
663            if partname == 'raises' or '.raises.' in fqname:
664                return None
665            return modulegraph.Node(fqname)
666
667        graph._import_module = _import_module
668
669        record = []
670        root = modulegraph.Node('root')
671        m = graph._load_tail(root, '')
672        self.assertTrue(m is root)
673        self.assertEqual(record, [
674            ])
675
676        record = []
677        root = modulegraph.Node('root')
678        m = graph._load_tail(root, 'sub')
679        self.assertFalse(m is root)
680        self.assertEqual(record, [
681                ('sub', 'root.sub', root),
682            ])
683
684        record = []
685        root = modulegraph.Node('root')
686        m = graph._load_tail(root, 'sub.sub1')
687        self.assertFalse(m is root)
688        node = modulegraph.Node('root.sub')
689        self.assertEqual(record, [
690                ('sub', 'root.sub', root),
691                ('sub1', 'root.sub.sub1', node),
692            ])
693
694        record = []
695        root = modulegraph.Node('root')
696        m = graph._load_tail(root, 'sub.sub1.sub2')
697        self.assertFalse(m is root)
698        node = modulegraph.Node('root.sub')
699        node2 = modulegraph.Node('root.sub.sub1')
700        self.assertEqual(record, [
701                ('sub', 'root.sub', root),
702                ('sub1', 'root.sub.sub1', node),
703                ('sub2', 'root.sub.sub1.sub2', node2),
704            ])
705
706        n = graph._load_tail(root, 'raises')
707        self.assertIsInstance(n, modulegraph.MissingModule)
708        self.assertEqual(n.identifier, 'root.raises')
709
710        n = graph._load_tail(root, 'sub.raises')
711        self.assertIsInstance(n, modulegraph.MissingModule)
712        self.assertEqual(n.identifier, 'root.sub.raises')
713
714        n = graph._load_tail(root, 'sub.raises.sub')
715        self.assertIsInstance(n, modulegraph.MissingModule)
716        self.assertEqual(n.identifier, 'root.sub.raises.sub')
717
718
719
720    @expectedFailure
721    def test_ensure_fromlist(self):
722        # 1. basic 'from module import name, name'
723        # 2. 'from module import *'
724        # 3. from module import os
725        #    (where 'os' is not a name in 'module',
726        #     should create MissingModule node, and
727        #     should *not* refer to the global os)
728        self.fail("ensure_fromlist")
729
730    @expectedFailure
731    def test_find_all_submodules(self):
732        # 1. basic
733        # 2. no packagepath (basic module)
734        # 3. extensions, python modules
735        # 4. with/without zipfile
736        # 5. files that aren't python modules/extensions
737        self.fail("find_all_submodules")
738
739    @expectedFailure
740    def test_import_module(self):
741        self.fail("import_module")
742
743    @expectedFailure
744    def test_load_module(self):
745        self.fail("load_module")
746
747    @expectedFailure
748    def test_safe_import_hook(self):
749        self.fail("safe_import_hook")
750
751    @expectedFailure
752    def test_scan_code(self):
753        mod = modulegraph.Node('root')
754
755        graph = modulegraph.ModuleGraph()
756        code = compile('', '<test>', 'exec', 0, False)
757        graph.scan_code(code, mod)
758        self.assertEqual(list(graph.nodes()), [])
759
760        node_map = {}
761        def _safe_import(name, mod, fromlist, level):
762            if name in node_map:
763                node = node_map[name]
764            else:
765                node = modulegraph.Node(name)
766            node_map[name] = node
767            return [node]
768
769        graph = modulegraph.ModuleGraph()
770        graph._safe_import_hook = _safe_import
771
772        code = compile(textwrap.dedent('''\
773            import sys
774            import os.path
775
776            def testfunc():
777                import shutil
778            '''), '<test>', 'exec', 0, False)
779        graph.scan_code(code, mod)
780        modules = [node.identifier for node in graph.nodes()]
781        self.assertEqual(set(node_map), set(['sys', 'os.path', 'shutil']))
782
783
784        # from module import a, b, c
785        # from module import *
786        #  both:
787        #   -> with/without globals
788        #   -> with/without modules in globals (e.g,
789        #       from os import * adds dependency to os.path)
790        # from .module import a
791        # from ..module import a
792        #   -> check levels
793        # import name
794        # import a.b
795        #   -> should add dependency to a
796        # try to build case where commented out
797        # code would behave different than current code
798        # (Carbon.SomeMod contains 'import Sibling' seems
799        # to cause difference in real code)
800
801        self.fail("actual test needed")
802
803
804
805    @expectedFailure
806    def test_load_package(self):
807        self.fail("load_package")
808
809    def test_find_module(self):
810        record = []
811        def mock_finder(name, path):
812            record.append((name, path))
813            return saved_finder(name, path)
814
815        saved_finder = modulegraph.find_module
816        try:
817            modulegraph.find_module = mock_finder
818
819            graph = modulegraph.ModuleGraph()
820            m = graph._find_module('sys', None)
821            self.assertEqual(record, [])
822            self.assertEqual(m, (None, None, ("", "", imp.C_BUILTIN)))
823
824            modulegraph.find_module = saved_finder
825            xml = graph.import_hook("xml")[0]
826            self.assertEqual(xml.identifier, 'xml')
827            modulegraph.find_module = mock_finder
828
829            self.assertRaises(ImportError, graph._find_module, 'xml', None)
830
831            self.assertEqual(record, [])
832            m = graph._find_module('shutil', None)
833            self.assertEqual(record, [
834                ('shutil', graph.path),
835            ])
836            self.assertTrue(isinstance(m, tuple))
837            self.assertEqual(len(m), 3)
838            self.assertTrue(hasattr(m[0], 'read'))
839            self.assertIsInstance(m[0].read(), str)
840            srcfn = shutil.__file__
841            if srcfn.endswith('.pyc'):
842                srcfn = srcfn[:-1]
843            self.assertEqual(m[1], srcfn)
844            self.assertEqual(m[2], ('.py', 'rU', imp.PY_SOURCE))
845            m[0].close()
846
847            m2 = graph._find_module('shutil', None)
848            self.assertEqual(m[1:], m2[1:])
849            m2[0].close()
850
851
852            record[:] = []
853            m = graph._find_module('sax', xml.packagepath, xml)
854            self.assertEqual(m,
855                    (None, os.path.join(os.path.dirname(xml.filename), 'sax'),
856                    ('', '', imp.PKG_DIRECTORY)))
857            self.assertEqual(record, [
858                ('sax', xml.packagepath),
859            ])
860            if m[0] is not None: m[0].close()
861
862        finally:
863            modulegraph.find_module = saved_finder
864
865    @expectedFailure
866    def test_create_xref(self):
867        self.fail("create_xref")
868
869    @expectedFailure
870    def test_itergraphreport(self):
871        self.fail("itergraphreport")
872
873    def test_report(self):
874        graph = modulegraph.ModuleGraph()
875
876        saved_stdout = sys.stdout
877        try:
878            fp = sys.stdout = StringIO()
879            graph.report()
880            lines = fp.getvalue().splitlines()
881            fp.close()
882
883            self.assertEqual(len(lines), 3)
884            self.assertEqual(lines[0], '')
885            self.assertEqual(lines[1], 'Class           Name                      File')
886            self.assertEqual(lines[2], '-----           ----                      ----')
887
888            fp = sys.stdout = StringIO()
889            graph._safe_import_hook('os', None, ())
890            graph._safe_import_hook('sys', None, ())
891            graph._safe_import_hook('nomod', None, ())
892            graph.report()
893            lines = fp.getvalue().splitlines()
894            fp.close()
895
896            self.assertEqual(lines[0], '')
897            self.assertEqual(lines[1], 'Class           Name                      File')
898            self.assertEqual(lines[2], '-----           ----                      ----')
899            expected = []
900            for n in graph.flatten():
901                if n.filename:
902                    expected.append([type(n).__name__, n.identifier, n.filename])
903                else:
904                    expected.append([type(n).__name__, n.identifier])
905
906            expected.sort()
907            actual = [item.split() for item in lines[3:]]
908            actual.sort()
909            self.assertEqual(expected, actual)
910
911
912        finally:
913            sys.stdout = saved_stdout
914
915    def test_graphreport(self):
916
917        def my_iter(flatpackages="packages"):
918            yield "line1\n"
919            yield str(flatpackages) + "\n"
920            yield "line2\n"
921
922        graph = modulegraph.ModuleGraph()
923        graph.itergraphreport = my_iter
924
925        fp = StringIO()
926        graph.graphreport(fp)
927        self.assertEqual(fp.getvalue(), "line1\n()\nline2\n")
928
929        fp = StringIO()
930        graph.graphreport(fp, "deps")
931        self.assertEqual(fp.getvalue(), "line1\ndeps\nline2\n")
932
933        saved_stdout = sys.stdout
934        try:
935            sys.stdout = fp = StringIO()
936            graph.graphreport()
937            self.assertEqual(fp.getvalue(), "line1\n()\nline2\n")
938
939        finally:
940            sys.stdout = saved_stdout
941
942
943    def test_replace_paths_in_code(self):
944        graph = modulegraph.ModuleGraph(replace_paths=[
945                ('path1', 'path2'),
946                ('path3/path5', 'path4'),
947            ])
948
949        co = compile(textwrap.dedent("""
950        [x for x in range(4)]
951        """), "path4/index.py", 'exec', 0, 1)
952        co = graph._replace_paths_in_code(co)
953        self.assertEqual(co.co_filename, 'path4/index.py')
954
955        co = compile(textwrap.dedent("""
956        [x for x in range(4)]
957        (x for x in range(4))
958        """), "path1/index.py", 'exec', 0, 1)
959        self.assertEqual(co.co_filename, 'path1/index.py')
960        co = graph._replace_paths_in_code(co)
961        self.assertEqual(co.co_filename, 'path2/index.py')
962        for c in co.co_consts:
963            if isinstance(c, type(co)):
964                self.assertEqual(c.co_filename, 'path2/index.py')
965
966        co = compile(textwrap.dedent("""
967        [x for x in range(4)]
968        """), "path3/path4/index.py", 'exec', 0, 1)
969        co = graph._replace_paths_in_code(co)
970        self.assertEqual(co.co_filename, 'path3/path4/index.py')
971
972        co = compile(textwrap.dedent("""
973        [x for x in range(4)]
974        """), "path3/path5.py", 'exec', 0, 1)
975        co = graph._replace_paths_in_code(co)
976        self.assertEqual(co.co_filename, 'path3/path5.py')
977
978        co = compile(textwrap.dedent("""
979        [x for x in range(4)]
980        """), "path3/path5/index.py", 'exec', 0, 1)
981        co = graph._replace_paths_in_code(co)
982        self.assertEqual(co.co_filename, 'path4/index.py')
983
984    def test_createReference(self):
985        graph = modulegraph.ModuleGraph()
986        n1 = modulegraph.Node('n1')
987        n2 = modulegraph.Node('n2')
988        graph.addNode(n1)
989        graph.addNode(n2)
990
991        graph.createReference(n1, n2)
992        outs, ins = map(list, graph.get_edges(n1))
993        self.assertEqual(outs, [n2])
994        self.assertEqual(ins, [])
995        outs, ins = map(list, graph.get_edges(n2))
996        self.assertEqual(outs, [])
997        self.assertEqual(ins, [n1])
998
999        e = graph.graph.edge_by_node('n1', 'n2')
1000        self.assertIsInstance(e, int)
1001        self.assertEqual(graph.graph.edge_data(e), 'direct')
1002
1003    def test_create_xref(self):
1004        # XXX: This test is far from optimal, it just ensures
1005        # that all code is exercised to catch small bugs and
1006        # py3k issues without verifying that the code actually
1007        # works....
1008        graph = modulegraph.ModuleGraph()
1009        if __file__.endswith('.py'):
1010            graph.run_script(__file__)
1011        else:
1012            graph.run_script(__file__[:-1])
1013
1014        graph.import_hook('os')
1015        graph.import_hook('xml.etree')
1016        graph.import_hook('unittest')
1017
1018        fp = StringIO()
1019        graph.create_xref(out=fp)
1020
1021        data = fp.getvalue()
1022        r = ET.fromstring(data)
1023
1024    def test_itergraphreport(self):
1025        # XXX: This test is far from optimal, it just ensures
1026        # that all code is exercised to catch small bugs and
1027        # py3k issues without verifying that the code actually
1028        # works....
1029        graph = modulegraph.ModuleGraph()
1030        if __file__.endswith('.py'):
1031            graph.run_script(__file__)
1032        else:
1033            graph.run_script(__file__[:-1])
1034        graph.import_hook('os')
1035        graph.import_hook('xml.etree')
1036        graph.import_hook('unittest')
1037        graph.import_hook('distutils.command.build')
1038
1039        fp = StringIO()
1040        list(graph.itergraphreport())
1041
1042        # XXX: platpackages isn't implemented, and is undocumented hence
1043        # it is unclear what this is inteded to be...
1044        #list(graph.itergraphreport(flatpackages=...))
1045
1046
1047
1048
1049class CompatTests (unittest.TestCase):
1050    def test_Bchr(self):
1051        v = modulegraph._Bchr(ord('A'))
1052        if sys.version_info[0] == 2:
1053            self.assertTrue(isinstance(v, bytes))
1054            self.assertEqual(v, b'A')
1055        else:
1056            self.assertTrue(isinstance(v, int))
1057            self.assertEqual(v, ord('A'))
1058
1059if __name__ == "__main__":
1060    unittest.main()
1061