1"""Unit tests for the memoryview
2
3XXX We need more tests! Some tests are in test_bytes
4"""
5
6import unittest
7import sys
8import gc
9import weakref
10import array
11from test import test_support
12import io
13import copy
14import pickle
15
16
17class AbstractMemoryTests:
18    source_bytes = b"abcdef"
19
20    @property
21    def _source(self):
22        return self.source_bytes
23
24    @property
25    def _types(self):
26        return filter(None, [self.ro_type, self.rw_type])
27
28    def check_getitem_with_type(self, tp):
29        item = self.getitem_type
30        b = tp(self._source)
31        oldrefcount = sys.getrefcount(b)
32        m = self._view(b)
33        self.assertEqual(m[0], item(b"a"))
34        self.assertIsInstance(m[0], bytes)
35        self.assertEqual(m[5], item(b"f"))
36        self.assertEqual(m[-1], item(b"f"))
37        self.assertEqual(m[-6], item(b"a"))
38        # Bounds checking
39        self.assertRaises(IndexError, lambda: m[6])
40        self.assertRaises(IndexError, lambda: m[-7])
41        self.assertRaises(IndexError, lambda: m[sys.maxsize])
42        self.assertRaises(IndexError, lambda: m[-sys.maxsize])
43        # Type checking
44        self.assertRaises(TypeError, lambda: m[None])
45        self.assertRaises(TypeError, lambda: m[0.0])
46        self.assertRaises(TypeError, lambda: m["a"])
47        m = None
48        self.assertEqual(sys.getrefcount(b), oldrefcount)
49
50    def test_getitem(self):
51        for tp in self._types:
52            self.check_getitem_with_type(tp)
53
54    def test_iter(self):
55        for tp in self._types:
56            b = tp(self._source)
57            m = self._view(b)
58            self.assertEqual(list(m), [m[i] for i in range(len(m))])
59
60    def test_repr(self):
61        for tp in self._types:
62            b = tp(self._source)
63            m = self._view(b)
64            self.assertIsInstance(m.__repr__(), str)
65
66    def test_setitem_readonly(self):
67        if not self.ro_type:
68            self.skipTest("no read-only type to test")
69        b = self.ro_type(self._source)
70        oldrefcount = sys.getrefcount(b)
71        m = self._view(b)
72        def setitem(value):
73            m[0] = value
74        self.assertRaises(TypeError, setitem, b"a")
75        self.assertRaises(TypeError, setitem, 65)
76        self.assertRaises(TypeError, setitem, memoryview(b"a"))
77        m = None
78        self.assertEqual(sys.getrefcount(b), oldrefcount)
79
80    def test_setitem_writable(self):
81        if not self.rw_type:
82            self.skipTest("no writable type to test")
83        tp = self.rw_type
84        b = self.rw_type(self._source)
85        oldrefcount = sys.getrefcount(b)
86        m = self._view(b)
87        m[0] = tp(b"0")
88        self._check_contents(tp, b, b"0bcdef")
89        m[1:3] = tp(b"12")
90        self._check_contents(tp, b, b"012def")
91        m[1:1] = tp(b"")
92        self._check_contents(tp, b, b"012def")
93        m[:] = tp(b"abcdef")
94        self._check_contents(tp, b, b"abcdef")
95
96        # Overlapping copies of a view into itself
97        m[0:3] = m[2:5]
98        self._check_contents(tp, b, b"cdedef")
99        m[:] = tp(b"abcdef")
100        m[2:5] = m[0:3]
101        self._check_contents(tp, b, b"ababcf")
102
103        def setitem(key, value):
104            m[key] = tp(value)
105        # Bounds checking
106        self.assertRaises(IndexError, setitem, 6, b"a")
107        self.assertRaises(IndexError, setitem, -7, b"a")
108        self.assertRaises(IndexError, setitem, sys.maxsize, b"a")
109        self.assertRaises(IndexError, setitem, -sys.maxsize, b"a")
110        # Wrong index/slice types
111        self.assertRaises(TypeError, setitem, 0.0, b"a")
112        self.assertRaises(TypeError, setitem, (0,), b"a")
113        self.assertRaises(TypeError, setitem, "a", b"a")
114        # Trying to resize the memory object
115        self.assertRaises(ValueError, setitem, 0, b"")
116        self.assertRaises(ValueError, setitem, 0, b"ab")
117        self.assertRaises(ValueError, setitem, slice(1,1), b"a")
118        self.assertRaises(ValueError, setitem, slice(0,2), b"a")
119
120        m = None
121        self.assertEqual(sys.getrefcount(b), oldrefcount)
122
123    def test_delitem(self):
124        for tp in self._types:
125            b = tp(self._source)
126            m = self._view(b)
127            with self.assertRaises(TypeError):
128                del m[1]
129            with self.assertRaises(TypeError):
130                del m[1:4]
131
132    def test_tobytes(self):
133        for tp in self._types:
134            m = self._view(tp(self._source))
135            b = m.tobytes()
136            # This calls self.getitem_type() on each separate byte of b"abcdef"
137            expected = b"".join(
138                self.getitem_type(c) for c in b"abcdef")
139            self.assertEqual(b, expected)
140            self.assertIsInstance(b, bytes)
141
142    def test_tolist(self):
143        for tp in self._types:
144            m = self._view(tp(self._source))
145            l = m.tolist()
146            self.assertEqual(l, map(ord, b"abcdef"))
147
148    def test_compare(self):
149        # memoryviews can compare for equality with other objects
150        # having the buffer interface.
151        for tp in self._types:
152            m = self._view(tp(self._source))
153            for tp_comp in self._types:
154                self.assertTrue(m == tp_comp(b"abcdef"))
155                self.assertFalse(m != tp_comp(b"abcdef"))
156                self.assertFalse(m == tp_comp(b"abcde"))
157                self.assertTrue(m != tp_comp(b"abcde"))
158                self.assertFalse(m == tp_comp(b"abcde1"))
159                self.assertTrue(m != tp_comp(b"abcde1"))
160            self.assertTrue(m == m)
161            self.assertTrue(m == m[:])
162            self.assertTrue(m[0:6] == m[:])
163            self.assertFalse(m[0:5] == m)
164
165            # Comparison with objects which don't support the buffer API
166            self.assertFalse(m == u"abcdef")
167            self.assertTrue(m != u"abcdef")
168            self.assertFalse(u"abcdef" == m)
169            self.assertTrue(u"abcdef" != m)
170
171            # Unordered comparisons are unimplemented, and therefore give
172            # arbitrary results (they raise a TypeError in py3k)
173
174    def check_attributes_with_type(self, tp):
175        m = self._view(tp(self._source))
176        self.assertEqual(m.format, self.format)
177        self.assertIsInstance(m.format, str)
178        self.assertEqual(m.itemsize, self.itemsize)
179        self.assertEqual(m.ndim, 1)
180        self.assertEqual(m.shape, (6,))
181        self.assertEqual(len(m), 6)
182        self.assertEqual(m.strides, (self.itemsize,))
183        self.assertEqual(m.suboffsets, None)
184        return m
185
186    def test_attributes_readonly(self):
187        if not self.ro_type:
188            self.skipTest("no read-only type to test")
189        m = self.check_attributes_with_type(self.ro_type)
190        self.assertEqual(m.readonly, True)
191
192    def test_attributes_writable(self):
193        if not self.rw_type:
194            self.skipTest("no writable type to test")
195        m = self.check_attributes_with_type(self.rw_type)
196        self.assertEqual(m.readonly, False)
197
198    # Disabled: unicode uses the old buffer API in 2.x
199
200    #def test_getbuffer(self):
201        ## Test PyObject_GetBuffer() on a memoryview object.
202        #for tp in self._types:
203            #b = tp(self._source)
204            #oldrefcount = sys.getrefcount(b)
205            #m = self._view(b)
206            #oldviewrefcount = sys.getrefcount(m)
207            #s = unicode(m, "utf-8")
208            #self._check_contents(tp, b, s.encode("utf-8"))
209            #self.assertEqual(sys.getrefcount(m), oldviewrefcount)
210            #m = None
211            #self.assertEqual(sys.getrefcount(b), oldrefcount)
212
213    def test_gc(self):
214        for tp in self._types:
215            if not isinstance(tp, type):
216                # If tp is a factory rather than a plain type, skip
217                continue
218
219            class MySource(tp):
220                pass
221            class MyObject:
222                pass
223
224            # Create a reference cycle through a memoryview object
225            b = MySource(tp(b'abc'))
226            m = self._view(b)
227            o = MyObject()
228            b.m = m
229            b.o = o
230            wr = weakref.ref(o)
231            b = m = o = None
232            # The cycle must be broken
233            gc.collect()
234            self.assertTrue(wr() is None, wr())
235
236    def test_writable_readonly(self):
237        # Issue #10451: memoryview incorrectly exposes a readonly
238        # buffer as writable causing a segfault if using mmap
239        tp = self.ro_type
240        if tp is None:
241            self.skipTest("no read-only type to test")
242        b = tp(self._source)
243        m = self._view(b)
244        i = io.BytesIO(b'ZZZZ')
245        self.assertRaises(TypeError, i.readinto, m)
246
247# Variations on source objects for the buffer: bytes-like objects, then arrays
248# with itemsize > 1.
249# NOTE: support for multi-dimensional objects is unimplemented.
250
251class BaseBytesMemoryTests(AbstractMemoryTests):
252    ro_type = bytes
253    rw_type = bytearray
254    getitem_type = bytes
255    itemsize = 1
256    format = 'B'
257
258# Disabled: array.array() does not support the new buffer API in 2.x
259
260#class BaseArrayMemoryTests(AbstractMemoryTests):
261    #ro_type = None
262    #rw_type = lambda self, b: array.array('i', map(ord, b))
263    #getitem_type = lambda self, b: array.array('i', map(ord, b)).tostring()
264    #itemsize = array.array('i').itemsize
265    #format = 'i'
266
267    #def test_getbuffer(self):
268        ## XXX Test should be adapted for non-byte buffers
269        #pass
270
271    #def test_tolist(self):
272        ## XXX NotImplementedError: tolist() only supports byte views
273        #pass
274
275
276# Variations on indirection levels: memoryview, slice of memoryview,
277# slice of slice of memoryview.
278# This is important to test allocation subtleties.
279
280class BaseMemoryviewTests:
281    def _view(self, obj):
282        return memoryview(obj)
283
284    def _check_contents(self, tp, obj, contents):
285        self.assertEqual(obj, tp(contents))
286
287class BaseMemorySliceTests:
288    source_bytes = b"XabcdefY"
289
290    def _view(self, obj):
291        m = memoryview(obj)
292        return m[1:7]
293
294    def _check_contents(self, tp, obj, contents):
295        self.assertEqual(obj[1:7], tp(contents))
296
297    def test_refs(self):
298        for tp in self._types:
299            m = memoryview(tp(self._source))
300            oldrefcount = sys.getrefcount(m)
301            m[1:2]
302            self.assertEqual(sys.getrefcount(m), oldrefcount)
303
304class BaseMemorySliceSliceTests:
305    source_bytes = b"XabcdefY"
306
307    def _view(self, obj):
308        m = memoryview(obj)
309        return m[:7][1:]
310
311    def _check_contents(self, tp, obj, contents):
312        self.assertEqual(obj[1:7], tp(contents))
313
314
315# Concrete test classes
316
317class BytesMemoryviewTest(unittest.TestCase,
318    BaseMemoryviewTests, BaseBytesMemoryTests):
319
320    def test_constructor(self):
321        for tp in self._types:
322            ob = tp(self._source)
323            self.assertTrue(memoryview(ob))
324            self.assertTrue(memoryview(object=ob))
325            self.assertRaises(TypeError, memoryview)
326            self.assertRaises(TypeError, memoryview, ob, ob)
327            self.assertRaises(TypeError, memoryview, argument=ob)
328            self.assertRaises(TypeError, memoryview, ob, argument=True)
329
330#class ArrayMemoryviewTest(unittest.TestCase,
331    #BaseMemoryviewTests, BaseArrayMemoryTests):
332
333    #def test_array_assign(self):
334        ## Issue #4569: segfault when mutating a memoryview with itemsize != 1
335        #a = array.array('i', range(10))
336        #m = memoryview(a)
337        #new_a = array.array('i', range(9, -1, -1))
338        #m[:] = new_a
339        #self.assertEqual(a, new_a)
340
341
342class BytesMemorySliceTest(unittest.TestCase,
343    BaseMemorySliceTests, BaseBytesMemoryTests):
344    pass
345
346#class ArrayMemorySliceTest(unittest.TestCase,
347    #BaseMemorySliceTests, BaseArrayMemoryTests):
348    #pass
349
350class BytesMemorySliceSliceTest(unittest.TestCase,
351    BaseMemorySliceSliceTests, BaseBytesMemoryTests):
352    pass
353
354#class ArrayMemorySliceSliceTest(unittest.TestCase,
355    #BaseMemorySliceSliceTests, BaseArrayMemoryTests):
356    #pass
357
358
359class OtherTest(unittest.TestCase):
360    def test_copy(self):
361        m = memoryview(b'abc')
362        with self.assertRaises(TypeError):
363            copy.copy(m)
364
365    # See issue #22995
366    ## def test_pickle(self):
367    ##     m = memoryview(b'abc')
368    ##     for proto in range(pickle.HIGHEST_PROTOCOL + 1):
369    ##         with self.assertRaises(TypeError):
370    ##             pickle.dumps(m, proto)
371
372
373def test_main():
374    test_support.run_unittest(__name__)
375
376if __name__ == "__main__":
377    test_main()
378