1# Copyright (c) 2006-2009 Mitch Garnaat http://garnaat.org/
2# Copyright (c) 2010 Chris Moyer http://coredumped.org/
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the
6# "Software"), to deal in the Software without restriction, including
7# without limitation the rights to use, copy, modify, merge, publish, dis-
8# tribute, sublicense, and/or sell copies of the Software, and to permit
9# persons to whom the Software is furnished to do so, subject to the fol-
10# lowing conditions:
11#
12# The above copyright notice and this permission notice shall be included
13# in all copies or substantial portions of the Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
16# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
17# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
18# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
19# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22"""
23High-level abstraction of an EC2 server
24"""
25
26import boto.ec2
27from boto.mashups.iobject import IObject
28from boto.pyami.config import BotoConfigPath, Config
29from boto.sdb.db.model import Model
30from boto.sdb.db.property import StringProperty, IntegerProperty, BooleanProperty, CalculatedProperty
31from boto.manage import propget
32from boto.ec2.zone import Zone
33from boto.ec2.keypair import KeyPair
34import os, time
35from contextlib import closing
36from boto.exception import EC2ResponseError
37from boto.compat import six, StringIO
38
39InstanceTypes = ['m1.small', 'm1.large', 'm1.xlarge',
40                 'c1.medium', 'c1.xlarge',
41                 'm2.2xlarge', 'm2.4xlarge']
42
43class Bundler(object):
44
45    def __init__(self, server, uname='root'):
46        from boto.manage.cmdshell import SSHClient
47        self.server = server
48        self.uname = uname
49        self.ssh_client = SSHClient(server, uname=uname)
50
51    def copy_x509(self, key_file, cert_file):
52        print('\tcopying cert and pk over to /mnt directory on server')
53        self.ssh_client.open_sftp()
54        path, name = os.path.split(key_file)
55        self.remote_key_file = '/mnt/%s' % name
56        self.ssh_client.put_file(key_file, self.remote_key_file)
57        path, name = os.path.split(cert_file)
58        self.remote_cert_file = '/mnt/%s' % name
59        self.ssh_client.put_file(cert_file, self.remote_cert_file)
60        print('...complete!')
61
62    def bundle_image(self, prefix, size, ssh_key):
63        command = ""
64        if self.uname != 'root':
65            command = "sudo "
66        command += 'ec2-bundle-vol '
67        command += '-c %s -k %s ' % (self.remote_cert_file, self.remote_key_file)
68        command += '-u %s ' % self.server._reservation.owner_id
69        command += '-p %s ' % prefix
70        command += '-s %d ' % size
71        command += '-d /mnt '
72        if self.server.instance_type == 'm1.small' or self.server.instance_type == 'c1.medium':
73            command += '-r i386'
74        else:
75            command += '-r x86_64'
76        return command
77
78    def upload_bundle(self, bucket, prefix, ssh_key):
79        command = ""
80        if self.uname != 'root':
81            command = "sudo "
82        command += 'ec2-upload-bundle '
83        command += '-m /mnt/%s.manifest.xml ' % prefix
84        command += '-b %s ' % bucket
85        command += '-a %s ' % self.server.ec2.aws_access_key_id
86        command += '-s %s ' % self.server.ec2.aws_secret_access_key
87        return command
88
89    def bundle(self, bucket=None, prefix=None, key_file=None, cert_file=None,
90               size=None, ssh_key=None, fp=None, clear_history=True):
91        iobject = IObject()
92        if not bucket:
93            bucket = iobject.get_string('Name of S3 bucket')
94        if not prefix:
95            prefix = iobject.get_string('Prefix for AMI file')
96        if not key_file:
97            key_file = iobject.get_filename('Path to RSA private key file')
98        if not cert_file:
99            cert_file = iobject.get_filename('Path to RSA public cert file')
100        if not size:
101            size = iobject.get_int('Size (in MB) of bundled image')
102        if not ssh_key:
103            ssh_key = self.server.get_ssh_key_file()
104        self.copy_x509(key_file, cert_file)
105        if not fp:
106            fp = StringIO()
107        fp.write('sudo mv %s /mnt/boto.cfg; ' % BotoConfigPath)
108        fp.write('mv ~/.ssh/authorized_keys /mnt/authorized_keys; ')
109        if clear_history:
110            fp.write('history -c; ')
111        fp.write(self.bundle_image(prefix, size, ssh_key))
112        fp.write('; ')
113        fp.write(self.upload_bundle(bucket, prefix, ssh_key))
114        fp.write('; ')
115        fp.write('sudo mv /mnt/boto.cfg %s; ' % BotoConfigPath)
116        fp.write('mv /mnt/authorized_keys ~/.ssh/authorized_keys')
117        command = fp.getvalue()
118        print('running the following command on the remote server:')
119        print(command)
120        t = self.ssh_client.run(command)
121        print('\t%s' % t[0])
122        print('\t%s' % t[1])
123        print('...complete!')
124        print('registering image...')
125        self.image_id = self.server.ec2.register_image(name=prefix, image_location='%s/%s.manifest.xml' % (bucket, prefix))
126        return self.image_id
127
128class CommandLineGetter(object):
129
130    def get_ami_list(self):
131        my_amis = []
132        for ami in self.ec2.get_all_images():
133            # hack alert, need a better way to do this!
134            if ami.location.find('pyami') >= 0:
135                my_amis.append((ami.location, ami))
136        return my_amis
137
138    def get_region(self, params):
139        region = params.get('region', None)
140        if isinstance(region, basestring):
141            region = boto.ec2.get_region(region)
142            params['region'] = region
143        if not region:
144            prop = self.cls.find_property('region_name')
145            params['region'] = propget.get(prop, choices=boto.ec2.regions)
146        self.ec2 = params['region'].connect()
147
148    def get_name(self, params):
149        if not params.get('name', None):
150            prop = self.cls.find_property('name')
151            params['name'] = propget.get(prop)
152
153    def get_description(self, params):
154        if not params.get('description', None):
155            prop = self.cls.find_property('description')
156            params['description'] = propget.get(prop)
157
158    def get_instance_type(self, params):
159        if not params.get('instance_type', None):
160            prop = StringProperty(name='instance_type', verbose_name='Instance Type',
161                                  choices=InstanceTypes)
162            params['instance_type'] = propget.get(prop)
163
164    def get_quantity(self, params):
165        if not params.get('quantity', None):
166            prop = IntegerProperty(name='quantity', verbose_name='Number of Instances')
167            params['quantity'] = propget.get(prop)
168
169    def get_zone(self, params):
170        if not params.get('zone', None):
171            prop = StringProperty(name='zone', verbose_name='EC2 Availability Zone',
172                                  choices=self.ec2.get_all_zones)
173            params['zone'] = propget.get(prop)
174
175    def get_ami_id(self, params):
176        valid = False
177        while not valid:
178            ami = params.get('ami', None)
179            if not ami:
180                prop = StringProperty(name='ami', verbose_name='AMI')
181                ami = propget.get(prop)
182            try:
183                rs = self.ec2.get_all_images([ami])
184                if len(rs) == 1:
185                    valid = True
186                    params['ami'] = rs[0]
187            except EC2ResponseError:
188                pass
189
190    def get_group(self, params):
191        group = params.get('group', None)
192        if isinstance(group, basestring):
193            group_list = self.ec2.get_all_security_groups()
194            for g in group_list:
195                if g.name == group:
196                    group = g
197                    params['group'] = g
198        if not group:
199            prop = StringProperty(name='group', verbose_name='EC2 Security Group',
200                                  choices=self.ec2.get_all_security_groups)
201            params['group'] = propget.get(prop)
202
203    def get_key(self, params):
204        keypair = params.get('keypair', None)
205        if isinstance(keypair, basestring):
206            key_list = self.ec2.get_all_key_pairs()
207            for k in key_list:
208                if k.name == keypair:
209                    keypair = k.name
210                    params['keypair'] = k.name
211        if not keypair:
212            prop = StringProperty(name='keypair', verbose_name='EC2 KeyPair',
213                                  choices=self.ec2.get_all_key_pairs)
214            params['keypair'] = propget.get(prop).name
215
216    def get(self, cls, params):
217        self.cls = cls
218        self.get_region(params)
219        self.ec2 = params['region'].connect()
220        self.get_name(params)
221        self.get_description(params)
222        self.get_instance_type(params)
223        self.get_zone(params)
224        self.get_quantity(params)
225        self.get_ami_id(params)
226        self.get_group(params)
227        self.get_key(params)
228
229class Server(Model):
230
231    #
232    # The properties of this object consists of real properties for data that
233    # is not already stored in EC2 somewhere (e.g. name, description) plus
234    # calculated properties for all of the properties that are already in
235    # EC2 (e.g. hostname, security groups, etc.)
236    #
237    name = StringProperty(unique=True, verbose_name="Name")
238    description = StringProperty(verbose_name="Description")
239    region_name = StringProperty(verbose_name="EC2 Region Name")
240    instance_id = StringProperty(verbose_name="EC2 Instance ID")
241    elastic_ip = StringProperty(verbose_name="EC2 Elastic IP Address")
242    production = BooleanProperty(verbose_name="Is This Server Production", default=False)
243    ami_id = CalculatedProperty(verbose_name="AMI ID", calculated_type=str, use_method=True)
244    zone = CalculatedProperty(verbose_name="Availability Zone Name", calculated_type=str, use_method=True)
245    hostname = CalculatedProperty(verbose_name="Public DNS Name", calculated_type=str, use_method=True)
246    private_hostname = CalculatedProperty(verbose_name="Private DNS Name", calculated_type=str, use_method=True)
247    groups = CalculatedProperty(verbose_name="Security Groups", calculated_type=list, use_method=True)
248    security_group = CalculatedProperty(verbose_name="Primary Security Group Name", calculated_type=str, use_method=True)
249    key_name = CalculatedProperty(verbose_name="Key Name", calculated_type=str, use_method=True)
250    instance_type = CalculatedProperty(verbose_name="Instance Type", calculated_type=str, use_method=True)
251    status = CalculatedProperty(verbose_name="Current Status", calculated_type=str, use_method=True)
252    launch_time = CalculatedProperty(verbose_name="Server Launch Time", calculated_type=str, use_method=True)
253    console_output = CalculatedProperty(verbose_name="Console Output", calculated_type=open, use_method=True)
254
255    packages = []
256    plugins = []
257
258    @classmethod
259    def add_credentials(cls, cfg, aws_access_key_id, aws_secret_access_key):
260        if not cfg.has_section('Credentials'):
261            cfg.add_section('Credentials')
262        cfg.set('Credentials', 'aws_access_key_id', aws_access_key_id)
263        cfg.set('Credentials', 'aws_secret_access_key', aws_secret_access_key)
264        if not cfg.has_section('DB_Server'):
265            cfg.add_section('DB_Server')
266        cfg.set('DB_Server', 'db_type', 'SimpleDB')
267        cfg.set('DB_Server', 'db_name', cls._manager.domain.name)
268
269    @classmethod
270    def create(cls, config_file=None, logical_volume = None, cfg = None, **params):
271        """
272        Create a new instance based on the specified configuration file or the specified
273        configuration and the passed in parameters.
274
275        If the config_file argument is not None, the configuration is read from there.
276        Otherwise, the cfg argument is used.
277
278        The config file may include other config files with a #import reference. The included
279        config files must reside in the same directory as the specified file.
280
281        The logical_volume argument, if supplied, will be used to get the current physical
282        volume ID and use that as an override of the value specified in the config file. This
283        may be useful for debugging purposes when you want to debug with a production config
284        file but a test Volume.
285
286        The dictionary argument may be used to override any EC2 configuration values in the
287        config file.
288        """
289        if config_file:
290            cfg = Config(path=config_file)
291        if cfg.has_section('EC2'):
292            # include any EC2 configuration values that aren't specified in params:
293            for option in cfg.options('EC2'):
294                if option not in params:
295                    params[option] = cfg.get('EC2', option)
296        getter = CommandLineGetter()
297        getter.get(cls, params)
298        region = params.get('region')
299        ec2 = region.connect()
300        cls.add_credentials(cfg, ec2.aws_access_key_id, ec2.aws_secret_access_key)
301        ami = params.get('ami')
302        kp = params.get('keypair')
303        group = params.get('group')
304        zone = params.get('zone')
305        # deal with possibly passed in logical volume:
306        if logical_volume != None:
307           cfg.set('EBS', 'logical_volume_name', logical_volume.name)
308        cfg_fp = StringIO()
309        cfg.write(cfg_fp)
310        # deal with the possibility that zone and/or keypair are strings read from the config file:
311        if isinstance(zone, Zone):
312            zone = zone.name
313        if isinstance(kp, KeyPair):
314            kp = kp.name
315        reservation = ami.run(min_count=1,
316                              max_count=params.get('quantity', 1),
317                              key_name=kp,
318                              security_groups=[group],
319                              instance_type=params.get('instance_type'),
320                              placement = zone,
321                              user_data = cfg_fp.getvalue())
322        l = []
323        i = 0
324        elastic_ip = params.get('elastic_ip')
325        instances = reservation.instances
326        if elastic_ip is not None and instances.__len__() > 0:
327            instance = instances[0]
328            print('Waiting for instance to start so we can set its elastic IP address...')
329            # Sometimes we get a message from ec2 that says that the instance does not exist.
330            # Hopefully the following delay will giv eec2 enough time to get to a stable state:
331            time.sleep(5)
332            while instance.update() != 'running':
333                time.sleep(1)
334            instance.use_ip(elastic_ip)
335            print('set the elastic IP of the first instance to %s' % elastic_ip)
336        for instance in instances:
337            s = cls()
338            s.ec2 = ec2
339            s.name = params.get('name') + '' if i==0 else str(i)
340            s.description = params.get('description')
341            s.region_name = region.name
342            s.instance_id = instance.id
343            if elastic_ip and i == 0:
344                s.elastic_ip = elastic_ip
345            s.put()
346            l.append(s)
347            i += 1
348        return l
349
350    @classmethod
351    def create_from_instance_id(cls, instance_id, name, description=''):
352        regions = boto.ec2.regions()
353        for region in regions:
354            ec2 = region.connect()
355            try:
356                rs = ec2.get_all_reservations([instance_id])
357            except:
358                rs = []
359            if len(rs) == 1:
360                s = cls()
361                s.ec2 = ec2
362                s.name = name
363                s.description = description
364                s.region_name = region.name
365                s.instance_id = instance_id
366                s._reservation = rs[0]
367                for instance in s._reservation.instances:
368                    if instance.id == instance_id:
369                        s._instance = instance
370                s.put()
371                return s
372        return None
373
374    @classmethod
375    def create_from_current_instances(cls):
376        servers = []
377        regions = boto.ec2.regions()
378        for region in regions:
379            ec2 = region.connect()
380            rs = ec2.get_all_reservations()
381            for reservation in rs:
382                for instance in reservation.instances:
383                    try:
384                        next(Server.find(instance_id=instance.id))
385                        boto.log.info('Server for %s already exists' % instance.id)
386                    except StopIteration:
387                        s = cls()
388                        s.ec2 = ec2
389                        s.name = instance.id
390                        s.region_name = region.name
391                        s.instance_id = instance.id
392                        s._reservation = reservation
393                        s.put()
394                        servers.append(s)
395        return servers
396
397    def __init__(self, id=None, **kw):
398        super(Server, self).__init__(id, **kw)
399        self.ssh_key_file = None
400        self.ec2 = None
401        self._cmdshell = None
402        self._reservation = None
403        self._instance = None
404        self._setup_ec2()
405
406    def _setup_ec2(self):
407        if self.ec2 and self._instance and self._reservation:
408            return
409        if self.id:
410            if self.region_name:
411                for region in boto.ec2.regions():
412                    if region.name == self.region_name:
413                        self.ec2 = region.connect()
414                        if self.instance_id and not self._instance:
415                            try:
416                                rs = self.ec2.get_all_reservations([self.instance_id])
417                                if len(rs) >= 1:
418                                    for instance in rs[0].instances:
419                                        if instance.id == self.instance_id:
420                                            self._reservation = rs[0]
421                                            self._instance = instance
422                            except EC2ResponseError:
423                                pass
424
425    def _status(self):
426        status = ''
427        if self._instance:
428            self._instance.update()
429            status = self._instance.state
430        return status
431
432    def _hostname(self):
433        hostname = ''
434        if self._instance:
435            hostname = self._instance.public_dns_name
436        return hostname
437
438    def _private_hostname(self):
439        hostname = ''
440        if self._instance:
441            hostname = self._instance.private_dns_name
442        return hostname
443
444    def _instance_type(self):
445        it = ''
446        if self._instance:
447            it = self._instance.instance_type
448        return it
449
450    def _launch_time(self):
451        lt = ''
452        if self._instance:
453            lt = self._instance.launch_time
454        return lt
455
456    def _console_output(self):
457        co = ''
458        if self._instance:
459            co = self._instance.get_console_output()
460        return co
461
462    def _groups(self):
463        gn = []
464        if self._reservation:
465            gn = self._reservation.groups
466        return gn
467
468    def _security_group(self):
469        groups = self._groups()
470        if len(groups) >= 1:
471            return groups[0].id
472        return ""
473
474    def _zone(self):
475        zone = None
476        if self._instance:
477            zone = self._instance.placement
478        return zone
479
480    def _key_name(self):
481        kn = None
482        if self._instance:
483            kn = self._instance.key_name
484        return kn
485
486    def put(self):
487        super(Server, self).put()
488        self._setup_ec2()
489
490    def delete(self):
491        if self.production:
492            raise ValueError("Can't delete a production server")
493        #self.stop()
494        super(Server, self).delete()
495
496    def stop(self):
497        if self.production:
498            raise ValueError("Can't delete a production server")
499        if self._instance:
500            self._instance.stop()
501
502    def terminate(self):
503        if self.production:
504            raise ValueError("Can't delete a production server")
505        if self._instance:
506            self._instance.terminate()
507
508    def reboot(self):
509        if self._instance:
510            self._instance.reboot()
511
512    def wait(self):
513        while self.status != 'running':
514            time.sleep(5)
515
516    def get_ssh_key_file(self):
517        if not self.ssh_key_file:
518            ssh_dir = os.path.expanduser('~/.ssh')
519            if os.path.isdir(ssh_dir):
520                ssh_file = os.path.join(ssh_dir, '%s.pem' % self.key_name)
521                if os.path.isfile(ssh_file):
522                    self.ssh_key_file = ssh_file
523            if not self.ssh_key_file:
524                iobject = IObject()
525                self.ssh_key_file = iobject.get_filename('Path to OpenSSH Key file')
526        return self.ssh_key_file
527
528    def get_cmdshell(self):
529        if not self._cmdshell:
530            from boto.manage import cmdshell
531            self.get_ssh_key_file()
532            self._cmdshell = cmdshell.start(self)
533        return self._cmdshell
534
535    def reset_cmdshell(self):
536        self._cmdshell = None
537
538    def run(self, command):
539        with closing(self.get_cmdshell()) as cmd:
540            status = cmd.run(command)
541        return status
542
543    def get_bundler(self, uname='root'):
544        self.get_ssh_key_file()
545        return Bundler(self, uname)
546
547    def get_ssh_client(self, uname='root', ssh_pwd=None):
548        from boto.manage.cmdshell import SSHClient
549        self.get_ssh_key_file()
550        return SSHClient(self, uname=uname, ssh_pwd=ssh_pwd)
551
552    def install(self, pkg):
553        return self.run('apt-get -y install %s' % pkg)
554
555
556
557