1"""Provides shared memory for direct access across processes.
2
3The API of this package is currently provisional. Refer to the
4documentation for details.
5"""
6
7
8__all__ = [ 'SharedMemory', 'ShareableList' ]
9
10
11from functools import partial
12import mmap
13import os
14import errno
15import struct
16import secrets
17import types
18
19if os.name == "nt":
20    import _winapi
21    _USE_POSIX = False
22else:
23    import _posixshmem
24    _USE_POSIX = True
25
26
27_O_CREX = os.O_CREAT | os.O_EXCL
28
29# FreeBSD (and perhaps other BSDs) limit names to 14 characters.
30_SHM_SAFE_NAME_LENGTH = 14
31
32# Shared memory block name prefix
33if _USE_POSIX:
34    _SHM_NAME_PREFIX = '/psm_'
35else:
36    _SHM_NAME_PREFIX = 'wnsm_'
37
38
39def _make_filename():
40    "Create a random filename for the shared memory object."
41    # number of random bytes to use for name
42    nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2
43    assert nbytes >= 2, '_SHM_NAME_PREFIX too long'
44    name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes)
45    assert len(name) <= _SHM_SAFE_NAME_LENGTH
46    return name
47
48
49class SharedMemory:
50    """Creates a new shared memory block or attaches to an existing
51    shared memory block.
52
53    Every shared memory block is assigned a unique name.  This enables
54    one process to create a shared memory block with a particular name
55    so that a different process can attach to that same shared memory
56    block using that same name.
57
58    As a resource for sharing data across processes, shared memory blocks
59    may outlive the original process that created them.  When one process
60    no longer needs access to a shared memory block that might still be
61    needed by other processes, the close() method should be called.
62    When a shared memory block is no longer needed by any process, the
63    unlink() method should be called to ensure proper cleanup."""
64
65    # Defaults; enables close() and unlink() to run without errors.
66    _name = None
67    _fd = -1
68    _mmap = None
69    _buf = None
70    _flags = os.O_RDWR
71    _mode = 0o600
72    _prepend_leading_slash = True if _USE_POSIX else False
73
74    def __init__(self, name=None, create=False, size=0):
75        if not size >= 0:
76            raise ValueError("'size' must be a positive integer")
77        if create:
78            self._flags = _O_CREX | os.O_RDWR
79            if size == 0:
80                raise ValueError("'size' must be a positive number different from zero")
81        if name is None and not self._flags & os.O_EXCL:
82            raise ValueError("'name' can only be None if create=True")
83
84        if _USE_POSIX:
85
86            # POSIX Shared Memory
87
88            if name is None:
89                while True:
90                    name = _make_filename()
91                    try:
92                        self._fd = _posixshmem.shm_open(
93                            name,
94                            self._flags,
95                            mode=self._mode
96                        )
97                    except FileExistsError:
98                        continue
99                    self._name = name
100                    break
101            else:
102                name = "/" + name if self._prepend_leading_slash else name
103                self._fd = _posixshmem.shm_open(
104                    name,
105                    self._flags,
106                    mode=self._mode
107                )
108                self._name = name
109            try:
110                if create and size:
111                    os.ftruncate(self._fd, size)
112                stats = os.fstat(self._fd)
113                size = stats.st_size
114                self._mmap = mmap.mmap(self._fd, size)
115            except OSError:
116                self.unlink()
117                raise
118
119            from .resource_tracker import register
120            register(self._name, "shared_memory")
121
122        else:
123
124            # Windows Named Shared Memory
125
126            if create:
127                while True:
128                    temp_name = _make_filename() if name is None else name
129                    # Create and reserve shared memory block with this name
130                    # until it can be attached to by mmap.
131                    h_map = _winapi.CreateFileMapping(
132                        _winapi.INVALID_HANDLE_VALUE,
133                        _winapi.NULL,
134                        _winapi.PAGE_READWRITE,
135                        (size >> 32) & 0xFFFFFFFF,
136                        size & 0xFFFFFFFF,
137                        temp_name
138                    )
139                    try:
140                        last_error_code = _winapi.GetLastError()
141                        if last_error_code == _winapi.ERROR_ALREADY_EXISTS:
142                            if name is not None:
143                                raise FileExistsError(
144                                    errno.EEXIST,
145                                    os.strerror(errno.EEXIST),
146                                    name,
147                                    _winapi.ERROR_ALREADY_EXISTS
148                                )
149                            else:
150                                continue
151                        self._mmap = mmap.mmap(-1, size, tagname=temp_name)
152                    finally:
153                        _winapi.CloseHandle(h_map)
154                    self._name = temp_name
155                    break
156
157            else:
158                self._name = name
159                # Dynamically determine the existing named shared memory
160                # block's size which is likely a multiple of mmap.PAGESIZE.
161                h_map = _winapi.OpenFileMapping(
162                    _winapi.FILE_MAP_READ,
163                    False,
164                    name
165                )
166                try:
167                    p_buf = _winapi.MapViewOfFile(
168                        h_map,
169                        _winapi.FILE_MAP_READ,
170                        0,
171                        0,
172                        0
173                    )
174                finally:
175                    _winapi.CloseHandle(h_map)
176                size = _winapi.VirtualQuerySize(p_buf)
177                self._mmap = mmap.mmap(-1, size, tagname=name)
178
179        self._size = size
180        self._buf = memoryview(self._mmap)
181
182    def __del__(self):
183        try:
184            self.close()
185        except OSError:
186            pass
187
188    def __reduce__(self):
189        return (
190            self.__class__,
191            (
192                self.name,
193                False,
194                self.size,
195            ),
196        )
197
198    def __repr__(self):
199        return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
200
201    @property
202    def buf(self):
203        "A memoryview of contents of the shared memory block."
204        return self._buf
205
206    @property
207    def name(self):
208        "Unique name that identifies the shared memory block."
209        reported_name = self._name
210        if _USE_POSIX and self._prepend_leading_slash:
211            if self._name.startswith("/"):
212                reported_name = self._name[1:]
213        return reported_name
214
215    @property
216    def size(self):
217        "Size in bytes."
218        return self._size
219
220    def close(self):
221        """Closes access to the shared memory from this instance but does
222        not destroy the shared memory block."""
223        if self._buf is not None:
224            self._buf.release()
225            self._buf = None
226        if self._mmap is not None:
227            self._mmap.close()
228            self._mmap = None
229        if _USE_POSIX and self._fd >= 0:
230            os.close(self._fd)
231            self._fd = -1
232
233    def unlink(self):
234        """Requests that the underlying shared memory block be destroyed.
235
236        In order to ensure proper cleanup of resources, unlink should be
237        called once (and only once) across all processes which have access
238        to the shared memory block."""
239        if _USE_POSIX and self._name:
240            from .resource_tracker import unregister
241            _posixshmem.shm_unlink(self._name)
242            unregister(self._name, "shared_memory")
243
244
245_encoding = "utf8"
246
247class ShareableList:
248    """Pattern for a mutable list-like object shareable via a shared
249    memory block.  It differs from the built-in list type in that these
250    lists can not change their overall length (i.e. no append, insert,
251    etc.)
252
253    Because values are packed into a memoryview as bytes, the struct
254    packing format for any storable value must require no more than 8
255    characters to describe its format."""
256
257    # The shared memory area is organized as follows:
258    # - 8 bytes: number of items (N) as a 64-bit integer
259    # - (N + 1) * 8 bytes: offsets of each element from the start of the
260    #                      data area
261    # - K bytes: the data area storing item values (with encoding and size
262    #            depending on their respective types)
263    # - N * 8 bytes: `struct` format string for each element
264    # - N bytes: index into _back_transforms_mapping for each element
265    #            (for reconstructing the corresponding Python value)
266    _types_mapping = {
267        int: "q",
268        float: "d",
269        bool: "xxxxxxx?",
270        str: "%ds",
271        bytes: "%ds",
272        None.__class__: "xxxxxx?x",
273    }
274    _alignment = 8
275    _back_transforms_mapping = {
276        0: lambda value: value,                   # int, float, bool
277        1: lambda value: value.rstrip(b'\x00').decode(_encoding),  # str
278        2: lambda value: value.rstrip(b'\x00'),   # bytes
279        3: lambda _value: None,                   # None
280    }
281
282    @staticmethod
283    def _extract_recreation_code(value):
284        """Used in concert with _back_transforms_mapping to convert values
285        into the appropriate Python objects when retrieving them from
286        the list as well as when storing them."""
287        if not isinstance(value, (str, bytes, None.__class__)):
288            return 0
289        elif isinstance(value, str):
290            return 1
291        elif isinstance(value, bytes):
292            return 2
293        else:
294            return 3  # NoneType
295
296    def __init__(self, sequence=None, *, name=None):
297        if name is None or sequence is not None:
298            sequence = sequence or ()
299            _formats = [
300                self._types_mapping[type(item)]
301                    if not isinstance(item, (str, bytes))
302                    else self._types_mapping[type(item)] % (
303                        self._alignment * (len(item) // self._alignment + 1),
304                    )
305                for item in sequence
306            ]
307            self._list_len = len(_formats)
308            assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len
309            offset = 0
310            # The offsets of each list element into the shared memory's
311            # data area (0 meaning the start of the data area, not the start
312            # of the shared memory area).
313            self._allocated_offsets = [0]
314            for fmt in _formats:
315                offset += self._alignment if fmt[-1] != "s" else int(fmt[:-1])
316                self._allocated_offsets.append(offset)
317            _recreation_codes = [
318                self._extract_recreation_code(item) for item in sequence
319            ]
320            requested_size = struct.calcsize(
321                "q" + self._format_size_metainfo +
322                "".join(_formats) +
323                self._format_packing_metainfo +
324                self._format_back_transform_codes
325            )
326
327            self.shm = SharedMemory(name, create=True, size=requested_size)
328        else:
329            self.shm = SharedMemory(name)
330
331        if sequence is not None:
332            _enc = _encoding
333            struct.pack_into(
334                "q" + self._format_size_metainfo,
335                self.shm.buf,
336                0,
337                self._list_len,
338                *(self._allocated_offsets)
339            )
340            struct.pack_into(
341                "".join(_formats),
342                self.shm.buf,
343                self._offset_data_start,
344                *(v.encode(_enc) if isinstance(v, str) else v for v in sequence)
345            )
346            struct.pack_into(
347                self._format_packing_metainfo,
348                self.shm.buf,
349                self._offset_packing_formats,
350                *(v.encode(_enc) for v in _formats)
351            )
352            struct.pack_into(
353                self._format_back_transform_codes,
354                self.shm.buf,
355                self._offset_back_transform_codes,
356                *(_recreation_codes)
357            )
358
359        else:
360            self._list_len = len(self)  # Obtains size from offset 0 in buffer.
361            self._allocated_offsets = list(
362                struct.unpack_from(
363                    self._format_size_metainfo,
364                    self.shm.buf,
365                    1 * 8
366                )
367            )
368
369    def _get_packing_format(self, position):
370        "Gets the packing format for a single value stored in the list."
371        position = position if position >= 0 else position + self._list_len
372        if (position >= self._list_len) or (self._list_len < 0):
373            raise IndexError("Requested position out of range.")
374
375        v = struct.unpack_from(
376            "8s",
377            self.shm.buf,
378            self._offset_packing_formats + position * 8
379        )[0]
380        fmt = v.rstrip(b'\x00')
381        fmt_as_str = fmt.decode(_encoding)
382
383        return fmt_as_str
384
385    def _get_back_transform(self, position):
386        "Gets the back transformation function for a single value."
387
388        if (position >= self._list_len) or (self._list_len < 0):
389            raise IndexError("Requested position out of range.")
390
391        transform_code = struct.unpack_from(
392            "b",
393            self.shm.buf,
394            self._offset_back_transform_codes + position
395        )[0]
396        transform_function = self._back_transforms_mapping[transform_code]
397
398        return transform_function
399
400    def _set_packing_format_and_transform(self, position, fmt_as_str, value):
401        """Sets the packing format and back transformation code for a
402        single value in the list at the specified position."""
403
404        if (position >= self._list_len) or (self._list_len < 0):
405            raise IndexError("Requested position out of range.")
406
407        struct.pack_into(
408            "8s",
409            self.shm.buf,
410            self._offset_packing_formats + position * 8,
411            fmt_as_str.encode(_encoding)
412        )
413
414        transform_code = self._extract_recreation_code(value)
415        struct.pack_into(
416            "b",
417            self.shm.buf,
418            self._offset_back_transform_codes + position,
419            transform_code
420        )
421
422    def __getitem__(self, position):
423        position = position if position >= 0 else position + self._list_len
424        try:
425            offset = self._offset_data_start + self._allocated_offsets[position]
426            (v,) = struct.unpack_from(
427                self._get_packing_format(position),
428                self.shm.buf,
429                offset
430            )
431        except IndexError:
432            raise IndexError("index out of range")
433
434        back_transform = self._get_back_transform(position)
435        v = back_transform(v)
436
437        return v
438
439    def __setitem__(self, position, value):
440        position = position if position >= 0 else position + self._list_len
441        try:
442            item_offset = self._allocated_offsets[position]
443            offset = self._offset_data_start + item_offset
444            current_format = self._get_packing_format(position)
445        except IndexError:
446            raise IndexError("assignment index out of range")
447
448        if not isinstance(value, (str, bytes)):
449            new_format = self._types_mapping[type(value)]
450            encoded_value = value
451        else:
452            allocated_length = self._allocated_offsets[position + 1] - item_offset
453
454            encoded_value = (value.encode(_encoding)
455                             if isinstance(value, str) else value)
456            if len(encoded_value) > allocated_length:
457                raise ValueError("bytes/str item exceeds available storage")
458            if current_format[-1] == "s":
459                new_format = current_format
460            else:
461                new_format = self._types_mapping[str] % (
462                    allocated_length,
463                )
464
465        self._set_packing_format_and_transform(
466            position,
467            new_format,
468            value
469        )
470        struct.pack_into(new_format, self.shm.buf, offset, encoded_value)
471
472    def __reduce__(self):
473        return partial(self.__class__, name=self.shm.name), ()
474
475    def __len__(self):
476        return struct.unpack_from("q", self.shm.buf, 0)[0]
477
478    def __repr__(self):
479        return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})'
480
481    @property
482    def format(self):
483        "The struct packing format used by all currently stored items."
484        return "".join(
485            self._get_packing_format(i) for i in range(self._list_len)
486        )
487
488    @property
489    def _format_size_metainfo(self):
490        "The struct packing format used for the items' storage offsets."
491        return "q" * (self._list_len + 1)
492
493    @property
494    def _format_packing_metainfo(self):
495        "The struct packing format used for the items' packing formats."
496        return "8s" * self._list_len
497
498    @property
499    def _format_back_transform_codes(self):
500        "The struct packing format used for the items' back transforms."
501        return "b" * self._list_len
502
503    @property
504    def _offset_data_start(self):
505        # - 8 bytes for the list length
506        # - (N + 1) * 8 bytes for the element offsets
507        return (self._list_len + 2) * 8
508
509    @property
510    def _offset_packing_formats(self):
511        return self._offset_data_start + self._allocated_offsets[-1]
512
513    @property
514    def _offset_back_transform_codes(self):
515        return self._offset_packing_formats + self._list_len * 8
516
517    def count(self, value):
518        "L.count(value) -> integer -- return number of occurrences of value."
519
520        return sum(value == entry for entry in self)
521
522    def index(self, value):
523        """L.index(value) -> integer -- return first index of value.
524        Raises ValueError if the value is not present."""
525
526        for position, entry in enumerate(self):
527            if value == entry:
528                return position
529        else:
530            raise ValueError(f"{value!r} not in this container")
531
532    __class_getitem__ = classmethod(types.GenericAlias)
533