1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#  Project                     ___| | | |  _ \| |
5#                             / __| | | | |_) | |
6#                            | (__| |_| |  _ <| |___
7#                             \___|\___/|_| \_\_____|
8#
9# Copyright (C) 2017 - 2020, Daniel Stenberg, <daniel@haxx.se>, et al.
10#
11# This software is licensed as described in the file COPYING, which
12# you should have received as part of this distribution. The terms
13# are also available at https://curl.haxx.se/docs/copyright.html.
14#
15# You may opt to use, copy, modify, merge, publish, distribute and/or sell
16# copies of the Software, and permit persons to whom the Software is
17# furnished to do so, under the terms of the COPYING file.
18#
19# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
20# KIND, either express or implied.
21#
22"""Server for testing SMB"""
23
24from __future__ import (absolute_import, division, print_function)
25# NOTE: the impacket configuration is not unicode_literals compatible!
26import argparse
27import os
28import sys
29import logging
30import tempfile
31if sys.version_info.major >= 3:
32    import configparser
33else:
34    import ConfigParser as configparser
35
36# Import our curl test data helper
37import curl_test_data
38
39# impacket needs to be installed in the Python environment
40try:
41    import impacket
42except ImportError:
43    sys.stderr.write('Python package impacket needs to be installed!\n')
44    sys.stderr.write('Use pip or your package manager to install it.\n')
45    sys.exit(1)
46from impacket import smbserver as imp_smbserver
47from impacket import smb as imp_smb
48from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_SUCCESS,
49                                STATUS_NO_SUCH_FILE)
50
51log = logging.getLogger(__name__)
52SERVER_MAGIC = "SERVER_MAGIC"
53TESTS_MAGIC = "TESTS_MAGIC"
54VERIFIED_REQ = "verifiedserver"
55VERIFIED_RSP = "WE ROOLZ: {pid}\n"
56
57
58def smbserver(options):
59    """Start up a TCP SMB server that serves forever
60
61    """
62    if options.pidfile:
63        pid = os.getpid()
64        # see tests/server/util.c function write_pidfile
65        if os.name == "nt":
66            pid += 65536
67        with open(options.pidfile, "w") as f:
68            f.write(str(pid))
69
70    # Here we write a mini config for the server
71    smb_config = configparser.ConfigParser()
72    smb_config.add_section("global")
73    smb_config.set("global", "server_name", "SERVICE")
74    smb_config.set("global", "server_os", "UNIX")
75    smb_config.set("global", "server_domain", "WORKGROUP")
76    smb_config.set("global", "log_file", "")
77    smb_config.set("global", "credentials_file", "")
78
79    # We need a share which allows us to test that the server is running
80    smb_config.add_section("SERVER")
81    smb_config.set("SERVER", "comment", "server function")
82    smb_config.set("SERVER", "read only", "yes")
83    smb_config.set("SERVER", "share type", "0")
84    smb_config.set("SERVER", "path", SERVER_MAGIC)
85
86    # Have a share for tests.  These files will be autogenerated from the
87    # test input.
88    smb_config.add_section("TESTS")
89    smb_config.set("TESTS", "comment", "tests")
90    smb_config.set("TESTS", "read only", "yes")
91    smb_config.set("TESTS", "share type", "0")
92    smb_config.set("TESTS", "path", TESTS_MAGIC)
93
94    if not options.srcdir or not os.path.isdir(options.srcdir):
95        raise ScriptException("--srcdir is mandatory")
96
97    test_data_dir = os.path.join(options.srcdir, "data")
98
99    smb_server = TestSmbServer((options.host, options.port),
100                               config_parser=smb_config,
101                               test_data_directory=test_data_dir)
102    log.info("[SMB] setting up SMB server on port %s", options.port)
103    smb_server.processConfigFile()
104    smb_server.serve_forever()
105    return 0
106
107
108class TestSmbServer(imp_smbserver.SMBSERVER):
109    """
110    Test server for SMB which subclasses the impacket SMBSERVER and provides
111    test functionality.
112    """
113
114    def __init__(self,
115                 address,
116                 config_parser=None,
117                 test_data_directory=None):
118        imp_smbserver.SMBSERVER.__init__(self,
119                                         address,
120                                         config_parser=config_parser)
121
122        # Set up a test data object so we can get test data later.
123        self.ctd = curl_test_data.TestData(test_data_directory)
124
125        # Override smbComNtCreateAndX so we can pretend to have files which
126        # don't exist.
127        self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
128                            self.create_and_x)
129
130    def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
131        """
132        Our version of smbComNtCreateAndX looks for special test files and
133        fools the rest of the framework into opening them as if they were
134        normal files.
135        """
136        conn_data = smb_server.getConnectionData(conn_id)
137
138        # Wrap processing in a try block which allows us to throw SmbException
139        # to control the flow.
140        try:
141            ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
142                smb_command["Parameters"])
143
144            path = self.get_share_path(conn_data,
145                                       ncax_parms["RootFid"],
146                                       recv_packet["Tid"])
147            log.info("[SMB] Requested share path: %s", path)
148
149            disposition = ncax_parms["Disposition"]
150            log.debug("[SMB] Requested disposition: %s", disposition)
151
152            # Currently we only support reading files.
153            if disposition != imp_smb.FILE_OPEN:
154                raise SmbException(STATUS_ACCESS_DENIED,
155                                   "Only support reading files")
156
157            # Check to see if the path we were given is actually a
158            # magic path which needs generating on the fly.
159            if path not in [SERVER_MAGIC, TESTS_MAGIC]:
160                # Pass the command onto the original handler.
161                return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
162                                                                    smb_server,
163                                                                    smb_command,
164                                                                    recv_packet)
165
166            flags2 = recv_packet["Flags2"]
167            ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
168                                                     data=smb_command[
169                                                         "Data"])
170            requested_file = imp_smbserver.decodeSMBString(
171                flags2,
172                ncax_data["FileName"])
173            log.debug("[SMB] User requested file '%s'", requested_file)
174
175            if path == SERVER_MAGIC:
176                fid, full_path = self.get_server_path(requested_file)
177            else:
178                assert (path == TESTS_MAGIC)
179                fid, full_path = self.get_test_path(requested_file)
180
181            resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
182            resp_data = ""
183
184            # Simple way to generate a fid
185            if len(conn_data["OpenedFiles"]) == 0:
186                fakefid = 1
187            else:
188                fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
189            resp_parms["Fid"] = fakefid
190            resp_parms["CreateAction"] = disposition
191
192            if os.path.isdir(path):
193                resp_parms[
194                    "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
195                resp_parms["IsDirectory"] = 1
196            else:
197                resp_parms["IsDirectory"] = 0
198                resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
199
200            # Get this file's information
201            resp_info, error_code = imp_smbserver.queryPathInformation(
202                "", full_path, level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
203
204            if error_code != STATUS_SUCCESS:
205                raise SmbException(error_code, "Failed to query path info")
206
207            resp_parms["CreateTime"] = resp_info["CreationTime"]
208            resp_parms["LastAccessTime"] = resp_info[
209                "LastAccessTime"]
210            resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
211            resp_parms["LastChangeTime"] = resp_info[
212                "LastChangeTime"]
213            resp_parms["FileAttributes"] = resp_info[
214                "ExtFileAttributes"]
215            resp_parms["AllocationSize"] = resp_info[
216                "AllocationSize"]
217            resp_parms["EndOfFile"] = resp_info["EndOfFile"]
218
219            # Let's store the fid for the connection
220            # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
221            conn_data["OpenedFiles"][fakefid] = {}
222            conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
223            conn_data["OpenedFiles"][fakefid]["FileName"] = path
224            conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
225
226        except SmbException as s:
227            log.debug("[SMB] SmbException hit: %s", s)
228            error_code = s.error_code
229            resp_parms = ""
230            resp_data = ""
231
232        resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
233        resp_cmd["Parameters"] = resp_parms
234        resp_cmd["Data"] = resp_data
235        smb_server.setConnectionData(conn_id, conn_data)
236
237        return [resp_cmd], None, error_code
238
239    def get_share_path(self, conn_data, root_fid, tid):
240        conn_shares = conn_data["ConnectedShares"]
241
242        if tid in conn_shares:
243            if root_fid > 0:
244                # If we have a rootFid, the path is relative to that fid
245                path = conn_data["OpenedFiles"][root_fid]["FileName"]
246                log.debug("RootFid present %s!" % path)
247            else:
248                if "path" in conn_shares[tid]:
249                    path = conn_shares[tid]["path"]
250                else:
251                    raise SmbException(STATUS_ACCESS_DENIED,
252                                       "Connection share had no path")
253        else:
254            raise SmbException(imp_smbserver.STATUS_SMB_BAD_TID,
255                               "TID was invalid")
256
257        return path
258
259    def get_server_path(self, requested_filename):
260        log.debug("[SMB] Get server path '%s'", requested_filename)
261
262        if requested_filename not in [VERIFIED_REQ]:
263            raise SmbException(STATUS_NO_SUCH_FILE, "Couldn't find the file")
264
265        fid, filename = tempfile.mkstemp()
266        log.debug("[SMB] Created %s (%d) for storing '%s'",
267                  filename, fid, requested_filename)
268
269        contents = ""
270
271        if requested_filename == VERIFIED_REQ:
272            log.debug("[SMB] Verifying server is alive")
273            pid = os.getpid()
274            # see tests/server/util.c function write_pidfile
275            if os.name == "nt":
276                pid += 65536
277            contents = VERIFIED_RSP.format(pid=pid).encode('utf-8')
278
279        self.write_to_fid(fid, contents)
280        return fid, filename
281
282    def write_to_fid(self, fid, contents):
283        # Write the contents to file descriptor
284        os.write(fid, contents)
285        os.fsync(fid)
286
287        # Rewind the file to the beginning so a read gets us the contents
288        os.lseek(fid, 0, os.SEEK_SET)
289
290    def get_test_path(self, requested_filename):
291        log.info("[SMB] Get reply data from 'test%s'", requested_filename)
292
293        fid, filename = tempfile.mkstemp()
294        log.debug("[SMB] Created %s (%d) for storing test '%s'",
295                  filename, fid, requested_filename)
296
297        try:
298            contents = self.ctd.get_test_data(requested_filename).encode('utf-8')
299            self.write_to_fid(fid, contents)
300            return fid, filename
301
302        except Exception:
303            log.exception("Failed to make test file")
304            raise SmbException(STATUS_NO_SUCH_FILE, "Failed to make test file")
305
306
307class SmbException(Exception):
308    def __init__(self, error_code, error_message):
309        super(SmbException, self).__init__(error_message)
310        self.error_code = error_code
311
312
313class ScriptRC(object):
314    """Enum for script return codes"""
315    SUCCESS = 0
316    FAILURE = 1
317    EXCEPTION = 2
318
319
320class ScriptException(Exception):
321    pass
322
323
324def get_options():
325    parser = argparse.ArgumentParser()
326
327    parser.add_argument("--port", action="store", default=9017,
328                      type=int, help="port to listen on")
329    parser.add_argument("--host", action="store", default="127.0.0.1",
330                      help="host to listen on")
331    parser.add_argument("--verbose", action="store", type=int, default=0,
332                        help="verbose output")
333    parser.add_argument("--pidfile", action="store",
334                        help="file name for the PID")
335    parser.add_argument("--logfile", action="store",
336                        help="file name for the log")
337    parser.add_argument("--srcdir", action="store", help="test directory")
338    parser.add_argument("--id", action="store", help="server ID")
339    parser.add_argument("--ipv4", action="store_true", default=0,
340                        help="IPv4 flag")
341
342    return parser.parse_args()
343
344
345def setup_logging(options):
346    """
347    Set up logging from the command line options
348    """
349    root_logger = logging.getLogger()
350    add_stdout = False
351
352    formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
353
354    # Write out to a logfile
355    if options.logfile:
356        handler = logging.FileHandler(options.logfile, mode="w")
357        handler.setFormatter(formatter)
358        handler.setLevel(logging.DEBUG)
359        root_logger.addHandler(handler)
360    else:
361        # The logfile wasn't specified. Add a stdout logger.
362        add_stdout = True
363
364    if options.verbose:
365        # Add a stdout logger as well in verbose mode
366        root_logger.setLevel(logging.DEBUG)
367        add_stdout = True
368    else:
369        root_logger.setLevel(logging.INFO)
370
371    if add_stdout:
372        stdout_handler = logging.StreamHandler(sys.stdout)
373        stdout_handler.setFormatter(formatter)
374        stdout_handler.setLevel(logging.DEBUG)
375        root_logger.addHandler(stdout_handler)
376
377
378if __name__ == '__main__':
379    # Get the options from the user.
380    options = get_options()
381
382    # Setup logging using the user options
383    setup_logging(options)
384
385    # Run main script.
386    try:
387        rc = smbserver(options)
388    except Exception as e:
389        log.exception(e)
390        rc = ScriptRC.EXCEPTION
391
392    log.info("[SMB] Returning %d", rc)
393    sys.exit(rc)
394