1#!/usr/bin/env python3
2import os
3import socket
4import sys
5import select
6from select import EPOLLIN, EPOLLPRI, EPOLLERR
7import time
8from collections import namedtuple
9import argparse
10
11TIMEOUT = 1.0 # seconds
12
13VERSION_HEADER = bytearray('MesaOverlayControlVersion', 'utf-8')
14DEVICE_NAME_HEADER = bytearray('DeviceName', 'utf-8')
15MESA_VERSION_HEADER = bytearray('MesaVersion', 'utf-8')
16
17DEFAULT_SERVER_ADDRESS = "\0mesa_overlay"
18
19class Connection:
20    def __init__(self, path):
21        # Create a Unix Domain socket and connect
22        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
23        try:
24            sock.connect(path)
25        except socket.error as msg:
26            print(msg)
27            sys.exit(1)
28
29        self.sock = sock
30
31        # initialize poll interface and register socket
32        epoll = select.epoll()
33        epoll.register(sock, EPOLLIN | EPOLLPRI | EPOLLERR)
34        self.epoll = epoll
35
36    def recv(self, timeout):
37        '''
38        timeout as float in seconds
39        returns:
40            - None on error or disconnection
41            - bytes() (empty) on timeout
42        '''
43
44        events = self.epoll.poll(timeout)
45        for ev in events:
46            (fd, event) = ev
47            if fd != self.sock.fileno():
48                continue
49
50            # check for socket error
51            if event & EPOLLERR:
52                return None
53
54            # EPOLLIN or EPOLLPRI, just read the message
55            msg = self.sock.recv(4096)
56
57            # socket disconnected
58            if len(msg) == 0:
59                return None
60
61            return msg
62
63        return bytes()
64
65    def send(self, msg):
66        self.sock.send(msg)
67
68class MsgParser:
69    MSGBEGIN = bytes(':', 'utf-8')[0]
70    MSGEND = bytes(';', 'utf-8')[0]
71    MSGSEP = bytes('=', 'utf-8')[0]
72
73    def __init__(self, conn):
74        self.cmdpos = 0
75        self.parampos = 0
76        self.bufferpos = 0
77        self.reading_cmd = False
78        self.reading_param = False
79        self.buffer = None
80        self.cmd = bytearray(4096)
81        self.param = bytearray(4096)
82
83        self.conn = conn
84
85    def readCmd(self, ncmds, timeout=TIMEOUT):
86        '''
87        returns:
88            - None on error or disconnection
89            - bytes() (empty) on timeout
90        '''
91
92        parsed = []
93
94        remaining = timeout
95
96        while remaining > 0 and ncmds > 0:
97            now = time.monotonic()
98
99            if self.buffer == None:
100                self.buffer = self.conn.recv(remaining)
101                self.bufferpos = 0
102
103            # disconnected or error
104            if self.buffer == None:
105                return None
106
107            for i in range(self.bufferpos, len(self.buffer)):
108                c = self.buffer[i]
109                self.bufferpos += 1
110                if c == self.MSGBEGIN:
111                    self.cmdpos = 0
112                    self.parampos = 0
113                    self.reading_cmd = True
114                    self.reading_param = False
115                elif c == self.MSGEND:
116                    if not self.reading_cmd:
117                        continue
118                    self.reading_cmd = False
119                    self.reading_param = False
120
121                    cmd = self.cmd[0:self.cmdpos]
122                    param = self.param[0:self.parampos]
123                    self.reading_cmd = False
124                    self.reading_param = False
125
126                    parsed.append((cmd, param))
127                    ncmds -= 1
128                    if ncmds == 0:
129                        break
130                elif c == self.MSGSEP:
131                    if self.reading_cmd:
132                        self.reading_param = True
133                else:
134                    if self.reading_param:
135                        self.param[self.parampos] = c
136                        self.parampos += 1
137                    elif self.reading_cmd:
138                        self.cmd[self.cmdpos] = c
139                        self.cmdpos += 1
140
141            # if we read the entire buffer and didn't finish the command,
142            # throw it away
143            self.buffer = None
144
145            # check if we have time for another iteration
146            elapsed = time.monotonic() - now
147            remaining = max(0, remaining - elapsed)
148
149        # timeout
150        return parsed
151
152def control(args):
153    if args.socket:
154        address = '\0' + args.socket
155    else:
156        address = DEFAULT_SERVER_ADDRESS
157
158    conn = Connection(address)
159    msgparser = MsgParser(conn)
160
161    version = None
162    name = None
163    mesa_version = None
164
165    msgs = msgparser.readCmd(3)
166
167    for m in msgs:
168        cmd, param = m
169        if cmd == VERSION_HEADER:
170            version = int(param)
171        elif cmd == DEVICE_NAME_HEADER:
172            name = param.decode('utf-8')
173        elif cmd == MESA_VERSION_HEADER:
174            mesa_version = param.decode('utf-8')
175
176    if version != 1 or name == None or mesa_version == None:
177        print('ERROR: invalid protocol')
178        sys.exit(1)
179
180
181    if args.info:
182        info = "Protocol Version: {}\n"
183        info += "Device Name: {}\n"
184        info += "Mesa Version: {}"
185        print(info.format(version, name, mesa_version))
186
187    if args.cmd == 'start-capture':
188        conn.send(bytearray(':capture=1;', 'utf-8'))
189    elif args.cmd == 'stop-capture':
190        conn.send(bytearray(':capture=0;', 'utf-8'))
191
192if __name__ == '__main__':
193    parser = argparse.ArgumentParser(description='MESA_overlay control client')
194    parser.add_argument('--info', action='store_true', help='Print info from socket')
195    parser.add_argument('--socket', '-s', type=str, help='Path to socket')
196
197    commands = parser.add_subparsers(help='commands to run', dest='cmd')
198    commands.add_parser('start-capture')
199    commands.add_parser('stop-capture')
200
201    args = parser.parse_args()
202
203    control(args)
204