1# -*- coding: iso-8859-1 -*-
2# Copyright (C) 2005 Martin v. L�wis
3# Licensed to PSF under a Contributor Agreement.
4from _msi import *
5import glob
6import os
7import re
8import string
9import sys
10
11AMD64 = "AMD64" in sys.version
12Itanium = "Itanium" in sys.version
13Win64 = AMD64 or Itanium
14
15# Partially taken from Wine
16datasizemask=      0x00ff
17type_valid=        0x0100
18type_localizable=  0x0200
19
20typemask=          0x0c00
21type_long=         0x0000
22type_short=        0x0400
23type_string=       0x0c00
24type_binary=       0x0800
25
26type_nullable=     0x1000
27type_key=          0x2000
28# XXX temporary, localizable?
29knownbits = datasizemask | type_valid | type_localizable | \
30            typemask | type_nullable | type_key
31
32class Table:
33    def __init__(self, name):
34        self.name = name
35        self.fields = []
36
37    def add_field(self, index, name, type):
38        self.fields.append((index,name,type))
39
40    def sql(self):
41        fields = []
42        keys = []
43        self.fields.sort()
44        fields = [None]*len(self.fields)
45        for index, name, type in self.fields:
46            index -= 1
47            unk = type & ~knownbits
48            if unk:
49                print "%s.%s unknown bits %x" % (self.name, name, unk)
50            size = type & datasizemask
51            dtype = type & typemask
52            if dtype == type_string:
53                if size:
54                    tname="CHAR(%d)" % size
55                else:
56                    tname="CHAR"
57            elif dtype == type_short:
58                assert size==2
59                tname = "SHORT"
60            elif dtype == type_long:
61                assert size==4
62                tname="LONG"
63            elif dtype == type_binary:
64                assert size==0
65                tname="OBJECT"
66            else:
67                tname="unknown"
68                print "%s.%sunknown integer type %d" % (self.name, name, size)
69            if type & type_nullable:
70                flags = ""
71            else:
72                flags = " NOT NULL"
73            if type & type_localizable:
74                flags += " LOCALIZABLE"
75            fields[index] = "`%s` %s%s" % (name, tname, flags)
76            if type & type_key:
77                keys.append("`%s`" % name)
78        fields = ", ".join(fields)
79        keys = ", ".join(keys)
80        return "CREATE TABLE %s (%s PRIMARY KEY %s)" % (self.name, fields, keys)
81
82    def create(self, db):
83        v = db.OpenView(self.sql())
84        v.Execute(None)
85        v.Close()
86
87class _Unspecified:pass
88def change_sequence(seq, action, seqno=_Unspecified, cond = _Unspecified):
89    "Change the sequence number of an action in a sequence list"
90    for i in range(len(seq)):
91        if seq[i][0] == action:
92            if cond is _Unspecified:
93                cond = seq[i][1]
94            if seqno is _Unspecified:
95                seqno = seq[i][2]
96            seq[i] = (action, cond, seqno)
97            return
98    raise ValueError, "Action not found in sequence"
99
100def add_data(db, table, values):
101    v = db.OpenView("SELECT * FROM `%s`" % table)
102    count = v.GetColumnInfo(MSICOLINFO_NAMES).GetFieldCount()
103    r = CreateRecord(count)
104    for value in values:
105        assert len(value) == count, value
106        for i in range(count):
107            field = value[i]
108            if isinstance(field, (int, long)):
109                r.SetInteger(i+1,field)
110            elif isinstance(field, basestring):
111                r.SetString(i+1,field)
112            elif field is None:
113                pass
114            elif isinstance(field, Binary):
115                r.SetStream(i+1, field.name)
116            else:
117                raise TypeError, "Unsupported type %s" % field.__class__.__name__
118        try:
119            v.Modify(MSIMODIFY_INSERT, r)
120        except Exception, e:
121            raise MSIError("Could not insert "+repr(values)+" into "+table)
122
123        r.ClearData()
124    v.Close()
125
126
127def add_stream(db, name, path):
128    v = db.OpenView("INSERT INTO _Streams (Name, Data) VALUES ('%s', ?)" % name)
129    r = CreateRecord(1)
130    r.SetStream(1, path)
131    v.Execute(r)
132    v.Close()
133
134def init_database(name, schema,
135                  ProductName, ProductCode, ProductVersion,
136                  Manufacturer):
137    try:
138        os.unlink(name)
139    except OSError:
140        pass
141    ProductCode = ProductCode.upper()
142    # Create the database
143    db = OpenDatabase(name, MSIDBOPEN_CREATE)
144    # Create the tables
145    for t in schema.tables:
146        t.create(db)
147    # Fill the validation table
148    add_data(db, "_Validation", schema._Validation_records)
149    # Initialize the summary information, allowing atmost 20 properties
150    si = db.GetSummaryInformation(20)
151    si.SetProperty(PID_TITLE, "Installation Database")
152    si.SetProperty(PID_SUBJECT, ProductName)
153    si.SetProperty(PID_AUTHOR, Manufacturer)
154    if Itanium:
155        si.SetProperty(PID_TEMPLATE, "Intel64;1033")
156    elif AMD64:
157        si.SetProperty(PID_TEMPLATE, "x64;1033")
158    else:
159        si.SetProperty(PID_TEMPLATE, "Intel;1033")
160    si.SetProperty(PID_REVNUMBER, gen_uuid())
161    si.SetProperty(PID_WORDCOUNT, 2) # long file names, compressed, original media
162    si.SetProperty(PID_PAGECOUNT, 200)
163    si.SetProperty(PID_APPNAME, "Python MSI Library")
164    # XXX more properties
165    si.Persist()
166    add_data(db, "Property", [
167        ("ProductName", ProductName),
168        ("ProductCode", ProductCode),
169        ("ProductVersion", ProductVersion),
170        ("Manufacturer", Manufacturer),
171        ("ProductLanguage", "1033")])
172    db.Commit()
173    return db
174
175def add_tables(db, module):
176    for table in module.tables:
177        add_data(db, table, getattr(module, table))
178
179def make_id(str):
180    identifier_chars = string.ascii_letters + string.digits + "._"
181    str = "".join([c if c in identifier_chars else "_" for c in str])
182    if str[0] in (string.digits + "."):
183        str = "_" + str
184    assert re.match("^[A-Za-z_][A-Za-z0-9_.]*$", str), "FILE"+str
185    return str
186
187def gen_uuid():
188    return "{"+UuidCreate().upper()+"}"
189
190class CAB:
191    def __init__(self, name):
192        self.name = name
193        self.files = []
194        self.filenames = set()
195        self.index = 0
196
197    def gen_id(self, file):
198        logical = _logical = make_id(file)
199        pos = 1
200        while logical in self.filenames:
201            logical = "%s.%d" % (_logical, pos)
202            pos += 1
203        self.filenames.add(logical)
204        return logical
205
206    def append(self, full, file, logical):
207        if os.path.isdir(full):
208            return
209        if not logical:
210            logical = self.gen_id(file)
211        self.index += 1
212        self.files.append((full, logical))
213        return self.index, logical
214
215    def commit(self, db):
216        from tempfile import mktemp
217        filename = mktemp()
218        FCICreate(filename, self.files)
219        add_data(db, "Media",
220                [(1, self.index, None, "#"+self.name, None, None)])
221        add_stream(db, self.name, filename)
222        os.unlink(filename)
223        db.Commit()
224
225_directories = set()
226class Directory:
227    def __init__(self, db, cab, basedir, physical, _logical, default, componentflags=None):
228        """Create a new directory in the Directory table. There is a current component
229        at each point in time for the directory, which is either explicitly created
230        through start_component, or implicitly when files are added for the first
231        time. Files are added into the current component, and into the cab file.
232        To create a directory, a base directory object needs to be specified (can be
233        None), the path to the physical directory, and a logical directory name.
234        Default specifies the DefaultDir slot in the directory table. componentflags
235        specifies the default flags that new components get."""
236        index = 1
237        _logical = make_id(_logical)
238        logical = _logical
239        while logical in _directories:
240            logical = "%s%d" % (_logical, index)
241            index += 1
242        _directories.add(logical)
243        self.db = db
244        self.cab = cab
245        self.basedir = basedir
246        self.physical = physical
247        self.logical = logical
248        self.component = None
249        self.short_names = set()
250        self.ids = set()
251        self.keyfiles = {}
252        self.componentflags = componentflags
253        if basedir:
254            self.absolute = os.path.join(basedir.absolute, physical)
255            blogical = basedir.logical
256        else:
257            self.absolute = physical
258            blogical = None
259        add_data(db, "Directory", [(logical, blogical, default)])
260
261    def start_component(self, component = None, feature = None, flags = None, keyfile = None, uuid=None):
262        """Add an entry to the Component table, and make this component the current for this
263        directory. If no component name is given, the directory name is used. If no feature
264        is given, the current feature is used. If no flags are given, the directory's default
265        flags are used. If no keyfile is given, the KeyPath is left null in the Component
266        table."""
267        if flags is None:
268            flags = self.componentflags
269        if uuid is None:
270            uuid = gen_uuid()
271        else:
272            uuid = uuid.upper()
273        if component is None:
274            component = self.logical
275        self.component = component
276        if Win64:
277            flags |= 256
278        if keyfile:
279            keyid = self.cab.gen_id(self.absolute, keyfile)
280            self.keyfiles[keyfile] = keyid
281        else:
282            keyid = None
283        add_data(self.db, "Component",
284                        [(component, uuid, self.logical, flags, None, keyid)])
285        if feature is None:
286            feature = current_feature
287        add_data(self.db, "FeatureComponents",
288                        [(feature.id, component)])
289
290    def make_short(self, file):
291        oldfile = file
292        file = file.replace('+', '_')
293        file = ''.join(c for c in file if not c in ' "/\[]:;=,')
294        parts = file.split(".")
295        if len(parts) > 1:
296            prefix = "".join(parts[:-1]).upper()
297            suffix = parts[-1].upper()
298            if not prefix:
299                prefix = suffix
300                suffix = None
301        else:
302            prefix = file.upper()
303            suffix = None
304        if len(parts) < 3 and len(prefix) <= 8 and file == oldfile and (
305                                                not suffix or len(suffix) <= 3):
306            if suffix:
307                file = prefix+"."+suffix
308            else:
309                file = prefix
310        else:
311            file = None
312        if file is None or file in self.short_names:
313            prefix = prefix[:6]
314            if suffix:
315                suffix = suffix[:3]
316            pos = 1
317            while 1:
318                if suffix:
319                    file = "%s~%d.%s" % (prefix, pos, suffix)
320                else:
321                    file = "%s~%d" % (prefix, pos)
322                if file not in self.short_names: break
323                pos += 1
324                assert pos < 10000
325                if pos in (10, 100, 1000):
326                    prefix = prefix[:-1]
327        self.short_names.add(file)
328        assert not re.search(r'[\?|><:/*"+,;=\[\]]', file) # restrictions on short names
329        return file
330
331    def add_file(self, file, src=None, version=None, language=None):
332        """Add a file to the current component of the directory, starting a new one
333        if there is no current component. By default, the file name in the source
334        and the file table will be identical. If the src file is specified, it is
335        interpreted relative to the current directory. Optionally, a version and a
336        language can be specified for the entry in the File table."""
337        if not self.component:
338            self.start_component(self.logical, current_feature, 0)
339        if not src:
340            # Allow relative paths for file if src is not specified
341            src = file
342            file = os.path.basename(file)
343        absolute = os.path.join(self.absolute, src)
344        assert not re.search(r'[\?|><:/*]"', file) # restrictions on long names
345        if file in self.keyfiles:
346            logical = self.keyfiles[file]
347        else:
348            logical = None
349        sequence, logical = self.cab.append(absolute, file, logical)
350        assert logical not in self.ids
351        self.ids.add(logical)
352        short = self.make_short(file)
353        full = "%s|%s" % (short, file)
354        filesize = os.stat(absolute).st_size
355        # constants.msidbFileAttributesVital
356        # Compressed omitted, since it is the database default
357        # could add r/o, system, hidden
358        attributes = 512
359        add_data(self.db, "File",
360                        [(logical, self.component, full, filesize, version,
361                         language, attributes, sequence)])
362        #if not version:
363        #    # Add hash if the file is not versioned
364        #    filehash = FileHash(absolute, 0)
365        #    add_data(self.db, "MsiFileHash",
366        #             [(logical, 0, filehash.IntegerData(1),
367        #               filehash.IntegerData(2), filehash.IntegerData(3),
368        #               filehash.IntegerData(4))])
369        # Automatically remove .pyc/.pyo files on uninstall (2)
370        # XXX: adding so many RemoveFile entries makes installer unbelievably
371        # slow. So instead, we have to use wildcard remove entries
372        if file.endswith(".py"):
373            add_data(self.db, "RemoveFile",
374                      [(logical+"c", self.component, "%sC|%sc" % (short, file),
375                        self.logical, 2),
376                       (logical+"o", self.component, "%sO|%so" % (short, file),
377                        self.logical, 2)])
378        return logical
379
380    def glob(self, pattern, exclude = None):
381        """Add a list of files to the current component as specified in the
382        glob pattern. Individual files can be excluded in the exclude list."""
383        files = glob.glob1(self.absolute, pattern)
384        for f in files:
385            if exclude and f in exclude: continue
386            self.add_file(f)
387        return files
388
389    def remove_pyc(self):
390        "Remove .pyc/.pyo files on uninstall"
391        add_data(self.db, "RemoveFile",
392                 [(self.component+"c", self.component, "*.pyc", self.logical, 2),
393                  (self.component+"o", self.component, "*.pyo", self.logical, 2)])
394
395class Binary:
396    def __init__(self, fname):
397        self.name = fname
398    def __repr__(self):
399        return 'msilib.Binary(os.path.join(dirname,"%s"))' % self.name
400
401class Feature:
402    def __init__(self, db, id, title, desc, display, level = 1,
403                 parent=None, directory = None, attributes=0):
404        self.id = id
405        if parent:
406            parent = parent.id
407        add_data(db, "Feature",
408                        [(id, parent, title, desc, display,
409                          level, directory, attributes)])
410    def set_current(self):
411        global current_feature
412        current_feature = self
413
414class Control:
415    def __init__(self, dlg, name):
416        self.dlg = dlg
417        self.name = name
418
419    def event(self, event, argument, condition = "1", ordering = None):
420        add_data(self.dlg.db, "ControlEvent",
421                 [(self.dlg.name, self.name, event, argument,
422                   condition, ordering)])
423
424    def mapping(self, event, attribute):
425        add_data(self.dlg.db, "EventMapping",
426                 [(self.dlg.name, self.name, event, attribute)])
427
428    def condition(self, action, condition):
429        add_data(self.dlg.db, "ControlCondition",
430                 [(self.dlg.name, self.name, action, condition)])
431
432class RadioButtonGroup(Control):
433    def __init__(self, dlg, name, property):
434        self.dlg = dlg
435        self.name = name
436        self.property = property
437        self.index = 1
438
439    def add(self, name, x, y, w, h, text, value = None):
440        if value is None:
441            value = name
442        add_data(self.dlg.db, "RadioButton",
443                 [(self.property, self.index, value,
444                   x, y, w, h, text, None)])
445        self.index += 1
446
447class Dialog:
448    def __init__(self, db, name, x, y, w, h, attr, title, first, default, cancel):
449        self.db = db
450        self.name = name
451        self.x, self.y, self.w, self.h = x,y,w,h
452        add_data(db, "Dialog", [(name, x,y,w,h,attr,title,first,default,cancel)])
453
454    def control(self, name, type, x, y, w, h, attr, prop, text, next, help):
455        add_data(self.db, "Control",
456                 [(self.name, name, type, x, y, w, h, attr, prop, text, next, help)])
457        return Control(self, name)
458
459    def text(self, name, x, y, w, h, attr, text):
460        return self.control(name, "Text", x, y, w, h, attr, None,
461                     text, None, None)
462
463    def bitmap(self, name, x, y, w, h, text):
464        return self.control(name, "Bitmap", x, y, w, h, 1, None, text, None, None)
465
466    def line(self, name, x, y, w, h):
467        return self.control(name, "Line", x, y, w, h, 1, None, None, None, None)
468
469    def pushbutton(self, name, x, y, w, h, attr, text, next):
470        return self.control(name, "PushButton", x, y, w, h, attr, None, text, next, None)
471
472    def radiogroup(self, name, x, y, w, h, attr, prop, text, next):
473        add_data(self.db, "Control",
474                 [(self.name, name, "RadioButtonGroup",
475                   x, y, w, h, attr, prop, text, next, None)])
476        return RadioButtonGroup(self, name, prop)
477
478    def checkbox(self, name, x, y, w, h, attr, prop, text, next):
479        return self.control(name, "CheckBox", x, y, w, h, attr, prop, text, next, None)
480