1#!/usr/bin/env python
2
3import argparse
4import code
5import sys
6import threading
7import time
8
9import six
10from six.moves.urllib.parse import urlparse
11
12import websocket
13
14try:
15    import readline
16except ImportError:
17    pass
18
19
20def get_encoding():
21    encoding = getattr(sys.stdin, "encoding", "")
22    if not encoding:
23        return "utf-8"
24    else:
25        return encoding.lower()
26
27
28OPCODE_DATA = (websocket.ABNF.OPCODE_TEXT, websocket.ABNF.OPCODE_BINARY)
29ENCODING = get_encoding()
30
31
32class VAction(argparse.Action):
33
34    def __call__(self, parser, args, values, option_string=None):
35        if values is None:
36            values = "1"
37        try:
38            values = int(values)
39        except ValueError:
40            values = values.count("v") + 1
41        setattr(args, self.dest, values)
42
43
44def parse_args():
45    parser = argparse.ArgumentParser(description="WebSocket Simple Dump Tool")
46    parser.add_argument("url", metavar="ws_url",
47                        help="websocket url. ex. ws://echo.websocket.org/")
48    parser.add_argument("-p", "--proxy",
49                        help="proxy url. ex. http://127.0.0.1:8080")
50    parser.add_argument("-v", "--verbose", default=0, nargs='?', action=VAction,
51                        dest="verbose",
52                        help="set verbose mode. If set to 1, show opcode. "
53                        "If set to 2, enable to trace  websocket module")
54    parser.add_argument("-n", "--nocert", action='store_true',
55                        help="Ignore invalid SSL cert")
56    parser.add_argument("-r", "--raw", action="store_true",
57                        help="raw output")
58    parser.add_argument("-s", "--subprotocols", nargs='*',
59                        help="Set subprotocols")
60    parser.add_argument("-o", "--origin",
61                        help="Set origin")
62    parser.add_argument("--eof-wait", default=0, type=int,
63                        help="wait time(second) after 'EOF' received.")
64    parser.add_argument("-t", "--text",
65                        help="Send initial text")
66    parser.add_argument("--timings", action="store_true",
67                        help="Print timings in seconds")
68    parser.add_argument("--headers",
69                        help="Set custom headers. Use ',' as separator")
70
71    return parser.parse_args()
72
73
74class RawInput:
75
76    def raw_input(self, prompt):
77        if six.PY3:
78            line = input(prompt)
79        else:
80            line = raw_input(prompt)
81
82        if ENCODING and ENCODING != "utf-8" and not isinstance(line, six.text_type):
83            line = line.decode(ENCODING).encode("utf-8")
84        elif isinstance(line, six.text_type):
85            line = line.encode("utf-8")
86
87        return line
88
89
90class InteractiveConsole(RawInput, code.InteractiveConsole):
91
92    def write(self, data):
93        sys.stdout.write("\033[2K\033[E")
94        # sys.stdout.write("\n")
95        sys.stdout.write("\033[34m< " + data + "\033[39m")
96        sys.stdout.write("\n> ")
97        sys.stdout.flush()
98
99    def read(self):
100        return self.raw_input("> ")
101
102
103class NonInteractive(RawInput):
104
105    def write(self, data):
106        sys.stdout.write(data)
107        sys.stdout.write("\n")
108        sys.stdout.flush()
109
110    def read(self):
111        return self.raw_input("")
112
113
114def main():
115    start_time = time.time()
116    args = parse_args()
117    if args.verbose > 1:
118        websocket.enableTrace(True)
119    options = {}
120    if args.proxy:
121        p = urlparse(args.proxy)
122        options["http_proxy_host"] = p.hostname
123        options["http_proxy_port"] = p.port
124    if args.origin:
125        options["origin"] = args.origin
126    if args.subprotocols:
127        options["subprotocols"] = args.subprotocols
128    opts = {}
129    if args.nocert:
130        opts = {"cert_reqs": websocket.ssl.CERT_NONE, "check_hostname": False}
131    if args.headers:
132        options['header'] = map(str.strip, args.headers.split(','))
133    ws = websocket.create_connection(args.url, sslopt=opts, **options)
134    if args.raw:
135        console = NonInteractive()
136    else:
137        console = InteractiveConsole()
138        print("Press Ctrl+C to quit")
139
140    def recv():
141        try:
142            frame = ws.recv_frame()
143        except websocket.WebSocketException:
144            return websocket.ABNF.OPCODE_CLOSE, None
145        if not frame:
146            raise websocket.WebSocketException("Not a valid frame %s" % frame)
147        elif frame.opcode in OPCODE_DATA:
148            return frame.opcode, frame.data
149        elif frame.opcode == websocket.ABNF.OPCODE_CLOSE:
150            ws.send_close()
151            return frame.opcode, None
152        elif frame.opcode == websocket.ABNF.OPCODE_PING:
153            ws.pong(frame.data)
154            return frame.opcode, frame.data
155
156        return frame.opcode, frame.data
157
158    def recv_ws():
159        while True:
160            opcode, data = recv()
161            msg = None
162            if six.PY3 and opcode == websocket.ABNF.OPCODE_TEXT and isinstance(data, bytes):
163                data = str(data, "utf-8")
164            if not args.verbose and opcode in OPCODE_DATA:
165                msg = data
166            elif args.verbose:
167                msg = "%s: %s" % (websocket.ABNF.OPCODE_MAP.get(opcode), data)
168
169            if msg is not None:
170                if args.timings:
171                    console.write(str(time.time() - start_time) + ": " + msg)
172                else:
173                    console.write(msg)
174
175            if opcode == websocket.ABNF.OPCODE_CLOSE:
176                break
177
178    thread = threading.Thread(target=recv_ws)
179    thread.daemon = True
180    thread.start()
181
182    if args.text:
183        ws.send(args.text)
184
185    while True:
186        try:
187            message = console.read()
188            ws.send(message)
189        except KeyboardInterrupt:
190            return
191        except EOFError:
192            time.sleep(args.eof_wait)
193            return
194
195
196if __name__ == "__main__":
197    try:
198        main()
199    except Exception as e:
200        print(e)
201