1# -*- coding: utf-8 -*-
2# Copyright 2013 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Contains gsutil base unit test case class."""
16
17from __future__ import absolute_import
18
19import logging
20import os
21import sys
22import tempfile
23
24import boto
25from gslib import wildcard_iterator
26from gslib.boto_translation import BotoTranslation
27from gslib.cloud_api_delegator import CloudApiDelegator
28from gslib.command_runner import CommandRunner
29from gslib.cs_api_map import ApiMapConstants
30from gslib.cs_api_map import ApiSelector
31from gslib.tests.mock_logging_handler import MockLoggingHandler
32from gslib.tests.testcase import base
33import gslib.tests.util as util
34from gslib.tests.util import unittest
35from gslib.tests.util import WorkingDirectory
36from gslib.util import GsutilStreamHandler
37
38
39class GsutilApiUnitTestClassMapFactory(object):
40  """Class map factory for use in unit tests.
41
42  BotoTranslation is used for all cases so that GSMockBucketStorageUri can
43  be used to communicate with the mock XML service.
44  """
45
46  @classmethod
47  def GetClassMap(cls):
48    """Returns a class map for use in unit tests."""
49    gs_class_map = {
50        ApiSelector.XML: BotoTranslation,
51        ApiSelector.JSON: BotoTranslation
52    }
53    s3_class_map = {
54        ApiSelector.XML: BotoTranslation
55    }
56    class_map = {
57        'gs': gs_class_map,
58        's3': s3_class_map
59    }
60    return class_map
61
62
63@unittest.skipUnless(util.RUN_UNIT_TESTS,
64                     'Not running integration tests.')
65class GsUtilUnitTestCase(base.GsUtilTestCase):
66  """Base class for gsutil unit tests."""
67
68  @classmethod
69  def setUpClass(cls):
70    base.GsUtilTestCase.setUpClass()
71    cls.mock_bucket_storage_uri = util.GSMockBucketStorageUri
72    cls.mock_gsutil_api_class_map_factory = GsutilApiUnitTestClassMapFactory
73    cls.logger = logging.getLogger()
74    cls.command_runner = CommandRunner(
75        bucket_storage_uri_class=cls.mock_bucket_storage_uri,
76        gsutil_api_class_map_factory=cls.mock_gsutil_api_class_map_factory)
77
78  def setUp(self):
79    super(GsUtilUnitTestCase, self).setUp()
80    self.bucket_uris = []
81    self.stdout_save = sys.stdout
82    self.stderr_save = sys.stderr
83    fd, self.stdout_file = tempfile.mkstemp()
84    sys.stdout = os.fdopen(fd, 'w+')
85    fd, self.stderr_file = tempfile.mkstemp()
86    sys.stderr = os.fdopen(fd, 'w+')
87    self.accumulated_stdout = []
88    self.accumulated_stderr = []
89
90    self.root_logger = logging.getLogger()
91    self.is_debugging = self.root_logger.isEnabledFor(logging.DEBUG)
92    self.log_handlers_save = self.root_logger.handlers
93    fd, self.log_handler_file = tempfile.mkstemp()
94    self.log_handler_stream = os.fdopen(fd, 'w+')
95    self.temp_log_handler = GsutilStreamHandler(self.log_handler_stream)
96    self.root_logger.handlers = [self.temp_log_handler]
97
98  def tearDown(self):
99    super(GsUtilUnitTestCase, self).tearDown()
100
101    self.root_logger.handlers = self.log_handlers_save
102    self.temp_log_handler.flush()
103    self.temp_log_handler.close()
104    self.log_handler_stream.seek(0)
105    log_output = self.log_handler_stream.read()
106    self.log_handler_stream.close()
107    os.unlink(self.log_handler_file)
108
109    sys.stdout.seek(0)
110    sys.stderr.seek(0)
111    stdout = sys.stdout.read()
112    stderr = sys.stderr.read()
113    stdout += ''.join(self.accumulated_stdout)
114    stderr += ''.join(self.accumulated_stderr)
115    sys.stdout.close()
116    sys.stderr.close()
117    sys.stdout = self.stdout_save
118    sys.stderr = self.stderr_save
119    os.unlink(self.stdout_file)
120    os.unlink(self.stderr_file)
121
122    if self.is_debugging and stdout:
123      sys.stderr.write('==== stdout %s ====\n' % self.id())
124      sys.stderr.write(stdout)
125      sys.stderr.write('==== end stdout ====\n')
126    if self.is_debugging and stderr:
127      sys.stderr.write('==== stderr %s ====\n' % self.id())
128      sys.stderr.write(stderr)
129      sys.stderr.write('==== end stderr ====\n')
130    if self.is_debugging and log_output:
131      sys.stderr.write('==== log output %s ====\n' % self.id())
132      sys.stderr.write(log_output)
133      sys.stderr.write('==== end log output ====\n')
134
135  def RunCommand(self, command_name, args=None, headers=None, debug=0,
136                 return_stdout=False, return_stderr=False,
137                 return_log_handler=False, cwd=None):
138    """Method for calling gslib.command_runner.CommandRunner.
139
140    Passes parallel_operations=False for all tests, optionally saving/returning
141    stdout output. We run all tests multi-threaded, to exercise those more
142    complicated code paths.
143    TODO: Change to run with parallel_operations=True for all tests. At
144    present when you do this it causes many test failures.
145
146    Args:
147      command_name: The name of the command being run.
148      args: Command-line args (arg0 = actual arg, not command name ala bash).
149      headers: Dictionary containing optional HTTP headers to pass to boto.
150      debug: Debug level to pass in to boto connection (range 0..3).
151      return_stdout: If True, will save and return stdout produced by command.
152      return_stderr: If True, will save and return stderr produced by command.
153      return_log_handler: If True, will return a MockLoggingHandler instance
154           that was attached to the command's logger while running.
155      cwd: The working directory that should be switched to before running the
156           command. The working directory will be reset back to its original
157           value after running the command. If not specified, the working
158           directory is left unchanged.
159
160    Returns:
161      One or a tuple of requested return values, depending on whether
162      return_stdout, return_stderr, and/or return_log_handler were specified.
163    """
164    args = args or []
165
166    command_line = ' '.join([command_name] + args)
167    if self.is_debugging:
168      self.stderr_save.write('\nRunCommand of %s\n' % command_line)
169
170    # Save and truncate stdout and stderr for the lifetime of RunCommand. This
171    # way, we can return just the stdout and stderr that was output during the
172    # RunNamedCommand call below.
173    sys.stdout.seek(0)
174    sys.stderr.seek(0)
175    stdout = sys.stdout.read()
176    stderr = sys.stderr.read()
177    if stdout:
178      self.accumulated_stdout.append(stdout)
179    if stderr:
180      self.accumulated_stderr.append(stderr)
181    sys.stdout.seek(0)
182    sys.stderr.seek(0)
183    sys.stdout.truncate()
184    sys.stderr.truncate()
185
186    mock_log_handler = MockLoggingHandler()
187    logging.getLogger(command_name).addHandler(mock_log_handler)
188
189    try:
190      with WorkingDirectory(cwd):
191        self.command_runner.RunNamedCommand(
192            command_name, args=args, headers=headers, debug=debug,
193            parallel_operations=False, do_shutdown=False)
194    finally:
195      sys.stdout.seek(0)
196      stdout = sys.stdout.read()
197      sys.stderr.seek(0)
198      stderr = sys.stderr.read()
199      logging.getLogger(command_name).removeHandler(mock_log_handler)
200      mock_log_handler.close()
201
202      log_output = '\n'.join(
203          '%s:\n  ' % level + '\n  '.join(records)
204          for level, records in mock_log_handler.messages.iteritems()
205          if records)
206      if self.is_debugging and log_output:
207        self.stderr_save.write(
208            '==== logging RunCommand %s %s ====\n' % (self.id(), command_line))
209        self.stderr_save.write(log_output)
210        self.stderr_save.write('\n==== end logging ====\n')
211      if self.is_debugging and stdout:
212        self.stderr_save.write(
213            '==== stdout RunCommand %s %s ====\n' % (self.id(), command_line))
214        self.stderr_save.write(stdout)
215        self.stderr_save.write('==== end stdout ====\n')
216      if self.is_debugging and stderr:
217        self.stderr_save.write(
218            '==== stderr RunCommand %s %s ====\n' % (self.id(), command_line))
219        self.stderr_save.write(stderr)
220        self.stderr_save.write('==== end stderr ====\n')
221
222      # Reset stdout and stderr files, so that we won't print them out again
223      # in tearDown if debugging is enabled.
224      sys.stdout.seek(0)
225      sys.stderr.seek(0)
226      sys.stdout.truncate()
227      sys.stderr.truncate()
228
229    to_return = []
230    if return_stdout:
231      to_return.append(stdout)
232    if return_stderr:
233      to_return.append(stderr)
234    if return_log_handler:
235      to_return.append(mock_log_handler)
236    if len(to_return) == 1:
237      return to_return[0]
238    return tuple(to_return)
239
240  @classmethod
241  def MakeGsUtilApi(cls, debug=0):
242    gsutil_api_map = {
243        ApiMapConstants.API_MAP: (
244            cls.mock_gsutil_api_class_map_factory.GetClassMap()),
245        ApiMapConstants.SUPPORT_MAP: {
246            'gs': [ApiSelector.XML, ApiSelector.JSON],
247            's3': [ApiSelector.XML]
248        },
249        ApiMapConstants.DEFAULT_MAP: {
250            'gs': ApiSelector.JSON,
251            's3': ApiSelector.XML
252        }
253    }
254
255    return CloudApiDelegator(
256        cls.mock_bucket_storage_uri, gsutil_api_map, cls.logger, debug=debug)
257
258  @classmethod
259  def _test_wildcard_iterator(cls, uri_or_str, debug=0):
260    """Convenience method for instantiating a test instance of WildcardIterator.
261
262    This makes it unnecessary to specify all the params of that class
263    (like bucket_storage_uri_class=mock_storage_service.MockBucketStorageUri).
264    Also, naming the factory method this way makes it clearer in the test code
265    that WildcardIterator needs to be set up for testing.
266
267    Args are same as for wildcard_iterator.wildcard_iterator(), except
268    there are no class args for bucket_storage_uri_class or gsutil_api_class.
269
270    Args:
271      uri_or_str: StorageUri or string representing the wildcard string.
272      debug: debug level to pass to the underlying connection (0..3)
273
274    Returns:
275      WildcardIterator, over which caller can iterate.
276    """
277    # TODO: Remove when tests no longer pass StorageUri arguments.
278    uri_string = uri_or_str
279    if hasattr(uri_or_str, 'uri'):
280      uri_string = uri_or_str.uri
281
282    return wildcard_iterator.CreateWildcardIterator(
283        uri_string, cls.MakeGsUtilApi())
284
285  @staticmethod
286  def _test_storage_uri(uri_str, default_scheme='file', debug=0,
287                        validate=True):
288    """Convenience method for instantiating a testing instance of StorageUri.
289
290    This makes it unnecessary to specify
291    bucket_storage_uri_class=mock_storage_service.MockBucketStorageUri.
292    Also naming the factory method this way makes it clearer in the test
293    code that StorageUri needs to be set up for testing.
294
295    Args, Returns, and Raises are same as for boto.storage_uri(), except there's
296    no bucket_storage_uri_class arg.
297
298    Args:
299      uri_str: Uri string to create StorageUri for.
300      default_scheme: Default scheme for the StorageUri
301      debug: debug level to pass to the underlying connection (0..3)
302      validate: If True, validate the resource that the StorageUri refers to.
303
304    Returns:
305      StorageUri based on the arguments.
306    """
307    return boto.storage_uri(uri_str, default_scheme, debug, validate,
308                            util.GSMockBucketStorageUri)
309
310  def CreateBucket(self, bucket_name=None, test_objects=0, storage_class=None,
311                   provider='gs'):
312    """Creates a test bucket.
313
314    The bucket and all of its contents will be deleted after the test.
315
316    Args:
317      bucket_name: Create the bucket with this name. If not provided, a
318                   temporary test bucket name is constructed.
319      test_objects: The number of objects that should be placed in the bucket or
320                    a list of object names to place in the bucket. Defaults to
321                    0.
322      storage_class: storage class to use. If not provided we us standard.
323      provider: string provider to use, default gs.
324
325    Returns:
326      StorageUri for the created bucket.
327    """
328    bucket_name = bucket_name or self.MakeTempName('bucket')
329    bucket_uri = boto.storage_uri(
330        '%s://%s' % (provider, bucket_name.lower()),
331        suppress_consec_slashes=False,
332        bucket_storage_uri_class=util.GSMockBucketStorageUri)
333    bucket_uri.create_bucket(storage_class=storage_class)
334    self.bucket_uris.append(bucket_uri)
335    try:
336      iter(test_objects)
337    except TypeError:
338      test_objects = [self.MakeTempName('obj') for _ in range(test_objects)]
339    for i, name in enumerate(test_objects):
340      self.CreateObject(bucket_uri=bucket_uri, object_name=name,
341                        contents='test %d' % i)
342    return bucket_uri
343
344  def CreateObject(self, bucket_uri=None, object_name=None, contents=None):
345    """Creates a test object.
346
347    Args:
348      bucket_uri: The URI of the bucket to place the object in. If not
349                  specified, a new temporary bucket is created.
350      object_name: The name to use for the object. If not specified, a temporary
351                   test object name is constructed.
352      contents: The contents to write to the object. If not specified, the key
353                is not written to, which means that it isn't actually created
354                yet on the server.
355
356    Returns:
357      A StorageUri for the created object.
358    """
359    bucket_uri = bucket_uri or self.CreateBucket()
360    object_name = object_name or self.MakeTempName('obj')
361    key_uri = bucket_uri.clone_replace_name(object_name)
362    if contents is not None:
363      key_uri.set_contents_from_string(contents)
364    return key_uri
365