1# Copyright (c) 2012 Amazon.com, Inc. or its affiliates.  All Rights Reserved
2#
3# Permission is hereby granted, free of charge, to any person obtaining a
4# copy of this software and associated documentation files (the
5# "Software"), to deal in the Software without restriction, including
6# without limitation the rights to use, copy, modify, merge, publish, dis-
7# tribute, sublicense, and/or sell copies of the Software, and to permit
8# persons to whom the Software is furnished to do so, subject to the fol-
9# lowing conditions:
10#
11# The above copyright notice and this permission notice shall be included
12# in all copies or substantial portions of the Software.
13#
14# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
16# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
17# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
18# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20# IN THE SOFTWARE.
21#
22import os
23import math
24import threading
25import hashlib
26import time
27import logging
28from boto.compat import Queue
29import binascii
30
31from boto.glacier.utils import DEFAULT_PART_SIZE, minimum_part_size, \
32                               chunk_hashes, tree_hash, bytes_to_hex
33from boto.glacier.exceptions import UploadArchiveError, \
34                                    DownloadArchiveError, \
35                                    TreeHashDoesNotMatchError
36
37
38_END_SENTINEL = object()
39log = logging.getLogger('boto.glacier.concurrent')
40
41
42class ConcurrentTransferer(object):
43    def __init__(self, part_size=DEFAULT_PART_SIZE, num_threads=10):
44        self._part_size = part_size
45        self._num_threads = num_threads
46        self._threads = []
47
48    def _calculate_required_part_size(self, total_size):
49        min_part_size_required = minimum_part_size(total_size)
50        if self._part_size >= min_part_size_required:
51            part_size = self._part_size
52        else:
53            part_size = min_part_size_required
54            log.debug("The part size specified (%s) is smaller than "
55                      "the minimum required part size.  Using a part "
56                      "size of: %s", self._part_size, part_size)
57        total_parts = int(math.ceil(total_size / float(part_size)))
58        return total_parts, part_size
59
60    def _shutdown_threads(self):
61        log.debug("Shutting down threads.")
62        for thread in self._threads:
63            thread.should_continue = False
64        for thread in self._threads:
65            thread.join()
66        log.debug("Threads have exited.")
67
68    def _add_work_items_to_queue(self, total_parts, worker_queue, part_size):
69        log.debug("Adding work items to queue.")
70        for i in range(total_parts):
71            worker_queue.put((i, part_size))
72        for i in range(self._num_threads):
73            worker_queue.put(_END_SENTINEL)
74
75
76class ConcurrentUploader(ConcurrentTransferer):
77    """Concurrently upload an archive to glacier.
78
79    This class uses a thread pool to concurrently upload an archive
80    to glacier using the multipart upload API.
81
82    The threadpool is completely managed by this class and is
83    transparent to the users of this class.
84
85    """
86    def __init__(self, api, vault_name, part_size=DEFAULT_PART_SIZE,
87                 num_threads=10):
88        """
89        :type api: :class:`boto.glacier.layer1.Layer1`
90        :param api: A layer1 glacier object.
91
92        :type vault_name: str
93        :param vault_name: The name of the vault.
94
95        :type part_size: int
96        :param part_size: The size, in bytes, of the chunks to use when uploading
97            the archive parts.  The part size must be a megabyte multiplied by
98            a power of two.
99
100        :type num_threads: int
101        :param num_threads: The number of threads to spawn for the thread pool.
102            The number of threads will control how much parts are being
103            concurrently uploaded.
104
105        """
106        super(ConcurrentUploader, self).__init__(part_size, num_threads)
107        self._api = api
108        self._vault_name = vault_name
109
110    def upload(self, filename, description=None):
111        """Concurrently create an archive.
112
113        The part_size value specified when the class was constructed
114        will be used *unless* it is smaller than the minimum required
115        part size needed for the size of the given file.  In that case,
116        the part size used will be the minimum part size required
117        to properly upload the given file.
118
119        :type file: str
120        :param file: The filename to upload
121
122        :type description: str
123        :param description: The description of the archive.
124
125        :rtype: str
126        :return: The archive id of the newly created archive.
127
128        """
129        total_size = os.stat(filename).st_size
130        total_parts, part_size = self._calculate_required_part_size(total_size)
131        hash_chunks = [None] * total_parts
132        worker_queue = Queue()
133        result_queue = Queue()
134        response = self._api.initiate_multipart_upload(self._vault_name,
135                                                       part_size,
136                                                       description)
137        upload_id = response['UploadId']
138        # The basic idea is to add the chunks (the offsets not the actual
139        # contents) to a work queue, start up a thread pool, let the crank
140        # through the items in the work queue, and then place their results
141        # in a result queue which we use to complete the multipart upload.
142        self._add_work_items_to_queue(total_parts, worker_queue, part_size)
143        self._start_upload_threads(result_queue, upload_id,
144                                   worker_queue, filename)
145        try:
146            self._wait_for_upload_threads(hash_chunks, result_queue,
147                                          total_parts)
148        except UploadArchiveError as e:
149            log.debug("An error occurred while uploading an archive, "
150                      "aborting multipart upload.")
151            self._api.abort_multipart_upload(self._vault_name, upload_id)
152            raise e
153        log.debug("Completing upload.")
154        response = self._api.complete_multipart_upload(
155            self._vault_name, upload_id, bytes_to_hex(tree_hash(hash_chunks)),
156            total_size)
157        log.debug("Upload finished.")
158        return response['ArchiveId']
159
160    def _wait_for_upload_threads(self, hash_chunks, result_queue, total_parts):
161        for _ in range(total_parts):
162            result = result_queue.get()
163            if isinstance(result, Exception):
164                log.debug("An error was found in the result queue, terminating "
165                          "threads: %s", result)
166                self._shutdown_threads()
167                raise UploadArchiveError("An error occurred while uploading "
168                                         "an archive: %s" % result)
169            # Each unit of work returns the tree hash for the given part
170            # number, which we use at the end to compute the tree hash of
171            # the entire archive.
172            part_number, tree_sha256 = result
173            hash_chunks[part_number] = tree_sha256
174        self._shutdown_threads()
175
176    def _start_upload_threads(self, result_queue, upload_id, worker_queue,
177                              filename):
178        log.debug("Starting threads.")
179        for _ in range(self._num_threads):
180            thread = UploadWorkerThread(self._api, self._vault_name, filename,
181                                        upload_id, worker_queue, result_queue)
182            time.sleep(0.2)
183            thread.start()
184            self._threads.append(thread)
185
186
187class TransferThread(threading.Thread):
188    def __init__(self, worker_queue, result_queue):
189        super(TransferThread, self).__init__()
190        self._worker_queue = worker_queue
191        self._result_queue = result_queue
192        # This value can be set externally by other objects
193        # to indicate that the thread should be shut down.
194        self.should_continue = True
195
196    def run(self):
197        while self.should_continue:
198            try:
199                work = self._worker_queue.get(timeout=1)
200            except Empty:
201                continue
202            if work is _END_SENTINEL:
203                self._cleanup()
204                return
205            result = self._process_chunk(work)
206            self._result_queue.put(result)
207        self._cleanup()
208
209    def _process_chunk(self, work):
210        pass
211
212    def _cleanup(self):
213        pass
214
215
216class UploadWorkerThread(TransferThread):
217    def __init__(self, api, vault_name, filename, upload_id,
218                 worker_queue, result_queue, num_retries=5,
219                 time_between_retries=5,
220                 retry_exceptions=Exception):
221        super(UploadWorkerThread, self).__init__(worker_queue, result_queue)
222        self._api = api
223        self._vault_name = vault_name
224        self._filename = filename
225        self._fileobj = open(filename, 'rb')
226        self._upload_id = upload_id
227        self._num_retries = num_retries
228        self._time_between_retries = time_between_retries
229        self._retry_exceptions = retry_exceptions
230
231    def _process_chunk(self, work):
232        result = None
233        for i in range(self._num_retries + 1):
234            try:
235                result = self._upload_chunk(work)
236                break
237            except self._retry_exceptions as e:
238                log.error("Exception caught uploading part number %s for "
239                          "vault %s, attempt: (%s / %s), filename: %s, "
240                          "exception: %s, msg: %s",
241                          work[0], self._vault_name, i + 1, self._num_retries + 1,
242                          self._filename, e.__class__, e)
243                time.sleep(self._time_between_retries)
244                result = e
245        return result
246
247    def _upload_chunk(self, work):
248        part_number, part_size = work
249        start_byte = part_number * part_size
250        self._fileobj.seek(start_byte)
251        contents = self._fileobj.read(part_size)
252        linear_hash = hashlib.sha256(contents).hexdigest()
253        tree_hash_bytes = tree_hash(chunk_hashes(contents))
254        byte_range = (start_byte, start_byte + len(contents) - 1)
255        log.debug("Uploading chunk %s of size %s", part_number, part_size)
256        response = self._api.upload_part(self._vault_name, self._upload_id,
257                                         linear_hash,
258                                         bytes_to_hex(tree_hash_bytes),
259                                         byte_range, contents)
260        # Reading the response allows the connection to be reused.
261        response.read()
262        return (part_number, tree_hash_bytes)
263
264    def _cleanup(self):
265        self._fileobj.close()
266
267
268class ConcurrentDownloader(ConcurrentTransferer):
269    """
270    Concurrently download an archive from glacier.
271
272    This class uses a thread pool to concurrently download an archive
273    from glacier.
274
275    The threadpool is completely managed by this class and is
276    transparent to the users of this class.
277
278    """
279    def __init__(self, job, part_size=DEFAULT_PART_SIZE,
280                 num_threads=10):
281        """
282        :param job: A layer2 job object for archive retrieval object.
283
284        :param part_size: The size, in bytes, of the chunks to use when uploading
285            the archive parts.  The part size must be a megabyte multiplied by
286            a power of two.
287
288        """
289        super(ConcurrentDownloader, self).__init__(part_size, num_threads)
290        self._job = job
291
292    def download(self, filename):
293        """
294        Concurrently download an archive.
295
296        :param filename: The filename to download the archive to
297        :type filename: str
298
299        """
300        total_size = self._job.archive_size
301        total_parts, part_size = self._calculate_required_part_size(total_size)
302        worker_queue = Queue()
303        result_queue = Queue()
304        self._add_work_items_to_queue(total_parts, worker_queue, part_size)
305        self._start_download_threads(result_queue, worker_queue)
306        try:
307            self._wait_for_download_threads(filename, result_queue, total_parts)
308        except DownloadArchiveError as e:
309            log.debug("An error occurred while downloading an archive: %s", e)
310            raise e
311        log.debug("Download completed.")
312
313    def _wait_for_download_threads(self, filename, result_queue, total_parts):
314        """
315        Waits until the result_queue is filled with all the downloaded parts
316        This indicates that all part downloads have completed
317
318        Saves downloaded parts into filename
319
320        :param filename:
321        :param result_queue:
322        :param total_parts:
323        """
324        hash_chunks = [None] * total_parts
325        with open(filename, "wb") as f:
326            for _ in range(total_parts):
327                result = result_queue.get()
328                if isinstance(result, Exception):
329                    log.debug("An error was found in the result queue, "
330                              "terminating threads: %s", result)
331                    self._shutdown_threads()
332                    raise DownloadArchiveError(
333                        "An error occurred while uploading "
334                        "an archive: %s" % result)
335                part_number, part_size, actual_hash, data = result
336                hash_chunks[part_number] = actual_hash
337                start_byte = part_number * part_size
338                f.seek(start_byte)
339                f.write(data)
340                f.flush()
341        final_hash = bytes_to_hex(tree_hash(hash_chunks))
342        log.debug("Verifying final tree hash of archive, expecting: %s, "
343                  "actual: %s", self._job.sha256_treehash, final_hash)
344        if self._job.sha256_treehash != final_hash:
345            self._shutdown_threads()
346            raise TreeHashDoesNotMatchError(
347                "Tree hash for entire archive does not match, "
348                "expected: %s, got: %s" % (self._job.sha256_treehash,
349                                           final_hash))
350        self._shutdown_threads()
351
352    def _start_download_threads(self, result_queue, worker_queue):
353        log.debug("Starting threads.")
354        for _ in range(self._num_threads):
355            thread = DownloadWorkerThread(self._job, worker_queue, result_queue)
356            time.sleep(0.2)
357            thread.start()
358            self._threads.append(thread)
359
360
361class DownloadWorkerThread(TransferThread):
362    def __init__(self, job,
363                 worker_queue, result_queue,
364                 num_retries=5,
365                 time_between_retries=5,
366                 retry_exceptions=Exception):
367        """
368        Individual download thread that will download parts of the file from Glacier. Parts
369        to download stored in work queue.
370
371        Parts download to a temp dir with each part a separate file
372
373        :param job: Glacier job object
374        :param work_queue: A queue of tuples which include the part_number and
375            part_size
376        :param result_queue: A priority queue of tuples which include the
377            part_number and the path to the temp file that holds that
378            part's data.
379
380        """
381        super(DownloadWorkerThread, self).__init__(worker_queue, result_queue)
382        self._job = job
383        self._num_retries = num_retries
384        self._time_between_retries = time_between_retries
385        self._retry_exceptions = retry_exceptions
386
387    def _process_chunk(self, work):
388        """
389        Attempt to download a part of the archive from Glacier
390        Store the result in the result_queue
391
392        :param work:
393        """
394        result = None
395        for _ in range(self._num_retries):
396            try:
397                result = self._download_chunk(work)
398                break
399            except self._retry_exceptions as e:
400                log.error("Exception caught downloading part number %s for "
401                          "job %s", work[0], self._job,)
402                time.sleep(self._time_between_retries)
403                result = e
404        return result
405
406    def _download_chunk(self, work):
407        """
408        Downloads a chunk of archive from Glacier. Saves the data to a temp file
409        Returns the part number and temp file location
410
411        :param work:
412        """
413        part_number, part_size = work
414        start_byte = part_number * part_size
415        byte_range = (start_byte, start_byte + part_size - 1)
416        log.debug("Downloading chunk %s of size %s", part_number, part_size)
417        response = self._job.get_output(byte_range)
418        data = response.read()
419        actual_hash = bytes_to_hex(tree_hash(chunk_hashes(data)))
420        if response['TreeHash'] != actual_hash:
421            raise TreeHashDoesNotMatchError(
422                "Tree hash for part number %s does not match, "
423                "expected: %s, got: %s" % (part_number, response['TreeHash'],
424                                           actual_hash))
425        return (part_number, part_size, binascii.unhexlify(actual_hash), data)
426