1#!/usr/bin/env python
2# Copyright 2011 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17"""Mapreduce shuffler implementation."""
18
19from __future__ import with_statement
20
21
22
23
24__all__ = [
25    "ShufflePipeline",
26    ]
27
28# Using opensource naming conventions, pylint: disable=g-bad-name
29
30import gc
31import heapq
32import logging
33import pickle
34import time
35
36import pipeline
37from pipeline import common as pipeline_common
38from google.appengine.ext import db
39from mapreduce import context
40from mapreduce import errors
41from mapreduce import input_readers
42from mapreduce import kv_pb
43from mapreduce import mapper_pipeline
44from mapreduce import model
45from mapreduce import operation
46from mapreduce import output_writers
47from mapreduce import pipeline_base
48from mapreduce import records
49from mapreduce import util
50
51# pylint: disable=g-import-not-at-top
52# TODO(user): Cleanup imports if/when cloudstorage becomes part of runtime.
53try:
54  # Check if the full cloudstorage package exists. The stub part is in runtime.
55  import cloudstorage
56  if hasattr(cloudstorage, "_STUB"):
57    cloudstorage = None
58except ImportError:
59  pass  # CloudStorage library not available
60
61
62# pylint: disable=g-bad-name
63# pylint: disable=protected-access
64
65
66class _OutputFile(db.Model):
67  """Entity to store output filenames of pipelines.
68
69  These entities are always children of key returned by get_root_key().
70  """
71
72  @classmethod
73  def kind(cls):
74    """Returns entity kind."""
75    return "_AE_MR_OutputFile"
76
77  @classmethod
78  def get_root_key(cls, job_id):
79    """Get root key to store output files.
80
81    Args:
82      job_id: pipeline's job id.
83
84    Returns:
85      root key for a given job id to store output file entities.
86    """
87    return db.Key.from_path(cls.kind(), job_id)
88
89
90def _compare_keys(key_record1, key_record2):
91  """Compare two (key, records) protos by key."""
92  return cmp(key_record1[0], key_record2[0])
93
94
95class _BatchGCSRecordsReader(
96    input_readers._GoogleCloudStorageRecordInputReader):
97  """GCS Records reader that reads in big batches."""
98
99  BATCH_SIZE = 1024 *1024 * 3
100
101  def __iter__(self):
102    # pylint: disable=redefined-outer-name
103    records = []
104    size = 0
105    try:
106      while True:
107        record = super(_BatchGCSRecordsReader, self).next()
108        records.append(record)
109        size += len(record)
110        if size > self.BATCH_SIZE:
111          yield records
112          size = 0
113          records = []
114          gc.collect()
115    except StopIteration:
116      pass
117    if records:
118      yield records
119      records = []
120      gc.collect()
121
122
123# pylint: disable=redefined-outer-name
124def _sort_records_map(records):
125  """Map function sorting records.
126
127  Converts records to KeyValue protos, sorts them by key and writes them
128  into new GCS file. Creates _OutputFile entity to record resulting
129  file name.
130
131  Args:
132    records: list of records which are serialized KeyValue protos.
133  """
134  ctx = context.get()
135  l = len(records)
136  key_records = [None] * l
137
138  logging.debug("Parsing")
139  for i in range(l):
140    proto = kv_pb.KeyValue()
141    proto.ParseFromString(records[i])
142    key_records[i] = (proto.key(), records[i])
143
144  logging.debug("Sorting")
145  key_records.sort(cmp=_compare_keys)
146
147  logging.debug("Writing")
148  mapper_spec = ctx.mapreduce_spec.mapper
149  params = input_readers._get_params(mapper_spec)
150  bucket_name = params.get("bucket_name")
151  filename = (ctx.mapreduce_spec.name + "/" + ctx.mapreduce_id + "/output-" +
152              ctx.shard_id + "-" + str(int(time.time())))
153  full_filename = "/%s/%s" % (bucket_name, filename)
154  filehandle = cloudstorage.open(full_filename, mode="w")
155  with output_writers.GCSRecordsPool(filehandle, ctx=ctx) as pool:
156    for key_record in key_records:
157      pool.append(key_record[1])
158
159  logging.debug("Finalizing")
160  filehandle.close()
161
162  entity = _OutputFile(key_name=full_filename,
163                       parent=_OutputFile.get_root_key(ctx.mapreduce_id))
164  entity.put()
165
166
167class _SortChunksPipeline(pipeline_base.PipelineBase):
168  """A pipeline to sort multiple key-value files.
169
170  Args:
171    job_name: root job name.
172    bucket_name: The name of the Google Cloud Storage bucket.
173    filenames: list of a list of filenames (hashed/bucketed) to sort,
174      as produced by _HashingGCSOutputWriter.
175
176  Returns:
177    The list of lists of sorted filenames. Each list corresponds to each
178    list of input files. Each filenames contains a chunk of sorted data.
179  """
180
181  def run(self, job_name, bucket_name, filenames):
182    sort_mappers = []
183    for i in range(len(filenames)):
184      filenames_only = util.strip_prefix_from_items("/%s/" % bucket_name,
185                                                    filenames[i])
186      sort_mapper = yield mapper_pipeline.MapperPipeline(
187          "%s-shuffle-sort-%s" % (job_name, str(i)),
188          __name__ + "._sort_records_map",
189          __name__ + "._BatchGCSRecordsReader",
190          None,
191          {
192              "input_reader": {
193                  "bucket_name": bucket_name,
194                  "objects": filenames_only,
195              },
196          },
197          shards=1)
198      sort_mappers.append(sort_mapper)
199    with pipeline.After(*sort_mappers):
200      job_ids = yield pipeline_common.Append(*[mapper.job_id for mapper in
201                                               sort_mappers])
202      result = yield _CollectOutputFiles(job_ids)
203      with pipeline.After(result):
204        yield _CleanupOutputFiles(job_ids)
205      yield pipeline_common.Return(result)
206
207
208class _CollectOutputFiles(pipeline_base.PipelineBase):
209  """Collect output file names from _OutputFile entities for given jobs.
210
211  Args:
212    job_ids: list of job ids to load filenames.
213
214  Returns:
215    list of lists of filenames produced by specified job ids.
216  """
217
218  def run(self, job_ids):
219    result = []
220    for job_id in job_ids:
221      entities = _OutputFile.all().ancestor(_OutputFile.get_root_key(job_id))
222      result.append([entity.key().name() for entity in entities])
223    return result
224
225
226class _CleanupOutputFiles(pipeline_base.PipelineBase):
227  """Cleanup _OutputFile entities for given job ids.
228
229  Args:
230    job_ids: list of job ids.
231  """
232
233  def run(self, job_ids):
234    for job_id in job_ids:
235      db.delete(_OutputFile.all().ancestor(_OutputFile.get_root_key(job_id)))
236
237
238class _MergingReader(input_readers.InputReader):
239  """Reader which merge-reads multiple sorted KeyValue files.
240
241  Reads list of lists of filenames. Each filename list constitutes one shard
242  and is merged together.
243
244  Yields (key, values) tuple. If none of the max_values_count and
245  max_values_size parameters are not specified, then there will be a single key.
246  Otherwise multiple (key, values) pairs for the same key will be created,
247  according to restrictions.
248  """
249
250  expand_parameters = True
251
252  FILES_PARAM = "files"
253  MAX_VALUES_COUNT_PARAM = "max_values_count"
254  MAX_VALUES_SIZE_PARAM = "max_values_size"
255
256  # Use a smaller buffer than the default.
257  GCS_BUFFER_SIZE = 256 * 1024  # 256K.
258
259  def __init__(self,
260               offsets,
261               max_values_count,
262               max_values_size):
263    """Constructor.
264
265    Args:
266      offsets: offsets for each input file to start from as list of ints.
267      max_values_count: maximum number of values to yield for a single value at
268        a time. Ignored if -1.
269      max_values_size: maximum total size of yielded values.  Ignored if -1
270    """
271    self._offsets = offsets
272    self._max_values_count = max_values_count
273    self._max_values_size = max_values_size
274
275  def __iter__(self):
276    """Iterate over records in input files.
277
278    self._offsets is always correctly updated so that stopping iterations
279    doesn't skip records and doesn't read the same record twice.
280
281    Raises:
282      Exception: when Files list and offsets do not match.
283
284    Yields:
285      The result.
286    """
287    ctx = context.get()
288    mapper_spec = ctx.mapreduce_spec.mapper
289    shard_number = ctx._shard_state.shard_number
290    filenames = mapper_spec.params[self.FILES_PARAM][shard_number]
291
292    if len(filenames) != len(self._offsets):
293      raise Exception("Files list and offsets do not match.")
294
295    # Heap with (Key, Value, Index, reader) pairs.
296    readers = []
297
298    # Initialize heap
299    for (i, filename) in enumerate(filenames):
300      offset = self._offsets[i]
301      # TODO(user): Shrinking the buffer size is a workaround until
302      # a tiered/segmented merge is implemented.
303      reader = records.RecordsReader(
304          cloudstorage.open(filename, read_buffer_size=self.GCS_BUFFER_SIZE))
305      reader.seek(offset)
306      readers.append((None, None, i, reader))
307
308    # Read records from heap and merge values with the same key.
309
310    # current_result is yielded and consumed buy _merge_map.
311    # current_result = (key, value, is_partial)
312    current_result = None
313    current_count = 0
314    current_size = 0
315    while readers:
316      (key, value, index, reader) = readers[0]
317
318      if key is not None:
319        current_count += 1
320        current_size += len(value)
321
322        should_yield = False
323        if current_result:
324          if key != current_result[0]:
325            # New key encountered
326            should_yield = True
327          elif (self._max_values_count != -1 and
328                current_count >= self._max_values_count):
329            # Maximum number of values encountered.
330            current_result[2] = True
331            should_yield = True
332          elif (self._max_values_size != -1 and
333                current_size >= self._max_values_size):
334            # Maximum size of values encountered
335            current_result[2] = True
336            should_yield = True
337
338        if should_yield:
339          # New key encountered or maximum count hit. Yield current key.
340          yield current_result
341        if not current_result or should_yield:
342          current_result = [key, [], False]
343          current_count = 0
344          current_size = 0
345        current_result[1].append(value)
346
347      # Read next key/value from reader.
348      try:
349        self._offsets[index] = reader.tell()
350        start_time = time.time()
351        binary_record = reader.read()
352        # update counters
353        if context.get():
354          operation.counters.Increment(
355              input_readers.COUNTER_IO_READ_BYTES,
356              len(binary_record))(context.get())
357          operation.counters.Increment(
358              input_readers.COUNTER_IO_READ_MSEC,
359              int((time.time() - start_time) * 1000))(context.get())
360        proto = kv_pb.KeyValue()
361        proto.ParseFromString(binary_record)
362        # Put read data back into heap.
363        heapq.heapreplace(readers,
364                          (proto.key(), proto.value(), index, reader))
365      except EOFError:
366        heapq.heappop(readers)
367
368    # Yield leftovers.
369    if current_result:
370      yield current_result
371
372  @classmethod
373  def from_json(cls, json):
374    """Restore reader from json state."""
375    return cls(json["offsets"],
376               json["max_values_count"],
377               json["max_values_size"])
378
379  def to_json(self):
380    """Serialize reader state to json."""
381    return {"offsets": self._offsets,
382            "max_values_count": self._max_values_count,
383            "max_values_size": self._max_values_size}
384
385  @classmethod
386  def split_input(cls, mapper_spec):
387    """Split input into multiple shards."""
388    filelists = mapper_spec.params[cls.FILES_PARAM]
389    max_values_count = mapper_spec.params.get(cls.MAX_VALUES_COUNT_PARAM, -1)
390    max_values_size = mapper_spec.params.get(cls.MAX_VALUES_SIZE_PARAM, -1)
391    return [cls([0] * len(files), max_values_count, max_values_size)
392            for files in filelists]
393
394  @classmethod
395  def validate(cls, mapper_spec):
396    """Validate reader parameters in mapper_spec."""
397    if mapper_spec.input_reader_class() != cls:
398      raise errors.BadReaderParamsError("Input reader class mismatch")
399    params = mapper_spec.params
400    if cls.FILES_PARAM not in params:
401      raise errors.BadReaderParamsError("Missing files parameter.")
402
403
404class _HashingGCSOutputWriter(output_writers.OutputWriter):
405  """An OutputWriter which outputs data into GCS in key-value format.
406
407  The output is tailored towards shuffler needs. It shards key/values using
408  key hash modulo number of output files. Each shard will hash keys that will
409  be placed in one of shard_count number of files (buckets) specific to that
410  shard. The same key will be hashed to the same logical file across all of
411  the shards. Then the list of all the same logical files will be assembled
412  and a list of those lists will be returned.
413  """
414
415  # Supported parameters
416  BUCKET_NAME_PARAM = "bucket_name"
417
418  # pylint: disable=super-init-not-called
419  def __init__(self, filehandles):
420    """Constructor.
421
422    Args:
423      filehandles: list of file handles that this writer outputs to.
424    """
425    self._filehandles = filehandles
426    self._pools = [None] * len(filehandles)
427
428  @classmethod
429  def validate(cls, mapper_spec):
430    """Validates mapper specification.
431
432    Args:
433      mapper_spec: an instance of model.MapperSpec to validate.
434    Raises:
435      BadWriterParamsError: when Output writer class mismatch.
436    """
437    if mapper_spec.output_writer_class() != cls:
438      raise errors.BadWriterParamsError("Output writer class mismatch")
439    params = output_writers._get_params(mapper_spec)
440    # Bucket Name is required
441    if cls.BUCKET_NAME_PARAM not in params:
442      raise errors.BadWriterParamsError(
443          "%s is required for the _HashingGCSOutputWriter" %
444          cls.BUCKET_NAME_PARAM)
445
446  @classmethod
447  def from_json(cls, json):
448    """Creates an instance of the OutputWriter for the given json state.
449
450    Args:
451      json: The OutputWriter state as a dict-like object.
452
453    Returns:
454      An instance of the OutputWriter configured using the values of json.
455    """
456    return cls(pickle.loads(json["filehandles"]))
457
458  def to_json(self):
459    """Returns writer state to serialize in json.
460
461    Returns:
462      A json-izable version of the OutputWriter state.
463    """
464    # Use the member variable (since we don't have access to the context) to
465    # flush each pool to minimize the size of each filehandle before we
466    # serialize it.
467    for pool in self._pools:
468      if pool is not None:
469        pool.flush(True)
470    return {"filehandles": pickle.dumps(self._filehandles)}
471
472  @classmethod
473  def create(cls, mr_spec, shard_number, shard_attempt, _writer_state=None):
474    """Inherit docs."""
475    mapper_spec = mr_spec.mapper
476    params = output_writers._get_params(mapper_spec)
477    bucket_name = params.get(cls.BUCKET_NAME_PARAM)
478    shards = mapper_spec.shard_count
479
480    filehandles = []
481    filename = (mr_spec.name + "/" + mr_spec.mapreduce_id +
482                "/shard-" + str(shard_number) + "-bucket-")
483    for i in range(shards):
484      full_filename = "/%s/%s%d" % (bucket_name, filename, i)
485      filehandles.append(cloudstorage.open(full_filename, mode="w"))
486    return cls(filehandles)
487
488  @classmethod
489  def get_filenames(cls, mapreduce_state):
490    """See parent class."""
491    shards = mapreduce_state.mapreduce_spec.mapper.shard_count
492    filenames = []
493    for _ in range(shards):
494      filenames.append([None] * shards)
495    shard_states = model.ShardState.find_all_by_mapreduce_state(mapreduce_state)
496    for x, shard_state in enumerate(shard_states):
497      shard_filenames = shard_state.writer_state["shard_filenames"]
498      for y in range(shards):
499        filenames[y][x] = shard_filenames[y]
500    return filenames
501
502  def finalize(self, ctx, shard_state):
503    """See parent class."""
504    filenames = []
505    for filehandle in self._filehandles:
506      filenames.append(filehandle.name)
507      filehandle.close()
508    shard_state.writer_state = {"shard_filenames": filenames}
509
510  def write(self, data):
511    """Write data.
512
513    Args:
514      data: actual data yielded from handler. Type is writer-specific.
515    """
516    ctx = context.get()
517    if len(data) != 2:
518      logging.error("Got bad tuple of length %d (2-tuple expected): %s",
519                    len(data), data)
520
521    try:
522      key = str(data[0])
523      value = str(data[1])
524    except TypeError:
525      logging.error("Expecting a tuple, but got %s: %s",
526                    data.__class__.__name__, data)
527
528    file_index = key.__hash__() % len(self._filehandles)
529
530    # Work-around: Since we don't have access to the context in the to_json()
531    # function, but we need to flush each pool before we serialize the
532    # filehandle, we rely on a member variable instead of using context for
533    # pool management.
534    pool = self._pools[file_index]
535    if pool is None:
536      filehandle = self._filehandles[file_index]
537      pool = output_writers.GCSRecordsPool(filehandle=filehandle, ctx=ctx)
538      self._pools[file_index] = pool
539
540    proto = kv_pb.KeyValue()
541    proto.set_key(key)
542    proto.set_value(value)
543    pool.append(proto.Encode())
544
545
546class _ShardOutputs(pipeline_base.PipelineBase):
547  """Shards the ouputs.
548
549  Takes a flat list of filenames, returns a list of lists, each with
550  one member each.
551  """
552
553  def run(self, filenames):
554    result = []
555    for name in filenames:
556      result.append([name])
557    return result
558
559
560# pylint: disable=unused-argument
561def _merge_map(key, values, partial):
562  """A map function used in merge phase.
563
564  Stores (key, values) into KeyValues proto and yields its serialization.
565
566  Args:
567    key: values key.
568    values: values themselves.
569    partial: True if more values for this key will follow. False otherwise.
570
571  Yields:
572    The proto.
573  """
574  proto = kv_pb.KeyValues()
575  proto.set_key(key)
576  proto.value_list().extend(values)
577  yield proto.Encode()
578
579
580class _MergePipeline(pipeline_base.PipelineBase):
581  """Pipeline to merge sorted chunks.
582
583  This pipeline merges together individually sorted chunks of each shard.
584
585  Args:
586    filenames: list of lists of filenames. Each list will correspond to a single
587      shard. Each file in the list should have keys sorted and should contain
588      records with KeyValue serialized entity.
589
590  Yields:
591    The list of filenames, where each filename is fully merged and will contain
592    records with KeyValues serialized entity.
593  """
594
595  # Maximum number of values to produce in a single KeyValues proto.
596  _MAX_VALUES_COUNT = 100000  # Combiners usually good for 5 orders of magnitude
597  # Maximum size of values to produce in a single KeyValues proto.
598  _MAX_VALUES_SIZE = 1000000
599
600  def run(self, job_name, bucket_name, filenames):
601    yield mapper_pipeline.MapperPipeline(
602        job_name + "-shuffle-merge",
603        __name__ + "._merge_map",
604        __name__ + "._MergingReader",
605        output_writer_spec=
606        output_writers.__name__ + "._GoogleCloudStorageRecordOutputWriter",
607        params={
608            _MergingReader.FILES_PARAM: filenames,
609            _MergingReader.MAX_VALUES_COUNT_PARAM: self._MAX_VALUES_COUNT,
610            _MergingReader.MAX_VALUES_SIZE_PARAM: self._MAX_VALUES_SIZE,
611            "output_writer": {
612                "bucket_name": bucket_name,
613            },
614        },
615        shards=len(filenames))
616
617
618def _hashing_map(binary_record):
619  """A map function used in hash phase.
620
621  Reads KeyValue from binary record.
622
623  Args:
624    binary_record: The binary record.
625
626  Yields:
627    The (key, value).
628  """
629  proto = kv_pb.KeyValue()
630  proto.ParseFromString(binary_record)
631  yield (proto.key(), proto.value())
632
633
634class _HashPipeline(pipeline_base.PipelineBase):
635  """A pipeline to read mapper output and hash by key.
636
637  Args:
638    job_name: root mapreduce job name.
639    bucket_name: The name of the Google Cloud Storage bucket.
640    filenames: filenames of mapper output. Should be of records format
641      with serialized KeyValue proto.
642    shards: Optional. Number of output shards to generate. Defaults
643      to the number of input files.
644
645  Yields:
646    The list of filenames. Each file is of records formad with serialized
647    KeyValue proto. For each proto its output file is decided based on key
648    hash. Thus all equal keys would end up in the same file.
649  """
650
651  def run(self, job_name, bucket_name, filenames, shards=None):
652    filenames_only = (
653        util.strip_prefix_from_items("/%s/" % bucket_name, filenames))
654    if shards is None:
655      shards = len(filenames)
656    yield mapper_pipeline.MapperPipeline(
657        job_name + "-shuffle-hash",
658        __name__ + "._hashing_map",
659        input_readers.__name__ + "._GoogleCloudStorageRecordInputReader",
660        output_writer_spec=__name__ + "._HashingGCSOutputWriter",
661        params={
662            "input_reader": {
663                "bucket_name": bucket_name,
664                "objects": filenames_only,
665            },
666            "output_writer": {
667                "bucket_name": bucket_name,
668            },
669        },
670        shards=shards)
671
672
673class ShufflePipeline(pipeline_base.PipelineBase):
674  """A pipeline to shuffle multiple key-value files.
675
676  Args:
677    job_name: The descriptive name of the overall job.
678    mapper_params: parameters to use for mapper phase.
679    filenames: list of file names to sort. Files have to be of records format
680      defined by Files API and contain serialized kv_pb.KeyValue
681      protocol messages. The filenames may or may not contain the
682      GCS bucket name in their path.
683    shards: Optional. Number of output shards to generate. Defaults
684      to the number of input files.
685
686  Returns:
687    default: a list of filenames as string. Resulting files contain
688      serialized kv_pb.KeyValues protocol messages with
689      all values collated to a single key. When there is no output,
690      an empty list from shuffle service or a list of empty files from
691      in memory shuffler.
692  """
693
694  def run(self, job_name, mapper_params, filenames, shards=None):
695    bucket_name = mapper_params["bucket_name"]
696    hashed_files = yield _HashPipeline(job_name, bucket_name,
697                                       filenames, shards=shards)
698    sorted_files = yield _SortChunksPipeline(job_name, bucket_name,
699                                             hashed_files)
700    temp_files = [hashed_files, sorted_files]
701
702    merged_files = yield _MergePipeline(job_name, bucket_name, sorted_files)
703
704    with pipeline.After(merged_files):
705      all_temp_files = yield pipeline_common.Extend(*temp_files)
706      yield _GCSCleanupPipeline(all_temp_files)
707
708    yield pipeline_common.Return(merged_files)
709
710
711class _GCSCleanupPipeline(pipeline_base.PipelineBase):
712  """A pipeline to do a cleanup for mapreduce jobs that use GCS.
713
714  Args:
715    filename_or_list: list of files or file lists to delete.
716  """
717
718  # The minimum number of retries for GCS to delete the file.
719  _MIN_RETRIES = 5
720  # The maximum number of retries for GCS to delete the file.
721  _MAX_RETRIES = 10
722
723  def delete_file_or_list(self, filename_or_list):
724    if isinstance(filename_or_list, list):
725      for filename in filename_or_list:
726        self.delete_file_or_list(filename)
727    else:
728      filename = filename_or_list
729      retry_params = cloudstorage.RetryParams(min_retries=self._MIN_RETRIES,
730                                              max_retries=self._MAX_RETRIES)
731      # pylint: disable=bare-except
732      try:
733        cloudstorage.delete(filename, retry_params)
734      except:
735        pass
736
737  def run(self, temp_files):
738    self.delete_file_or_list(temp_files)
739