1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#  Project                     ___| | | |  _ \| |
5#                             / __| | | | |_) | |
6#                            | (__| |_| |  _ <| |___
7#                             \___|\___/|_| \_\_____|
8#
9# Copyright (C) 2017, Daniel Stenberg, <daniel@haxx.se>, et al.
10#
11# This software is licensed as described in the file COPYING, which
12# you should have received as part of this distribution. The terms
13# are also available at https://curl.haxx.se/docs/copyright.html.
14#
15# You may opt to use, copy, modify, merge, publish, distribute and/or sell
16# copies of the Software, and permit persons to whom the Software is
17# furnished to do so, under the terms of the COPYING file.
18#
19# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
20# KIND, either express or implied.
21#
22"""Server for testing SMB"""
23
24from __future__ import (absolute_import, division, print_function)
25# unicode_literals)
26import argparse
27import ConfigParser
28import os
29import sys
30import logging
31import tempfile
32
33# Import our curl test data helper
34import curl_test_data
35
36# This saves us having to set up the PYTHONPATH explicitly
37deps_dir = os.path.join(os.path.dirname(__file__), "python_dependencies")
38sys.path.append(deps_dir)
39from impacket import smbserver as imp_smbserver
40from impacket import smb as imp_smb
41from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_SUCCESS,
42                                STATUS_NO_SUCH_FILE)
43
44log = logging.getLogger(__name__)
45SERVER_MAGIC = "SERVER_MAGIC"
46TESTS_MAGIC = "TESTS_MAGIC"
47VERIFIED_REQ = "verifiedserver"
48VERIFIED_RSP = b"WE ROOLZ: {pid}\n"
49
50
51def smbserver(options):
52    """Start up a TCP SMB server that serves forever
53
54    """
55    if options.pidfile:
56        pid = os.getpid()
57        with open(options.pidfile, "w") as f:
58            f.write("{0}".format(pid))
59
60    # Here we write a mini config for the server
61    smb_config = ConfigParser.ConfigParser()
62    smb_config.add_section("global")
63    smb_config.set("global", "server_name", "SERVICE")
64    smb_config.set("global", "server_os", "UNIX")
65    smb_config.set("global", "server_domain", "WORKGROUP")
66    smb_config.set("global", "log_file", "")
67    smb_config.set("global", "credentials_file", "")
68
69    # We need a share which allows us to test that the server is running
70    smb_config.add_section("SERVER")
71    smb_config.set("SERVER", "comment", "server function")
72    smb_config.set("SERVER", "read only", "yes")
73    smb_config.set("SERVER", "share type", "0")
74    smb_config.set("SERVER", "path", SERVER_MAGIC)
75
76    # Have a share for tests.  These files will be autogenerated from the
77    # test input.
78    smb_config.add_section("TESTS")
79    smb_config.set("TESTS", "comment", "tests")
80    smb_config.set("TESTS", "read only", "yes")
81    smb_config.set("TESTS", "share type", "0")
82    smb_config.set("TESTS", "path", TESTS_MAGIC)
83
84    if not options.srcdir or not os.path.isdir(options.srcdir):
85        raise ScriptException("--srcdir is mandatory")
86
87    test_data_dir = os.path.join(options.srcdir, "data")
88
89    smb_server = TestSmbServer(("127.0.0.1", options.port),
90                               config_parser=smb_config,
91                               test_data_directory=test_data_dir)
92    log.info("[SMB] setting up SMB server on port %s", options.port)
93    smb_server.processConfigFile()
94    smb_server.serve_forever()
95    return 0
96
97
98class TestSmbServer(imp_smbserver.SMBSERVER):
99    """
100    Test server for SMB which subclasses the impacket SMBSERVER and provides
101    test functionality.
102    """
103
104    def __init__(self,
105                 address,
106                 config_parser=None,
107                 test_data_directory=None):
108        imp_smbserver.SMBSERVER.__init__(self,
109                                         address,
110                                         config_parser=config_parser)
111
112        # Set up a test data object so we can get test data later.
113        self.ctd = curl_test_data.TestData(test_data_directory)
114
115        # Override smbComNtCreateAndX so we can pretend to have files which
116        # don't exist.
117        self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
118                            self.create_and_x)
119
120    def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
121        """
122        Our version of smbComNtCreateAndX looks for special test files and
123        fools the rest of the framework into opening them as if they were
124        normal files.
125        """
126        conn_data = smb_server.getConnectionData(conn_id)
127
128        # Wrap processing in a try block which allows us to throw SmbException
129        # to control the flow.
130        try:
131            ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
132                smb_command["Parameters"])
133
134            path = self.get_share_path(conn_data,
135                                       ncax_parms["RootFid"],
136                                       recv_packet["Tid"])
137            log.info("[SMB] Requested share path: %s", path)
138
139            disposition = ncax_parms["Disposition"]
140            log.debug("[SMB] Requested disposition: %s", disposition)
141
142            # Currently we only support reading files.
143            if disposition != imp_smb.FILE_OPEN:
144                raise SmbException(STATUS_ACCESS_DENIED,
145                                   "Only support reading files")
146
147            # Check to see if the path we were given is actually a
148            # magic path which needs generating on the fly.
149            if path not in [SERVER_MAGIC, TESTS_MAGIC]:
150                # Pass the command onto the original handler.
151                return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
152                                                                    smb_server,
153                                                                    smb_command,
154                                                                    recv_packet)
155
156            flags2 = recv_packet["Flags2"]
157            ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
158                                                     data=smb_command[
159                                                         "Data"])
160            requested_file = imp_smbserver.decodeSMBString(
161                flags2,
162                ncax_data["FileName"])
163            log.debug("[SMB] User requested file '%s'", requested_file)
164
165            if path == SERVER_MAGIC:
166                fid, full_path = self.get_server_path(requested_file)
167            else:
168                assert (path == TESTS_MAGIC)
169                fid, full_path = self.get_test_path(requested_file)
170
171            resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
172            resp_data = ""
173
174            # Simple way to generate a fid
175            if len(conn_data["OpenedFiles"]) == 0:
176                fakefid = 1
177            else:
178                fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
179            resp_parms["Fid"] = fakefid
180            resp_parms["CreateAction"] = disposition
181
182            if os.path.isdir(path):
183                resp_parms[
184                    "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
185                resp_parms["IsDirectory"] = 1
186            else:
187                resp_parms["IsDirectory"] = 0
188                resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
189
190            # Get this file's information
191            resp_info, error_code = imp_smbserver.queryPathInformation(
192                "", full_path, level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
193
194            if error_code != STATUS_SUCCESS:
195                raise SmbException(error_code, "Failed to query path info")
196
197            resp_parms["CreateTime"] = resp_info["CreationTime"]
198            resp_parms["LastAccessTime"] = resp_info[
199                "LastAccessTime"]
200            resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
201            resp_parms["LastChangeTime"] = resp_info[
202                "LastChangeTime"]
203            resp_parms["FileAttributes"] = resp_info[
204                "ExtFileAttributes"]
205            resp_parms["AllocationSize"] = resp_info[
206                "AllocationSize"]
207            resp_parms["EndOfFile"] = resp_info["EndOfFile"]
208
209            # Let's store the fid for the connection
210            # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
211            conn_data["OpenedFiles"][fakefid] = {}
212            conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
213            conn_data["OpenedFiles"][fakefid]["FileName"] = path
214            conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
215
216        except SmbException as s:
217            log.debug("[SMB] SmbException hit: %s", s)
218            error_code = s.error_code
219            resp_parms = ""
220            resp_data = ""
221
222        resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
223        resp_cmd["Parameters"] = resp_parms
224        resp_cmd["Data"] = resp_data
225        smb_server.setConnectionData(conn_id, conn_data)
226
227        return [resp_cmd], None, error_code
228
229    def get_share_path(self, conn_data, root_fid, tid):
230        conn_shares = conn_data["ConnectedShares"]
231
232        if tid in conn_shares:
233            if root_fid > 0:
234                # If we have a rootFid, the path is relative to that fid
235                path = conn_data["OpenedFiles"][root_fid]["FileName"]
236                log.debug("RootFid present %s!" % path)
237            else:
238                if "path" in conn_shares[tid]:
239                    path = conn_shares[tid]["path"]
240                else:
241                    raise SmbException(STATUS_ACCESS_DENIED,
242                                       "Connection share had no path")
243        else:
244            raise SmbException(imp_smbserver.STATUS_SMB_BAD_TID,
245                               "TID was invalid")
246
247        return path
248
249    def get_server_path(self, requested_filename):
250        log.debug("[SMB] Get server path '%s'", requested_filename)
251
252        if requested_filename not in [VERIFIED_REQ]:
253            raise SmbException(STATUS_NO_SUCH_FILE, "Couldn't find the file")
254
255        fid, filename = tempfile.mkstemp()
256        log.debug("[SMB] Created %s (%d) for storing '%s'",
257                  filename, fid, requested_filename)
258
259        contents = ""
260
261        if requested_filename == VERIFIED_REQ:
262            log.debug("[SMB] Verifying server is alive")
263            contents = VERIFIED_RSP.format(pid=os.getpid())
264
265        self.write_to_fid(fid, contents)
266        return fid, filename
267
268    def write_to_fid(self, fid, contents):
269        # Write the contents to file descriptor
270        os.write(fid, contents)
271        os.fsync(fid)
272
273        # Rewind the file to the beginning so a read gets us the contents
274        os.lseek(fid, 0, os.SEEK_SET)
275
276    def get_test_path(self, requested_filename):
277        log.info("[SMB] Get reply data from 'test%s'", requested_filename)
278
279        fid, filename = tempfile.mkstemp()
280        log.debug("[SMB] Created %s (%d) for storing test '%s'",
281                  filename, fid, requested_filename)
282
283        try:
284            contents = self.ctd.get_test_data(requested_filename)
285            self.write_to_fid(fid, contents)
286            return fid, filename
287
288        except Exception:
289            log.exception("Failed to make test file")
290            raise SmbException(STATUS_NO_SUCH_FILE, "Failed to make test file")
291
292
293class SmbException(Exception):
294    def __init__(self, error_code, error_message):
295        super(SmbException, self).__init__(error_message)
296        self.error_code = error_code
297
298
299class ScriptRC(object):
300    """Enum for script return codes"""
301    SUCCESS = 0
302    FAILURE = 1
303    EXCEPTION = 2
304
305
306class ScriptException(Exception):
307    pass
308
309
310def get_options():
311    parser = argparse.ArgumentParser()
312
313    parser.add_argument("--port", action="store", default=9017,
314                      type=int, help="port to listen on")
315    parser.add_argument("--verbose", action="store", type=int, default=0,
316                        help="verbose output")
317    parser.add_argument("--pidfile", action="store",
318                        help="file name for the PID")
319    parser.add_argument("--logfile", action="store",
320                        help="file name for the log")
321    parser.add_argument("--srcdir", action="store", help="test directory")
322    parser.add_argument("--id", action="store", help="server ID")
323    parser.add_argument("--ipv4", action="store_true", default=0,
324                        help="IPv4 flag")
325
326    return parser.parse_args()
327
328
329def setup_logging(options):
330    """
331    Set up logging from the command line options
332    """
333    root_logger = logging.getLogger()
334    add_stdout = False
335
336    formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
337
338    # Write out to a logfile
339    if options.logfile:
340        handler = logging.FileHandler(options.logfile, mode="w")
341        handler.setFormatter(formatter)
342        handler.setLevel(logging.DEBUG)
343        root_logger.addHandler(handler)
344    else:
345        # The logfile wasn't specified. Add a stdout logger.
346        add_stdout = True
347
348    if options.verbose:
349        # Add a stdout logger as well in verbose mode
350        root_logger.setLevel(logging.DEBUG)
351        add_stdout = True
352    else:
353        root_logger.setLevel(logging.INFO)
354
355    if add_stdout:
356        stdout_handler = logging.StreamHandler(sys.stdout)
357        stdout_handler.setFormatter(formatter)
358        stdout_handler.setLevel(logging.DEBUG)
359        root_logger.addHandler(stdout_handler)
360
361
362if __name__ == '__main__':
363    # Get the options from the user.
364    options = get_options()
365
366    # Setup logging using the user options
367    setup_logging(options)
368
369    # Run main script.
370    try:
371        rc = smbserver(options)
372    except Exception as e:
373        log.exception(e)
374        rc = ScriptRC.EXCEPTION
375
376    log.info("[SMB] Returning %d", rc)
377    sys.exit(rc)
378