1#!/usr/bin/env python
2# Copyright 2010 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17"""Defines executor tasks handlers for MapReduce implementation."""
18
19
20
21# pylint: disable=protected-access
22# pylint: disable=g-bad-name
23
24import datetime
25import logging
26import math
27import os
28import random
29import sys
30import time
31import traceback
32import zlib
33
34try:
35  import json
36except ImportError:
37  import simplejson as json
38
39from google.appengine.ext import ndb
40
41from google.appengine import runtime
42from google.appengine.api import datastore_errors
43from google.appengine.api import logservice
44from google.appengine.api import taskqueue
45from google.appengine.ext import db
46from mapreduce import base_handler
47from mapreduce import context
48from mapreduce import errors
49from mapreduce import input_readers
50from mapreduce import map_job_context
51from mapreduce import model
52from mapreduce import operation
53from mapreduce import output_writers
54from mapreduce import parameters
55from mapreduce import shard_life_cycle
56from mapreduce import util
57from mapreduce.api import map_job
58from google.appengine.runtime import apiproxy_errors
59
60# pylint: disable=g-import-not-at-top
61try:
62  import cloudstorage
63  # In 25 runtime, the above code will be scrubbed to import the stub version
64  # of cloudstorage. All occurences of the following if condition in MR
65  # codebase is to tell it apart.
66  # TODO(user): Remove after 25 runtime MR is abondoned.
67  if hasattr(cloudstorage, "_STUB"):
68    cloudstorage = None
69except ImportError:
70  cloudstorage = None  # CloudStorage library not available
71
72
73# A guide to logging.
74# log.critical: messages user absolutely should see, e.g. failed job.
75# log.error: exceptions during processing user data, or unexpected
76# errors detected by mr framework.
77# log.warning: errors mr framework knows how to handle.
78# log.info: other expected events.
79
80
81# Set of strings of various test-injected faults.
82_TEST_INJECTED_FAULTS = set()
83
84
85def _run_task_hook(hooks, method, task, queue_name):
86  """Invokes hooks.method(task, queue_name).
87
88  Args:
89    hooks: A hooks.Hooks instance or None.
90    method: The name of the method to invoke on the hooks class e.g.
91        "enqueue_kickoff_task".
92    task: The taskqueue.Task to pass to the hook method.
93    queue_name: The name of the queue to pass to the hook method.
94
95  Returns:
96    True if the hooks.Hooks instance handled the method, False otherwise.
97  """
98  if hooks is not None:
99    try:
100      getattr(hooks, method)(task, queue_name)
101    except NotImplementedError:
102      # Use the default task addition implementation.
103      return False
104
105    return True
106  return False
107
108
109class MapperWorkerCallbackHandler(base_handler.HugeTaskHandler):
110  """Callback handler for mapreduce worker task."""
111
112  # These directives instruct self.__return() how to set state and enqueue task.
113  _TASK_DIRECTIVE = util._enum(
114      # Task is running as expected.
115      PROCEED_TASK="proceed_task",
116      # Need to retry task. Lock was NOT acquired when the error occur.
117      # Don't change payload or datastore.
118      RETRY_TASK="retry_task",
119      # Need to retry task. Lock was acquired when the error occurr.
120      # Don't change payload or datastore.
121      RETRY_SLICE="retry_slice",
122      # Drop the task (due to duplicated task). Must log permanent drop.
123      DROP_TASK="drop_task",
124      # See handlers.MapperWorkerCallbackHandler._attempt_slice_recovery.
125      RECOVER_SLICE="recover_slice",
126      # Need to retry the shard.
127      RETRY_SHARD="retry_shard",
128      # Need to drop task and fail the shard. Log permanent failure.
129      FAIL_TASK="fail_task",
130      # Need to abort the shard.
131      ABORT_SHARD="abort_shard")
132
133  def __init__(self, *args):
134    """Constructor."""
135    super(MapperWorkerCallbackHandler, self).__init__(*args)
136    self._time = time.time
137    self.slice_context = None
138    self.shard_context = None
139
140  def _drop_gracefully(self):
141    """Drop worker task gracefully.
142
143    Set current shard_state to failed. Controller logic will take care of
144    other shards and the entire MR.
145    """
146    shard_id = self.request.headers[util._MR_SHARD_ID_TASK_HEADER]
147    mr_id = self.request.headers[util._MR_ID_TASK_HEADER]
148    shard_state, mr_state = db.get([
149        model.ShardState.get_key_by_shard_id(shard_id),
150        model.MapreduceState.get_key_by_job_id(mr_id)])
151
152    if shard_state and shard_state.active:
153      shard_state.set_for_failure()
154      config = util.create_datastore_write_config(mr_state.mapreduce_spec)
155      shard_state.put(config=config)
156
157  def _try_acquire_lease(self, shard_state, tstate):
158    """Validate datastore and the task payload are consistent.
159
160    If so, attempt to get a lease on this slice's execution.
161    See model.ShardState doc on slice_start_time.
162
163    Args:
164      shard_state: model.ShardState from datastore.
165      tstate: model.TransientShardState from taskqueue paylod.
166
167    Returns:
168      A _TASK_DIRECTIVE enum. PROCEED_TASK if lock is acquired.
169    RETRY_TASK if task should be retried, DROP_TASK if task should
170    be dropped. Only old tasks (comparing to datastore state)
171    will be dropped. Future tasks are retried until they naturally
172    become old so that we don't ever stuck MR.
173    """
174    # Controller will tally shard_states and properly handle the situation.
175    if not shard_state:
176      logging.warning("State not found for shard %s; Possible spurious task "
177                      "execution. Dropping this task.",
178                      tstate.shard_id)
179      return self._TASK_DIRECTIVE.DROP_TASK
180
181    if not shard_state.active:
182      logging.warning("Shard %s is not active. Possible spurious task "
183                      "execution. Dropping this task.", tstate.shard_id)
184      logging.warning(str(shard_state))
185      return self._TASK_DIRECTIVE.DROP_TASK
186
187    # Validate shard retry count.
188    if shard_state.retries > tstate.retries:
189      logging.warning(
190          "Got shard %s from previous shard retry %s. Possible spurious "
191          "task execution. Dropping this task.",
192          tstate.shard_id,
193          tstate.retries)
194      logging.warning(str(shard_state))
195      return self._TASK_DIRECTIVE.DROP_TASK
196    elif shard_state.retries < tstate.retries:
197      # By the end of last slice, task enqueue succeeded but datastore commit
198      # failed. That transaction will be retried and adding the same task
199      # will pass.
200      logging.warning(
201          "ShardState for %s is behind slice. Waiting for it to catch up",
202          shard_state.shard_id)
203      return self._TASK_DIRECTIVE.RETRY_TASK
204
205    # Validate slice id.
206    # Taskqueue executes old successful tasks.
207    if shard_state.slice_id > tstate.slice_id:
208      logging.warning(
209          "Task %s-%s is behind ShardState %s. Dropping task.""",
210          tstate.shard_id, tstate.slice_id, shard_state.slice_id)
211      return self._TASK_DIRECTIVE.DROP_TASK
212    # By the end of last slice, task enqueue succeeded but datastore commit
213    # failed. That transaction will be retried and adding the same task
214    # will pass. User data is duplicated in this case.
215    elif shard_state.slice_id < tstate.slice_id:
216      logging.warning(
217          "Task %s-%s is ahead of ShardState %s. Waiting for it to catch up.",
218          tstate.shard_id, tstate.slice_id, shard_state.slice_id)
219      return self._TASK_DIRECTIVE.RETRY_TASK
220
221    # Check potential duplicated tasks for the same slice.
222    # See model.ShardState doc.
223    if shard_state.slice_start_time:
224      countdown = self._wait_time(shard_state,
225                                  parameters._LEASE_DURATION_SEC)
226      if countdown > 0:
227        logging.warning(
228            "Last retry of slice %s-%s may be still running."
229            "Will try again in %s seconds", tstate.shard_id, tstate.slice_id,
230            countdown)
231        # TODO(user): There might be a better way. Taskqueue's countdown
232        # only applies to add new tasks, not retry of tasks.
233        # Reduce contention.
234        time.sleep(countdown)
235        return self._TASK_DIRECTIVE.RETRY_TASK
236      # lease could have expired. Verify with logs API.
237      else:
238        if self._wait_time(shard_state,
239                           parameters._MAX_LEASE_DURATION_SEC):
240          if not self._has_old_request_ended(shard_state):
241            logging.warning(
242                "Last retry of slice %s-%s is still in flight with request_id "
243                "%s. Will try again later.", tstate.shard_id, tstate.slice_id,
244                shard_state.slice_request_id)
245            return self._TASK_DIRECTIVE.RETRY_TASK
246        else:
247          logging.warning(
248              "Last retry of slice %s-%s has no log entry and has"
249              "timed out after %s seconds",
250              tstate.shard_id, tstate.slice_id,
251              parameters._MAX_LEASE_DURATION_SEC)
252
253    # Lease expired or slice_start_time not set.
254    config = util.create_datastore_write_config(tstate.mapreduce_spec)
255    @db.transactional(retries=5)
256    def _tx():
257      """Use datastore to set slice_start_time to now.
258
259      If failed for any reason, raise error to retry the task (hence all
260      the previous validation code). The task would die naturally eventually.
261
262      Raises:
263        Rollback: If the shard state is missing.
264
265      Returns:
266        A _TASK_DIRECTIVE enum.
267      """
268      fresh_state = model.ShardState.get_by_shard_id(tstate.shard_id)
269      if not fresh_state:
270        logging.warning("ShardState missing.")
271        raise db.Rollback()
272      if (fresh_state.active and
273          fresh_state.slice_id == shard_state.slice_id and
274          fresh_state.slice_start_time == shard_state.slice_start_time):
275        shard_state.slice_start_time = datetime.datetime.now()
276        shard_state.slice_request_id = os.environ.get("REQUEST_LOG_ID")
277        shard_state.acquired_once = True
278        shard_state.put(config=config)
279        return self._TASK_DIRECTIVE.PROCEED_TASK
280      else:
281        logging.warning(
282            "Contention on slice %s-%s execution. Will retry again.",
283            tstate.shard_id, tstate.slice_id)
284        # One proposer should win. In case all lost, back off arbitrarily.
285        time.sleep(random.randrange(1, 5))
286        return self._TASK_DIRECTIVE.RETRY_TASK
287
288    return _tx()
289
290  def _has_old_request_ended(self, shard_state):
291    """Whether previous slice retry has ended according to Logs API.
292
293    Args:
294      shard_state: shard state.
295
296    Returns:
297      True if the request of previous slice retry has ended. False if it has
298    not or unknown.
299    """
300    assert shard_state.slice_start_time is not None
301    assert shard_state.slice_request_id is not None
302    request_ids = [shard_state.slice_request_id]
303    logs = None
304    try:
305      logs = list(logservice.fetch(request_ids=request_ids))
306    except (apiproxy_errors.FeatureNotEnabledError,
307        apiproxy_errors.CapabilityDisabledError) as e:
308      # Managed VMs do not have access to the logservice API
309      # See https://groups.google.com/forum/#!topic/app-engine-managed-vms/r8i65uiFW0w
310      logging.warning("Ignoring exception: %s", e)
311
312    if not logs or not logs[0].finished:
313      return False
314    return True
315
316  def _wait_time(self, shard_state, secs, now=datetime.datetime.now):
317    """Time to wait until slice_start_time is secs ago from now.
318
319    Args:
320      shard_state: shard state.
321      secs: duration in seconds.
322      now: a func that gets now.
323
324    Returns:
325      0 if no wait. A positive int in seconds otherwise. Always around up.
326    """
327    assert shard_state.slice_start_time is not None
328    delta = now() - shard_state.slice_start_time
329    duration = datetime.timedelta(seconds=secs)
330    if delta < duration:
331      return util.total_seconds(duration - delta)
332    else:
333      return 0
334
335  def _try_free_lease(self, shard_state, slice_retry=False):
336    """Try to free lease.
337
338    A lightweight transaction to update shard_state and unset
339    slice_start_time to allow the next retry to happen without blocking.
340    We don't care if this fails or not because the lease will expire
341    anyway.
342
343    Under normal execution, _save_state_and_schedule_next is the exit point.
344    It updates/saves shard state and schedules the next slice or returns.
345    Other exit points are:
346    1. _are_states_consistent: at the beginning of handle, checks
347      if datastore states and the task are in sync.
348      If not, raise or return.
349    2. _attempt_slice_retry: may raise exception to taskqueue.
350    3. _save_state_and_schedule_next: may raise exception when taskqueue/db
351       unreachable.
352
353    This handler should try to free the lease on every exceptional exit point.
354
355    Args:
356      shard_state: model.ShardState.
357      slice_retry: whether to count this as a failed slice execution.
358    """
359    @db.transactional
360    def _tx():
361      fresh_state = model.ShardState.get_by_shard_id(shard_state.shard_id)
362      if fresh_state and fresh_state.active:
363        # Free lease.
364        fresh_state.slice_start_time = None
365        fresh_state.slice_request_id = None
366        if slice_retry:
367          fresh_state.slice_retries += 1
368        fresh_state.put()
369    try:
370      _tx()
371    # pylint: disable=broad-except
372    except Exception, e:
373      logging.warning(e)
374      logging.warning(
375          "Release lock for shard %s failed. Wait for lease to expire.",
376          shard_state.shard_id)
377
378  def _maintain_LC(self, obj, slice_id, last_slice=False, begin_slice=True,
379                   shard_ctx=None, slice_ctx=None):
380    """Makes sure shard life cycle interface are respected.
381
382    Args:
383      obj: the obj that may have implemented _ShardLifeCycle.
384      slice_id: current slice_id
385      last_slice: whether this is the last slice.
386      begin_slice: whether this is the beginning or the end of a slice.
387      shard_ctx: shard ctx for dependency injection. If None, it will be read
388        from self.
389      slice_ctx: slice ctx for dependency injection. If None, it will be read
390        from self.
391    """
392    if obj is None or not isinstance(obj, shard_life_cycle._ShardLifeCycle):
393      return
394
395    shard_context = shard_ctx or self.shard_context
396    slice_context = slice_ctx or self.slice_context
397    if begin_slice:
398      if slice_id == 0:
399        obj.begin_shard(shard_context)
400      obj.begin_slice(slice_context)
401    else:
402      obj.end_slice(slice_context)
403      if last_slice:
404        obj.end_shard(shard_context)
405
406  def _lc_start_slice(self, tstate, slice_id):
407    self._maintain_LC(tstate.output_writer, slice_id)
408    self._maintain_LC(tstate.input_reader, slice_id)
409    self._maintain_LC(tstate.handler, slice_id)
410
411  def _lc_end_slice(self, tstate, slice_id, last_slice=False):
412    self._maintain_LC(tstate.handler, slice_id, last_slice=last_slice,
413                      begin_slice=False)
414    self._maintain_LC(tstate.input_reader, slice_id, last_slice=last_slice,
415                      begin_slice=False)
416    self._maintain_LC(tstate.output_writer, slice_id, last_slice=last_slice,
417                      begin_slice=False)
418
419  def handle(self):
420    """Handle request.
421
422    This method has to be careful to pass the same ShardState instance to
423    its subroutines calls if the calls mutate or read from ShardState.
424    Note especially that Context instance caches and updates the ShardState
425    instance.
426
427    Returns:
428      Set HTTP status code and always returns None.
429    """
430    # Reconstruct basic states.
431    self._start_time = self._time()
432    shard_id = self.request.headers[util._MR_SHARD_ID_TASK_HEADER]
433    mr_id = self.request.headers[util._MR_ID_TASK_HEADER]
434    spec = model.MapreduceSpec._get_mapreduce_spec(mr_id)
435    shard_state, control = db.get([
436        model.ShardState.get_key_by_shard_id(shard_id),
437        model.MapreduceControl.get_key_by_job_id(mr_id),
438    ])
439
440    # Set context before any IO code is called.
441    ctx = context.Context(spec, shard_state,
442                          task_retry_count=self.task_retry_count())
443    context.Context._set(ctx)
444
445    # Unmarshall input reader, output writer, and other transient states.
446    tstate = model.TransientShardState.from_request(self.request)
447
448    # Try acquire a lease on the shard.
449    if shard_state:
450      is_this_a_retry = shard_state.acquired_once
451    task_directive = self._try_acquire_lease(shard_state, tstate)
452    if task_directive in (self._TASK_DIRECTIVE.RETRY_TASK,
453                          self._TASK_DIRECTIVE.DROP_TASK):
454      return self.__return(shard_state, tstate, task_directive)
455    assert task_directive == self._TASK_DIRECTIVE.PROCEED_TASK
456
457    # Abort shard if received signal.
458    if control and control.command == model.MapreduceControl.ABORT:
459      task_directive = self._TASK_DIRECTIVE.ABORT_SHARD
460      return self.__return(shard_state, tstate, task_directive)
461
462    # Retry shard if user disabled slice retry.
463    if (is_this_a_retry and
464        parameters.config.TASK_MAX_DATA_PROCESSING_ATTEMPTS <= 1):
465      task_directive = self._TASK_DIRECTIVE.RETRY_SHARD
466      return self.__return(shard_state, tstate, task_directive)
467
468    # TODO(user): Find a better way to set these per thread configs.
469    # E.g. what if user change it?
470    util._set_ndb_cache_policy()
471
472    job_config = map_job.JobConfig._to_map_job_config(
473        spec,
474        os.environ.get("HTTP_X_APPENGINE_QUEUENAME"))
475    job_context = map_job_context.JobContext(job_config)
476    self.shard_context = map_job_context.ShardContext(job_context, shard_state)
477    self.slice_context = map_job_context.SliceContext(self.shard_context,
478                                                      shard_state,
479                                                      tstate)
480    try:
481      slice_id = tstate.slice_id
482      self._lc_start_slice(tstate, slice_id)
483
484      if shard_state.is_input_finished():
485        self._lc_end_slice(tstate, slice_id, last_slice=True)
486        # Finalize the stream and set status if there's no more input.
487        if (tstate.output_writer and
488            isinstance(tstate.output_writer, output_writers.OutputWriter)):
489          # It's possible that finalization is successful but
490          # saving state failed. In this case this shard will retry upon
491          # finalization error.
492          # TODO(user): make finalize method idempotent!
493          tstate.output_writer.finalize(ctx, shard_state)
494        shard_state.set_for_success()
495        return self.__return(shard_state, tstate, task_directive)
496
497      if is_this_a_retry:
498        task_directive = self._attempt_slice_recovery(shard_state, tstate)
499        if task_directive != self._TASK_DIRECTIVE.PROCEED_TASK:
500          return self.__return(shard_state, tstate, task_directive)
501
502      last_slice = self._process_inputs(
503          tstate.input_reader, shard_state, tstate, ctx)
504
505      self._lc_end_slice(tstate, slice_id)
506
507      ctx.flush()
508
509      if last_slice:
510        # We're done processing data but we still need to finalize the output
511        # stream. We save this condition in datastore and force a new slice.
512        # That way if finalize fails no input data will be retried.
513        shard_state.set_input_finished()
514    # pylint: disable=broad-except
515    except Exception, e:
516      logging.warning("Shard %s got error.", shard_state.shard_id)
517      logging.error(traceback.format_exc())
518
519      # Fail fast.
520      if type(e) is errors.FailJobError:
521        logging.error("Got FailJobError.")
522        task_directive = self._TASK_DIRECTIVE.FAIL_TASK
523      else:
524        task_directive = self._TASK_DIRECTIVE.RETRY_SLICE
525
526    self.__return(shard_state, tstate, task_directive)
527
528  def __return(self, shard_state, tstate, task_directive):
529    """Handler should always call this as the last statement."""
530    task_directive = self._set_state(shard_state, tstate, task_directive)
531    self._save_state_and_schedule_next(shard_state, tstate, task_directive)
532    context.Context._set(None)
533
534  def _process_inputs(self,
535                      input_reader,
536                      shard_state,
537                      tstate,
538                      ctx):
539    """Read inputs, process them, and write out outputs.
540
541    This is the core logic of MapReduce. It reads inputs from input reader,
542    invokes user specified mapper function, and writes output with
543    output writer. It also updates shard_state accordingly.
544    e.g. if shard processing is done, set shard_state.active to False.
545
546    If errors.FailJobError is caught, it will fail this MR job.
547    All other exceptions will be logged and raised to taskqueue for retry
548    until the number of retries exceeds a limit.
549
550    Args:
551      input_reader: input reader.
552      shard_state: shard state.
553      tstate: transient shard state.
554      ctx: mapreduce context.
555
556    Returns:
557      Whether this shard has finished processing all its input split.
558    """
559    processing_limit = self._processing_limit(tstate.mapreduce_spec)
560    if processing_limit == 0:
561      return
562
563    finished_shard = True
564    # Input reader may not be an iterator. It is only a container.
565    iterator = iter(input_reader)
566
567    while True:
568      try:
569        entity = iterator.next()
570      except StopIteration:
571        break
572      # Reading input got exception. If we assume
573      # 1. The input reader have done enough retries.
574      # 2. The input reader can still serialize correctly after this exception.
575      # 3. The input reader, upon resume, will try to re-read this failed
576      #    record.
577      # 4. This exception doesn't imply the input reader is permanently stuck.
578      # we can serialize current slice immediately to avoid duplicated
579      # outputs.
580      # TODO(user): Validate these assumptions on all readers. MR should
581      # also have a way to detect fake forward progress.
582
583      if isinstance(entity, db.Model):
584        shard_state.last_work_item = repr(entity.key())
585      elif isinstance(entity, ndb.Model):
586        shard_state.last_work_item = repr(entity.key)
587      else:
588        shard_state.last_work_item = repr(entity)[:100]
589
590      processing_limit -= 1
591
592      if not self._process_datum(
593          entity, input_reader, ctx, tstate):
594        finished_shard = False
595        break
596      elif processing_limit == 0:
597        finished_shard = False
598        break
599
600    # Flush context and its pools.
601    self.slice_context.incr(
602        context.COUNTER_MAPPER_WALLTIME_MS,
603        int((self._time() - self._start_time)*1000))
604
605    return finished_shard
606
607  def _process_datum(self, data, input_reader, ctx, transient_shard_state):
608    """Process a single data piece.
609
610    Call mapper handler on the data.
611
612    Args:
613      data: a datum to process.
614      input_reader: input reader.
615      ctx: mapreduce context
616      transient_shard_state: transient shard state.
617
618    Returns:
619      True if scan should be continued, False if scan should be stopped.
620    """
621    if data is not input_readers.ALLOW_CHECKPOINT:
622      self.slice_context.incr(context.COUNTER_MAPPER_CALLS)
623
624      handler = transient_shard_state.handler
625
626      if isinstance(handler, map_job.Mapper):
627        handler(self.slice_context, data)
628      else:
629        if input_reader.expand_parameters:
630          result = handler(*data)
631        else:
632          result = handler(data)
633
634        if util.is_generator(result):
635          for output in result:
636            if isinstance(output, operation.Operation):
637              output(ctx)
638            else:
639              output_writer = transient_shard_state.output_writer
640              if not output_writer:
641                logging.warning(
642                    "Handler yielded %s, but no output writer is set.", output)
643              else:
644                output_writer.write(output)
645
646    if self._time() - self._start_time >= parameters.config._SLICE_DURATION_SEC:
647      return False
648    return True
649
650  def _set_state(self, shard_state, tstate, task_directive):
651    """Set shard_state and tstate based on task_directive.
652
653    Args:
654      shard_state: model.ShardState for current shard.
655      tstate: model.TransientShardState for current shard.
656      task_directive: self._TASK_DIRECTIVE for current shard.
657
658    Returns:
659      A _TASK_DIRECTIVE enum.
660      PROCEED_TASK if task should proceed normally.
661      RETRY_SHARD if shard should be retried.
662      RETRY_SLICE if slice should be retried.
663      FAIL_TASK if sahrd should fail.
664      RECOVER_SLICE if slice should be recovered.
665      ABORT_SHARD if shard should be aborted.
666      RETRY_TASK if task should be retried.
667      DROP_TASK if task should be dropped.
668    """
669    if task_directive in (self._TASK_DIRECTIVE.RETRY_TASK,
670                          self._TASK_DIRECTIVE.DROP_TASK):
671      return task_directive
672
673    if task_directive == self._TASK_DIRECTIVE.ABORT_SHARD:
674      shard_state.set_for_abort()
675      return task_directive
676
677    if task_directive == self._TASK_DIRECTIVE.PROCEED_TASK:
678      shard_state.advance_for_next_slice()
679      tstate.advance_for_next_slice()
680      return task_directive
681
682    if task_directive == self._TASK_DIRECTIVE.RECOVER_SLICE:
683      tstate.advance_for_next_slice(recovery_slice=True)
684      shard_state.advance_for_next_slice(recovery_slice=True)
685      return task_directive
686
687    if task_directive == self._TASK_DIRECTIVE.RETRY_SLICE:
688      task_directive = self._attempt_slice_retry(shard_state, tstate)
689    if task_directive == self._TASK_DIRECTIVE.RETRY_SHARD:
690      task_directive = self._attempt_shard_retry(shard_state, tstate)
691    if task_directive == self._TASK_DIRECTIVE.FAIL_TASK:
692      shard_state.set_for_failure()
693
694    return task_directive
695
696  def _save_state_and_schedule_next(self, shard_state, tstate, task_directive):
697    """Save state and schedule task.
698
699    Save shard state to datastore.
700    Schedule next slice if needed.
701    Set HTTP response code.
702    No modification to any shard_state or tstate.
703
704    Args:
705      shard_state: model.ShardState for current shard.
706      tstate: model.TransientShardState for current shard.
707      task_directive: enum _TASK_DIRECTIVE.
708
709    Returns:
710      The task to retry if applicable.
711    """
712    spec = tstate.mapreduce_spec
713
714    if task_directive == self._TASK_DIRECTIVE.DROP_TASK:
715      return
716    if task_directive in (self._TASK_DIRECTIVE.RETRY_SLICE,
717                          self._TASK_DIRECTIVE.RETRY_TASK):
718      # Set HTTP code to 500.
719      return self.retry_task()
720    elif task_directive == self._TASK_DIRECTIVE.ABORT_SHARD:
721      logging.info("Aborting shard %d of job '%s'",
722                   shard_state.shard_number, shard_state.mapreduce_id)
723      task = None
724    elif task_directive == self._TASK_DIRECTIVE.FAIL_TASK:
725      logging.critical("Shard %s failed permanently.", shard_state.shard_id)
726      task = None
727    elif task_directive == self._TASK_DIRECTIVE.RETRY_SHARD:
728      logging.warning("Shard %s is going to be attempted for the %s time.",
729                      shard_state.shard_id,
730                      shard_state.retries + 1)
731      task = self._state_to_task(tstate, shard_state)
732    elif task_directive == self._TASK_DIRECTIVE.RECOVER_SLICE:
733      logging.warning("Shard %s slice %s is being recovered.",
734                      shard_state.shard_id,
735                      shard_state.slice_id)
736      task = self._state_to_task(tstate, shard_state)
737    else:
738      assert task_directive == self._TASK_DIRECTIVE.PROCEED_TASK
739      countdown = self._get_countdown_for_next_slice(spec)
740      task = self._state_to_task(tstate, shard_state, countdown=countdown)
741
742    # Prepare parameters for db transaction and taskqueue.
743    queue_name = os.environ.get("HTTP_X_APPENGINE_QUEUENAME",
744                                # For test only.
745                                # TODO(user): Remove this.
746                                "default")
747    config = util.create_datastore_write_config(spec)
748
749    @db.transactional(retries=5)
750    def _tx():
751      """The Transaction helper."""
752      fresh_shard_state = model.ShardState.get_by_shard_id(tstate.shard_id)
753      if not fresh_shard_state:
754        raise db.Rollback()
755      if (not fresh_shard_state.active or
756          "worker_active_state_collision" in _TEST_INJECTED_FAULTS):
757        logging.warning("Shard %s is not active. Possible spurious task "
758                        "execution. Dropping this task.", tstate.shard_id)
759        logging.warning("Datastore's %s", str(fresh_shard_state))
760        logging.warning("Slice's %s", str(shard_state))
761        return
762      fresh_shard_state.copy_from(shard_state)
763      fresh_shard_state.put(config=config)
764      # Add task in the same datastore transaction.
765      # This way we guarantee taskqueue is never behind datastore states.
766      # Old tasks will be dropped.
767      # Future task won't run until datastore states catches up.
768      if fresh_shard_state.active:
769        # Not adding task transactionally.
770        # transactional enqueue requires tasks with no name.
771        self._add_task(task, spec, queue_name)
772
773    try:
774      _tx()
775    except (datastore_errors.Error,
776            taskqueue.Error,
777            runtime.DeadlineExceededError,
778            apiproxy_errors.Error), e:
779      logging.warning(
780          "Can't transactionally continue shard. "
781          "Will retry slice %s %s for the %s time.",
782          tstate.shard_id,
783          tstate.slice_id,
784          self.task_retry_count() + 1)
785      self._try_free_lease(shard_state)
786      raise e
787
788  def _attempt_slice_recovery(self, shard_state, tstate):
789    """Recover a slice.
790
791    This is run when a slice had been previously attempted and output
792    may have been written. If an output writer requires slice recovery,
793    we run those logic to remove output duplicates. Otherwise we just retry
794    the slice.
795
796    If recovery is needed, then the entire slice will be dedicated
797    to recovery logic. No data processing will take place. Thus we call
798    the slice "recovery slice". This is needed for correctness:
799    An output writer instance can be out of sync from its physical
800    medium only when the slice dies after acquring the shard lock but before
801    committing shard state to db. The worst failure case is when
802    shard state failed to commit after the NAMED task for the next slice was
803    added. Thus, recovery slice has a special logic to increment current
804    slice_id n to n+2. If the task for n+1 had been added, it will be dropped
805    because it is behind shard state.
806
807    Args:
808      shard_state: an instance of Model.ShardState.
809      tstate: an instance of Model.TransientShardState.
810
811    Returns:
812      _TASK_DIRECTIVE.PROCEED_TASK to continue with this retry.
813      _TASK_DIRECTIVE.RECOVER_SLICE to recover this slice.
814      The next slice will start at the same input as
815      this slice but output to a new instance of output writer.
816      Combining outputs from all writer instances is up to implementation.
817    """
818    mapper_spec = tstate.mapreduce_spec.mapper
819    if not (tstate.output_writer and
820            tstate.output_writer._supports_slice_recovery(mapper_spec)):
821      return self._TASK_DIRECTIVE.PROCEED_TASK
822
823    tstate.output_writer = tstate.output_writer._recover(
824        tstate.mapreduce_spec, shard_state.shard_number,
825        shard_state.retries + 1)
826    return self._TASK_DIRECTIVE.RECOVER_SLICE
827
828  def _attempt_shard_retry(self, shard_state, tstate):
829    """Whether to retry shard.
830
831    This method may modify shard_state and tstate to prepare for retry or fail.
832
833    Args:
834      shard_state: model.ShardState for current shard.
835      tstate: model.TransientShardState for current shard.
836
837    Returns:
838      A _TASK_DIRECTIVE enum. RETRY_SHARD if shard should be retried.
839    FAIL_TASK otherwise.
840    """
841    shard_attempts = shard_state.retries + 1
842
843    if shard_attempts >= parameters.config.SHARD_MAX_ATTEMPTS:
844      logging.warning(
845          "Shard attempt %s exceeded %s max attempts.",
846          shard_attempts, parameters.config.SHARD_MAX_ATTEMPTS)
847      return self._TASK_DIRECTIVE.FAIL_TASK
848    if tstate.output_writer and (
849        not tstate.output_writer._supports_shard_retry(tstate)):
850      logging.warning("Output writer %s does not support shard retry.",
851                      tstate.output_writer.__class__.__name__)
852      return self._TASK_DIRECTIVE.FAIL_TASK
853
854    shard_state.reset_for_retry()
855    logging.warning("Shard %s attempt %s failed with up to %s attempts.",
856                    shard_state.shard_id,
857                    shard_state.retries,
858                    parameters.config.SHARD_MAX_ATTEMPTS)
859    output_writer = None
860    if tstate.output_writer:
861      output_writer = tstate.output_writer.create(
862          tstate.mapreduce_spec, shard_state.shard_number, shard_attempts + 1)
863    tstate.reset_for_retry(output_writer)
864    return self._TASK_DIRECTIVE.RETRY_SHARD
865
866  def _attempt_slice_retry(self, shard_state, tstate):
867    """Attempt to retry this slice.
868
869    This method may modify shard_state and tstate to prepare for retry or fail.
870
871    Args:
872      shard_state: model.ShardState for current shard.
873      tstate: model.TransientShardState for current shard.
874
875    Returns:
876      A _TASK_DIRECTIVE enum. RETRY_SLICE if slice should be retried.
877    RETRY_SHARD if shard retry should be attempted.
878    """
879    if (shard_state.slice_retries + 1 <
880        parameters.config.TASK_MAX_DATA_PROCESSING_ATTEMPTS):
881      logging.warning(
882          "Slice %s %s failed for the %s of up to %s attempts "
883          "(%s of %s taskqueue execution attempts). "
884          "Will retry now.",
885          tstate.shard_id,
886          tstate.slice_id,
887          shard_state.slice_retries + 1,
888          parameters.config.TASK_MAX_DATA_PROCESSING_ATTEMPTS,
889          self.task_retry_count() + 1,
890          parameters.config.TASK_MAX_ATTEMPTS)
891      # Clear info related to current exception. Otherwise, the real
892      # callstack that includes a frame for this method will show up
893      # in log.
894      sys.exc_clear()
895      self._try_free_lease(shard_state, slice_retry=True)
896      return self._TASK_DIRECTIVE.RETRY_SLICE
897
898    if parameters.config.TASK_MAX_DATA_PROCESSING_ATTEMPTS > 0:
899      logging.warning("Slice attempt %s exceeded %s max attempts.",
900                      self.task_retry_count() + 1,
901                      parameters.config.TASK_MAX_DATA_PROCESSING_ATTEMPTS)
902    return self._TASK_DIRECTIVE.RETRY_SHARD
903
904  @staticmethod
905  def get_task_name(shard_id, slice_id, retry=0):
906    """Compute single worker task name.
907
908    Args:
909      shard_id: shard id.
910      slice_id: slice id.
911      retry: current shard retry count.
912
913    Returns:
914      task name which should be used to process specified shard/slice.
915    """
916    # Prefix the task name with something unique to this framework's
917    # namespace so we don't conflict with user tasks on the queue.
918    return "appengine-mrshard-%s-%s-retry-%s" % (
919        shard_id, slice_id, retry)
920
921  def _get_countdown_for_next_slice(self, spec):
922    """Get countdown for next slice's task.
923
924    When user sets processing rate, we set countdown to delay task execution.
925
926    Args:
927      spec: model.MapreduceSpec
928
929    Returns:
930      countdown in int.
931    """
932    countdown = 0
933    if self._processing_limit(spec) != -1:
934      countdown = max(
935          int(parameters.config._SLICE_DURATION_SEC -
936              (self._time() - self._start_time)), 0)
937    return countdown
938
939  @classmethod
940  def _state_to_task(cls,
941                     tstate,
942                     shard_state,
943                     eta=None,
944                     countdown=None):
945    """Generate task for slice according to current states.
946
947    Args:
948      tstate: An instance of TransientShardState.
949      shard_state: An instance of ShardState.
950      eta: Absolute time when the MR should execute. May not be specified
951        if 'countdown' is also supplied. This may be timezone-aware or
952        timezone-naive.
953      countdown: Time in seconds into the future that this MR should execute.
954        Defaults to zero.
955
956    Returns:
957      A model.HugeTask instance for the slice specified by current states.
958    """
959    base_path = tstate.base_path
960
961    task_name = MapperWorkerCallbackHandler.get_task_name(
962        tstate.shard_id,
963        tstate.slice_id,
964        tstate.retries)
965
966    headers = util._get_task_headers(tstate.mapreduce_spec.mapreduce_id)
967    headers[util._MR_SHARD_ID_TASK_HEADER] = tstate.shard_id
968
969    worker_task = model.HugeTask(
970        url=base_path + "/worker_callback/" + tstate.shard_id,
971        params=tstate.to_dict(),
972        name=task_name,
973        eta=eta,
974        countdown=countdown,
975        parent=shard_state,
976        headers=headers)
977    return worker_task
978
979  @classmethod
980  def _add_task(cls,
981                worker_task,
982                mapreduce_spec,
983                queue_name):
984    """Schedule slice scanning by adding it to the task queue.
985
986    Args:
987      worker_task: a model.HugeTask task for slice. This is NOT a taskqueue
988        task.
989      mapreduce_spec: an instance of model.MapreduceSpec.
990      queue_name: Optional queue to run on; uses the current queue of
991        execution or the default queue if unspecified.
992    """
993    if not _run_task_hook(mapreduce_spec.get_hooks(),
994                          "enqueue_worker_task",
995                          worker_task,
996                          queue_name):
997      try:
998        # Not adding transactionally because worker_task has name.
999        # Named task is not allowed for transactional add.
1000        worker_task.add(queue_name)
1001      except (taskqueue.TombstonedTaskError,
1002              taskqueue.TaskAlreadyExistsError), e:
1003        logging.warning("Task %r already exists. %s: %s",
1004                        worker_task.name,
1005                        e.__class__,
1006                        e)
1007
1008  def _processing_limit(self, spec):
1009    """Get the limit on the number of map calls allowed by this slice.
1010
1011    Args:
1012      spec: a Mapreduce spec.
1013
1014    Returns:
1015      The limit as a positive int if specified by user. -1 otherwise.
1016    """
1017    processing_rate = float(spec.mapper.params.get("processing_rate", 0))
1018    slice_processing_limit = -1
1019    if processing_rate > 0:
1020      slice_processing_limit = int(math.ceil(
1021          parameters.config._SLICE_DURATION_SEC*processing_rate/
1022          int(spec.mapper.shard_count)))
1023    return slice_processing_limit
1024
1025  # Deprecated. Only used by old test cases.
1026  # TODO(user): clean up tests.
1027  @classmethod
1028  def _schedule_slice(cls,
1029                      shard_state,
1030                      tstate,
1031                      queue_name=None,
1032                      eta=None,
1033                      countdown=None):
1034    """Schedule slice scanning by adding it to the task queue.
1035
1036    Args:
1037      shard_state: An instance of ShardState.
1038      tstate: An instance of TransientShardState.
1039      queue_name: Optional queue to run on; uses the current queue of
1040        execution or the default queue if unspecified.
1041      eta: Absolute time when the MR should execute. May not be specified
1042        if 'countdown' is also supplied. This may be timezone-aware or
1043        timezone-naive.
1044      countdown: Time in seconds into the future that this MR should execute.
1045        Defaults to zero.
1046    """
1047    queue_name = queue_name or os.environ.get("HTTP_X_APPENGINE_QUEUENAME",
1048                                              "default")
1049    task = cls._state_to_task(tstate, shard_state, eta, countdown)
1050    cls._add_task(task, tstate.mapreduce_spec, queue_name)
1051
1052
1053class ControllerCallbackHandler(base_handler.HugeTaskHandler):
1054  """Supervises mapreduce execution.
1055
1056  Is also responsible for gathering execution status from shards together.
1057
1058  This task is "continuously" running by adding itself again to taskqueue if
1059  and only if mapreduce is still active. A mapreduce is active if it has
1060  actively running shards.
1061  """
1062
1063  def __init__(self, *args):
1064    """Constructor."""
1065    super(ControllerCallbackHandler, self).__init__(*args)
1066    self._time = time.time
1067
1068  def _drop_gracefully(self):
1069    """Gracefully drop controller task.
1070
1071    This method is called when decoding controller task payload failed.
1072    Upon this we mark ShardState and MapreduceState as failed so all
1073    tasks can stop.
1074
1075    Writing to datastore is forced (ignore read-only mode) because we
1076    want the tasks to stop badly, and if force_writes was False,
1077    the job would have never been started.
1078    """
1079    mr_id = self.request.headers[util._MR_ID_TASK_HEADER]
1080    state = model.MapreduceState.get_by_job_id(mr_id)
1081    if not state or not state.active:
1082      return
1083
1084    state.active = False
1085    state.result_status = model.MapreduceState.RESULT_FAILED
1086    config = util.create_datastore_write_config(state.mapreduce_spec)
1087    puts = []
1088    for ss in model.ShardState.find_all_by_mapreduce_state(state):
1089      if ss.active:
1090        ss.set_for_failure()
1091        puts.append(ss)
1092        # Avoid having too many shard states in memory.
1093        if len(puts) > model.ShardState._MAX_STATES_IN_MEMORY:
1094          db.put(puts, config=config)
1095          puts = []
1096    db.put(puts, config=config)
1097    # Put mr_state only after all shard_states are put.
1098    db.put(state, config=config)
1099
1100  def handle(self):
1101    """Handle request."""
1102    spec = model.MapreduceSpec.from_json_str(
1103        self.request.get("mapreduce_spec"))
1104    state, control = db.get([
1105        model.MapreduceState.get_key_by_job_id(spec.mapreduce_id),
1106        model.MapreduceControl.get_key_by_job_id(spec.mapreduce_id),
1107    ])
1108
1109    if not state:
1110      logging.warning("State not found for MR '%s'; dropping controller task.",
1111                      spec.mapreduce_id)
1112      return
1113    if not state.active:
1114      logging.warning(
1115          "MR %r is not active. Looks like spurious controller task execution.",
1116          spec.mapreduce_id)
1117      self._clean_up_mr(spec)
1118      return
1119
1120    shard_states = model.ShardState.find_all_by_mapreduce_state(state)
1121    self._update_state_from_shard_states(state, shard_states, control)
1122
1123    if state.active:
1124      ControllerCallbackHandler.reschedule(
1125          state, spec, self.serial_id() + 1)
1126
1127  def _update_state_from_shard_states(self, state, shard_states, control):
1128    """Update mr state by examing shard states.
1129
1130    Args:
1131      state: current mapreduce state as MapreduceState.
1132      shard_states: an iterator over shard states.
1133      control: model.MapreduceControl entity.
1134    """
1135    # Initialize vars.
1136    state.active_shards, state.aborted_shards, state.failed_shards = 0, 0, 0
1137    total_shards = 0
1138    processed_counts = []
1139    processed_status = []
1140    state.counters_map.clear()
1141
1142    # Tally across shard states once.
1143    for s in shard_states:
1144      total_shards += 1
1145      status = 'unknown'
1146      if s.active:
1147        state.active_shards += 1
1148        status = 'running'
1149      if s.result_status == model.ShardState.RESULT_SUCCESS:
1150        status = 'success'
1151      elif s.result_status == model.ShardState.RESULT_ABORTED:
1152        state.aborted_shards += 1
1153        status = 'aborted'
1154      elif s.result_status == model.ShardState.RESULT_FAILED:
1155        state.failed_shards += 1
1156        status = 'failed'
1157
1158      # Update stats in mapreduce state by aggregating stats from shard states.
1159      state.counters_map.add_map(s.counters_map)
1160      processed_counts.append(s.counters_map.get(context.COUNTER_MAPPER_CALLS))
1161      processed_status.append(status)
1162
1163    state.set_processed_counts(processed_counts, processed_status)
1164    state.last_poll_time = datetime.datetime.utcfromtimestamp(self._time())
1165
1166    spec = state.mapreduce_spec
1167
1168    if total_shards != spec.mapper.shard_count:
1169      logging.error("Found %d shard states. Expect %d. "
1170                    "Issuing abort command to job '%s'",
1171                    total_shards, spec.mapper.shard_count,
1172                    spec.mapreduce_id)
1173      # We issue abort command to allow shards to stop themselves.
1174      model.MapreduceControl.abort(spec.mapreduce_id)
1175
1176    # If any shard is active then the mr is active.
1177    # This way, controller won't prematurely stop before all the shards have.
1178    state.active = bool(state.active_shards)
1179    if not control and (state.failed_shards or state.aborted_shards):
1180      # Issue abort command if there are failed shards.
1181      model.MapreduceControl.abort(spec.mapreduce_id)
1182
1183    if not state.active:
1184      # Set final result status derived from shard states.
1185      if state.failed_shards or not total_shards:
1186        state.result_status = model.MapreduceState.RESULT_FAILED
1187      # It's important failed shards is checked before aborted shards
1188      # because failed shards will trigger other shards to abort.
1189      elif state.aborted_shards:
1190        state.result_status = model.MapreduceState.RESULT_ABORTED
1191      else:
1192        state.result_status = model.MapreduceState.RESULT_SUCCESS
1193      self._finalize_outputs(spec, state)
1194      self._finalize_job(spec, state)
1195    else:
1196      @db.transactional(retries=5)
1197      def _put_state():
1198        """The helper for storing the state."""
1199        fresh_state = model.MapreduceState.get_by_job_id(spec.mapreduce_id)
1200        # We don't check anything other than active because we are only
1201        # updating stats. It's OK if they are briefly inconsistent.
1202        if not fresh_state.active:
1203          logging.warning(
1204              "Job %s is not active. Looks like spurious task execution. "
1205              "Dropping controller task.", spec.mapreduce_id)
1206          return
1207        config = util.create_datastore_write_config(spec)
1208        state.put(config=config)
1209
1210      _put_state()
1211
1212  def serial_id(self):
1213    """Get serial unique identifier of this task from request.
1214
1215    Returns:
1216      serial identifier as int.
1217    """
1218    return int(self.request.get("serial_id"))
1219
1220  @classmethod
1221  def _finalize_outputs(cls, mapreduce_spec, mapreduce_state):
1222    """Finalize outputs.
1223
1224    Args:
1225      mapreduce_spec: an instance of MapreduceSpec.
1226      mapreduce_state: an instance of MapreduceState.
1227    """
1228    # Only finalize the output writers if the job is successful.
1229    if (mapreduce_spec.mapper.output_writer_class() and
1230        mapreduce_state.result_status == model.MapreduceState.RESULT_SUCCESS):
1231      mapreduce_spec.mapper.output_writer_class().finalize_job(mapreduce_state)
1232
1233  @classmethod
1234  def _finalize_job(cls, mapreduce_spec, mapreduce_state):
1235    """Finalize job execution.
1236
1237    Invokes done callback and save mapreduce state in a transaction,
1238    and schedule necessary clean ups. This method is idempotent.
1239
1240    Args:
1241      mapreduce_spec: an instance of MapreduceSpec
1242      mapreduce_state: an instance of MapreduceState
1243    """
1244    config = util.create_datastore_write_config(mapreduce_spec)
1245    queue_name = util.get_queue_name(mapreduce_spec.params.get(
1246        model.MapreduceSpec.PARAM_DONE_CALLBACK_QUEUE))
1247    done_callback = mapreduce_spec.params.get(
1248        model.MapreduceSpec.PARAM_DONE_CALLBACK)
1249    done_task = None
1250    if done_callback:
1251      done_task = taskqueue.Task(
1252          url=done_callback,
1253          headers=util._get_task_headers(mapreduce_spec.mapreduce_id,
1254                                         util.CALLBACK_MR_ID_TASK_HEADER),
1255          method=mapreduce_spec.params.get("done_callback_method", "POST"))
1256
1257    @db.transactional(retries=5)
1258    def _put_state():
1259      """Helper to store state."""
1260      fresh_state = model.MapreduceState.get_by_job_id(
1261          mapreduce_spec.mapreduce_id)
1262      if not fresh_state.active:
1263        logging.warning(
1264            "Job %s is not active. Looks like spurious task execution. "
1265            "Dropping task.", mapreduce_spec.mapreduce_id)
1266        return
1267      mapreduce_state.put(config=config)
1268      # Enqueue done_callback if needed.
1269      if done_task and not _run_task_hook(
1270          mapreduce_spec.get_hooks(),
1271          "enqueue_done_task",
1272          done_task,
1273          queue_name):
1274        done_task.add(queue_name, transactional=True)
1275
1276    _put_state()
1277    logging.info("Final result for job '%s' is '%s'",
1278                 mapreduce_spec.mapreduce_id, mapreduce_state.result_status)
1279    cls._clean_up_mr(mapreduce_spec)
1280
1281  @classmethod
1282  def _clean_up_mr(cls, mapreduce_spec):
1283    FinalizeJobHandler.schedule(mapreduce_spec)
1284
1285  @staticmethod
1286  def get_task_name(mapreduce_spec, serial_id):
1287    """Compute single controller task name.
1288
1289    Args:
1290      mapreduce_spec: specification of the mapreduce.
1291      serial_id: id of the invocation as int.
1292
1293    Returns:
1294      task name which should be used to process specified shard/slice.
1295    """
1296    # Prefix the task name with something unique to this framework's
1297    # namespace so we don't conflict with user tasks on the queue.
1298    return "appengine-mrcontrol-%s-%s" % (
1299        mapreduce_spec.mapreduce_id, serial_id)
1300
1301  @staticmethod
1302  def controller_parameters(mapreduce_spec, serial_id):
1303    """Fill in  controller task parameters.
1304
1305    Returned parameters map is to be used as task payload, and it contains
1306    all the data, required by controller to perform its function.
1307
1308    Args:
1309      mapreduce_spec: specification of the mapreduce.
1310      serial_id: id of the invocation as int.
1311
1312    Returns:
1313      string->string map of parameters to be used as task payload.
1314    """
1315    return {"mapreduce_spec": mapreduce_spec.to_json_str(),
1316            "serial_id": str(serial_id)}
1317
1318  @classmethod
1319  def reschedule(cls,
1320                 mapreduce_state,
1321                 mapreduce_spec,
1322                 serial_id,
1323                 queue_name=None):
1324    """Schedule new update status callback task.
1325
1326    Args:
1327      mapreduce_state: mapreduce state as model.MapreduceState
1328      mapreduce_spec: mapreduce specification as MapreduceSpec.
1329      serial_id: id of the invocation as int.
1330      queue_name: The queue to schedule this task on. Will use the current
1331        queue of execution if not supplied.
1332    """
1333    task_name = ControllerCallbackHandler.get_task_name(
1334        mapreduce_spec, serial_id)
1335    task_params = ControllerCallbackHandler.controller_parameters(
1336        mapreduce_spec, serial_id)
1337    if not queue_name:
1338      queue_name = os.environ.get("HTTP_X_APPENGINE_QUEUENAME", "default")
1339
1340    controller_callback_task = model.HugeTask(
1341        url=(mapreduce_spec.params["base_path"] + "/controller_callback/" +
1342             mapreduce_spec.mapreduce_id),
1343        name=task_name, params=task_params,
1344        countdown=parameters.config._CONTROLLER_PERIOD_SEC,
1345        parent=mapreduce_state,
1346        headers=util._get_task_headers(mapreduce_spec.mapreduce_id))
1347
1348    if not _run_task_hook(mapreduce_spec.get_hooks(),
1349                          "enqueue_controller_task",
1350                          controller_callback_task,
1351                          queue_name):
1352      try:
1353        controller_callback_task.add(queue_name)
1354      except (taskqueue.TombstonedTaskError,
1355              taskqueue.TaskAlreadyExistsError), e:
1356        logging.warning("Task %r with params %r already exists. %s: %s",
1357                        task_name, task_params, e.__class__, e)
1358
1359
1360class KickOffJobHandler(base_handler.TaskQueueHandler):
1361  """Taskqueue handler which kicks off a mapreduce processing.
1362
1363  This handler is idempotent.
1364
1365  Precondition:
1366    The Model.MapreduceState entity for this mr is already created and
1367    saved to datastore by StartJobHandler._start_map.
1368
1369  Request Parameters:
1370    mapreduce_id: in string.
1371  """
1372
1373  # Datastore key used to save json serialized input readers.
1374  _SERIALIZED_INPUT_READERS_KEY = "input_readers_for_mr_%s"
1375
1376  def handle(self):
1377    """Handles kick off request."""
1378    # Get and verify mr state.
1379    mr_id = self.request.get("mapreduce_id")
1380    # Log the mr_id since this is started in an unnamed task
1381    logging.info("Processing kickoff for job %s", mr_id)
1382    state = model.MapreduceState.get_by_job_id(mr_id)
1383    if not self._check_mr_state(state, mr_id):
1384      return
1385
1386    # Create input readers.
1387    readers, serialized_readers_entity = self._get_input_readers(state)
1388    if readers is None:
1389      # We don't have any data. Finish map.
1390      logging.warning("Found no mapper input data to process.")
1391      state.active = False
1392      state.result_status = model.MapreduceState.RESULT_SUCCESS
1393      ControllerCallbackHandler._finalize_job(
1394          state.mapreduce_spec, state)
1395      return False
1396
1397    # Create output writers.
1398    self._setup_output_writer(state)
1399
1400    # Save states and make sure we use the saved input readers for
1401    # subsequent operations.
1402    result = self._save_states(state, serialized_readers_entity)
1403    if result is None:
1404      readers, _ = self._get_input_readers(state)
1405    elif not result:
1406      return
1407
1408    queue_name = self.request.headers.get("X-AppEngine-QueueName")
1409    KickOffJobHandler._schedule_shards(state.mapreduce_spec, readers,
1410                                       queue_name,
1411                                       state.mapreduce_spec.params["base_path"],
1412                                       state)
1413
1414    ControllerCallbackHandler.reschedule(
1415        state, state.mapreduce_spec, serial_id=0, queue_name=queue_name)
1416
1417  def _drop_gracefully(self):
1418    """See parent."""
1419    mr_id = self.request.get("mapreduce_id")
1420    logging.error("Failed to kick off job %s", mr_id)
1421
1422    state = model.MapreduceState.get_by_job_id(mr_id)
1423    if not self._check_mr_state(state, mr_id):
1424      return
1425
1426    # Issue abort command just in case there are running tasks.
1427    config = util.create_datastore_write_config(state.mapreduce_spec)
1428    model.MapreduceControl.abort(mr_id, config=config)
1429
1430    # Finalize job and invoke callback.
1431    state.active = False
1432    state.result_status = model.MapreduceState.RESULT_FAILED
1433    ControllerCallbackHandler._finalize_job(state.mapreduce_spec, state)
1434
1435  def _get_input_readers(self, state):
1436    """Get input readers.
1437
1438    Args:
1439      state: a MapreduceState model.
1440
1441    Returns:
1442      A tuple: (a list of input readers, a model._HugeTaskPayload entity).
1443    The payload entity contains the json serialized input readers.
1444    (None, None) when input reader inplitting returned no data to process.
1445    """
1446    serialized_input_readers_key = (self._SERIALIZED_INPUT_READERS_KEY %
1447                                    state.key().id_or_name())
1448    serialized_input_readers = model._HugeTaskPayload.get_by_key_name(
1449        serialized_input_readers_key, parent=state)
1450
1451    # Initialize input readers.
1452    input_reader_class = state.mapreduce_spec.mapper.input_reader_class()
1453    split_param = state.mapreduce_spec.mapper
1454    if issubclass(input_reader_class, map_job.InputReader):
1455      split_param = map_job.JobConfig._to_map_job_config(
1456          state.mapreduce_spec,
1457          os.environ.get("HTTP_X_APPENGINE_QUEUENAME"))
1458    if serialized_input_readers is None:
1459      readers = input_reader_class.split_input(split_param)
1460    else:
1461      readers = [input_reader_class.from_json_str(_json) for _json in
1462                 json.loads(zlib.decompress(
1463                 serialized_input_readers.payload))]
1464
1465    if not readers:
1466      return None, None
1467
1468    # Update state and spec with actual shard count.
1469    state.mapreduce_spec.mapper.shard_count = len(readers)
1470    state.active_shards = len(readers)
1471
1472    # Prepare to save serialized input readers.
1473    if serialized_input_readers is None:
1474      # Use mr_state as parent so it can be easily cleaned up later.
1475      serialized_input_readers = model._HugeTaskPayload(
1476          key_name=serialized_input_readers_key, parent=state)
1477      readers_json_str = [i.to_json_str() for i in readers]
1478      serialized_input_readers.payload = zlib.compress(json.dumps(
1479                                                       readers_json_str))
1480    return readers, serialized_input_readers
1481
1482  def _setup_output_writer(self, state):
1483    if not state.writer_state:
1484      output_writer_class = state.mapreduce_spec.mapper.output_writer_class()
1485      if output_writer_class:
1486        output_writer_class.init_job(state)
1487
1488  @db.transactional
1489  def _save_states(self, state, serialized_readers_entity):
1490    """Run transaction to save state.
1491
1492    Args:
1493      state: a model.MapreduceState entity.
1494      serialized_readers_entity: a model._HugeTaskPayload entity containing
1495        json serialized input readers.
1496
1497    Returns:
1498      False if a fatal error is encountered and this task should be dropped
1499    immediately. True if transaction is successful. None if a previous
1500    attempt of this same transaction has already succeeded.
1501    """
1502    mr_id = state.key().id_or_name()
1503    fresh_state = model.MapreduceState.get_by_job_id(mr_id)
1504    if not self._check_mr_state(fresh_state, mr_id):
1505      return False
1506    if fresh_state.active_shards != 0:
1507      logging.warning(
1508          "Mapreduce %s already has active shards. Looks like spurious task "
1509          "execution.", mr_id)
1510      return None
1511    config = util.create_datastore_write_config(state.mapreduce_spec)
1512    db.put([state, serialized_readers_entity], config=config)
1513    return True
1514
1515  @classmethod
1516  def _schedule_shards(cls,
1517                       spec,
1518                       readers,
1519                       queue_name,
1520                       base_path,
1521                       mr_state):
1522    """Prepares shard states and schedules their execution.
1523
1524    Even though this method does not schedule shard task and save shard state
1525    transactionally, it's safe for taskqueue to retry this logic because
1526    the initial shard_state for each shard is the same from any retry.
1527    This is an important yet reasonable assumption on model.ShardState.
1528
1529    Args:
1530      spec: mapreduce specification as MapreduceSpec.
1531      readers: list of InputReaders describing shard splits.
1532      queue_name: The queue to run this job on.
1533      base_path: The base url path of mapreduce callbacks.
1534      mr_state: The MapReduceState of current job.
1535    """
1536    # Create shard states.
1537    shard_states = []
1538    for shard_number, input_reader in enumerate(readers):
1539      shard_state = model.ShardState.create_new(spec.mapreduce_id, shard_number)
1540      shard_state.shard_description = str(input_reader)
1541      shard_states.append(shard_state)
1542
1543    # Retrieves already existing shard states.
1544    existing_shard_states = db.get(shard.key() for shard in shard_states)
1545    existing_shard_keys = set(shard.key() for shard in existing_shard_states
1546                              if shard is not None)
1547
1548    # Save non existent shard states.
1549    # Note: we could do this transactionally if necessary.
1550    db.put((shard for shard in shard_states
1551            if shard.key() not in existing_shard_keys),
1552           config=util.create_datastore_write_config(spec))
1553
1554    # Create output writers.
1555    writer_class = spec.mapper.output_writer_class()
1556    writers = [None] * len(readers)
1557    if writer_class:
1558      for shard_number, shard_state in enumerate(shard_states):
1559        writers[shard_number] = writer_class.create(
1560            mr_state.mapreduce_spec,
1561            shard_state.shard_number, shard_state.retries + 1,
1562            mr_state.writer_state)
1563
1564    # Schedule ALL shard tasks.
1565    # Since each task is named, _add_task will fall back gracefully if a
1566    # task already exists.
1567    for shard_number, (input_reader, output_writer) in enumerate(
1568        zip(readers, writers)):
1569      shard_id = model.ShardState.shard_id_from_number(
1570          spec.mapreduce_id, shard_number)
1571      task = MapperWorkerCallbackHandler._state_to_task(
1572          model.TransientShardState(
1573              base_path, spec, shard_id, 0, input_reader, input_reader,
1574              output_writer=output_writer,
1575              handler=spec.mapper.handler),
1576          shard_states[shard_number])
1577      MapperWorkerCallbackHandler._add_task(task,
1578                                            spec,
1579                                            queue_name)
1580
1581  @classmethod
1582  def _check_mr_state(cls, state, mr_id):
1583    """Check MapreduceState.
1584
1585    Args:
1586      state: an MapreduceState instance.
1587      mr_id: mapreduce id.
1588
1589    Returns:
1590      True if state is valid. False if not and this task should be dropped.
1591    """
1592    if state is None:
1593      logging.warning(
1594          "Mapreduce State for job %s is missing. Dropping Task.",
1595          mr_id)
1596      return False
1597    if not state.active:
1598      logging.warning(
1599          "Mapreduce %s is not active. Looks like spurious task "
1600          "execution. Dropping Task.", mr_id)
1601      return False
1602    return True
1603
1604
1605class StartJobHandler(base_handler.PostJsonHandler):
1606  """Command handler starts a mapreduce job.
1607
1608  This handler allows user to start a mr via a web form. It's _start_map
1609  method can also be used independently to start a mapreduce.
1610  """
1611
1612  def handle(self):
1613    """Handles start request."""
1614    # Mapper spec as form arguments.
1615    mapreduce_name = self._get_required_param("name")
1616    mapper_input_reader_spec = self._get_required_param("mapper_input_reader")
1617    mapper_handler_spec = self._get_required_param("mapper_handler")
1618    mapper_output_writer_spec = self.request.get("mapper_output_writer")
1619    mapper_params = self._get_params(
1620        "mapper_params_validator", "mapper_params.")
1621    params = self._get_params(
1622        "params_validator", "params.")
1623
1624    # Default values.
1625    mr_params = map_job.JobConfig._get_default_mr_params()
1626    mr_params.update(params)
1627    if "queue_name" in mapper_params:
1628      mr_params["queue_name"] = mapper_params["queue_name"]
1629
1630    # Set some mapper param defaults if not present.
1631    mapper_params["processing_rate"] = int(mapper_params.get(
1632        "processing_rate") or parameters.config.PROCESSING_RATE_PER_SEC)
1633
1634    # Validate the Mapper spec, handler, and input reader.
1635    mapper_spec = model.MapperSpec(
1636        mapper_handler_spec,
1637        mapper_input_reader_spec,
1638        mapper_params,
1639        int(mapper_params.get("shard_count", parameters.config.SHARD_COUNT)),
1640        output_writer_spec=mapper_output_writer_spec)
1641
1642    mapreduce_id = self._start_map(
1643        mapreduce_name,
1644        mapper_spec,
1645        mr_params,
1646        queue_name=mr_params["queue_name"],
1647        _app=mapper_params.get("_app"))
1648    self.json_response["mapreduce_id"] = mapreduce_id
1649
1650  def _get_params(self, validator_parameter, name_prefix):
1651    """Retrieves additional user-supplied params for the job and validates them.
1652
1653    Args:
1654      validator_parameter: name of the request parameter which supplies
1655        validator for this parameter set.
1656      name_prefix: common prefix for all parameter names in the request.
1657
1658    Raises:
1659      Any exception raised by the 'params_validator' request parameter if
1660      the params fail to validate.
1661
1662    Returns:
1663      The user parameters.
1664    """
1665    params_validator = self.request.get(validator_parameter)
1666
1667    user_params = {}
1668    for key in self.request.arguments():
1669      if key.startswith(name_prefix):
1670        values = self.request.get_all(key)
1671        adjusted_key = key[len(name_prefix):]
1672        if len(values) == 1:
1673          user_params[adjusted_key] = values[0]
1674        else:
1675          user_params[adjusted_key] = values
1676
1677    if params_validator:
1678      resolved_validator = util.for_name(params_validator)
1679      resolved_validator(user_params)
1680
1681    return user_params
1682
1683  def _get_required_param(self, param_name):
1684    """Get a required request parameter.
1685
1686    Args:
1687      param_name: name of request parameter to fetch.
1688
1689    Returns:
1690      parameter value
1691
1692    Raises:
1693      errors.NotEnoughArgumentsError: if parameter is not specified.
1694    """
1695    value = self.request.get(param_name)
1696    if not value:
1697      raise errors.NotEnoughArgumentsError(param_name + " not specified")
1698    return value
1699
1700  @classmethod
1701  def _start_map(cls,
1702                 name,
1703                 mapper_spec,
1704                 mapreduce_params,
1705                 queue_name,
1706                 eta=None,
1707                 countdown=None,
1708                 hooks_class_name=None,
1709                 _app=None,
1710                 in_xg_transaction=False):
1711    # pylint: disable=g-doc-args
1712    # pylint: disable=g-doc-return-or-yield
1713    """See control.start_map.
1714
1715    Requirements for this method:
1716    1. The request that invokes this method can either be regular or
1717       from taskqueue. So taskqueue specific headers can not be used.
1718    2. Each invocation transactionally starts an isolated mapreduce job with
1719       a unique id. MapreduceState should be immediately available after
1720       returning. See control.start_map's doc on transactional.
1721    3. Method should be lightweight.
1722    """
1723    # Validate input reader.
1724    mapper_input_reader_class = mapper_spec.input_reader_class()
1725    mapper_input_reader_class.validate(mapper_spec)
1726
1727    # Validate output writer.
1728    mapper_output_writer_class = mapper_spec.output_writer_class()
1729    if mapper_output_writer_class:
1730      mapper_output_writer_class.validate(mapper_spec)
1731
1732    # Create a new id and mr spec.
1733    mapreduce_id = model.MapreduceState.new_mapreduce_id()
1734    mapreduce_spec = model.MapreduceSpec(
1735        name,
1736        mapreduce_id,
1737        mapper_spec.to_json(),
1738        mapreduce_params,
1739        hooks_class_name)
1740
1741    # Validate mapper handler.
1742    ctx = context.Context(mapreduce_spec, None)
1743    context.Context._set(ctx)
1744    try:
1745      # pylint: disable=pointless-statement
1746      mapper_spec.handler
1747    finally:
1748      context.Context._set(None)
1749
1750    # Save states and enqueue task.
1751    if in_xg_transaction:
1752      propagation = db.MANDATORY
1753    else:
1754      propagation = db.INDEPENDENT
1755
1756    @db.transactional(propagation=propagation)
1757    def _txn():
1758      cls._create_and_save_state(mapreduce_spec, _app)
1759      cls._add_kickoff_task(mapreduce_params["base_path"], mapreduce_spec, eta,
1760                            countdown, queue_name)
1761    _txn()
1762
1763    return mapreduce_id
1764
1765  @classmethod
1766  def _create_and_save_state(cls, mapreduce_spec, _app):
1767    """Save mapreduce state to datastore.
1768
1769    Save state to datastore so that UI can see it immediately.
1770
1771    Args:
1772      mapreduce_spec: model.MapreduceSpec,
1773      _app: app id if specified. None otherwise.
1774
1775    Returns:
1776      The saved Mapreduce state.
1777    """
1778    state = model.MapreduceState.create_new(mapreduce_spec.mapreduce_id)
1779    state.mapreduce_spec = mapreduce_spec
1780    state.active = True
1781    state.active_shards = 0
1782    if _app:
1783      state.app_id = _app
1784    config = util.create_datastore_write_config(mapreduce_spec)
1785    state.put(config=config)
1786    return state
1787
1788  @classmethod
1789  def _add_kickoff_task(cls,
1790                        base_path,
1791                        mapreduce_spec,
1792                        eta,
1793                        countdown,
1794                        queue_name):
1795    """Enqueues a new kickoff task."""
1796    params = {"mapreduce_id": mapreduce_spec.mapreduce_id}
1797    # Task is not named so that it can be added within a transaction.
1798    kickoff_task = taskqueue.Task(
1799        url=base_path + "/kickoffjob_callback/" + mapreduce_spec.mapreduce_id,
1800        headers=util._get_task_headers(mapreduce_spec.mapreduce_id),
1801        params=params,
1802        eta=eta,
1803        countdown=countdown)
1804    hooks = mapreduce_spec.get_hooks()
1805    if hooks is not None:
1806      try:
1807        hooks.enqueue_kickoff_task(kickoff_task, queue_name)
1808        return
1809      except NotImplementedError:
1810        pass
1811    kickoff_task.add(queue_name, transactional=True)
1812
1813
1814class FinalizeJobHandler(base_handler.TaskQueueHandler):
1815  """Finalize map job by deleting all temporary entities."""
1816
1817  def handle(self):
1818    mapreduce_id = self.request.get("mapreduce_id")
1819    mapreduce_state = model.MapreduceState.get_by_job_id(mapreduce_id)
1820    if mapreduce_state:
1821      config = (
1822          util.create_datastore_write_config(mapreduce_state.mapreduce_spec))
1823      keys = [model.MapreduceControl.get_key_by_job_id(mapreduce_id)]
1824      for ss in model.ShardState.find_all_by_mapreduce_state(mapreduce_state):
1825        keys.extend(list(
1826            model._HugeTaskPayload.all().ancestor(ss).run(keys_only=True)))
1827      keys.extend(list(model._HugeTaskPayload.all().ancestor(
1828          mapreduce_state).run(keys_only=True)))
1829      db.delete(keys, config=config)
1830
1831  @classmethod
1832  def schedule(cls, mapreduce_spec):
1833    """Schedule finalize task.
1834
1835    Args:
1836      mapreduce_spec: mapreduce specification as MapreduceSpec.
1837    """
1838    task_name = mapreduce_spec.mapreduce_id + "-finalize"
1839    finalize_task = taskqueue.Task(
1840        name=task_name,
1841        url=(mapreduce_spec.params["base_path"] + "/finalizejob_callback/" +
1842             mapreduce_spec.mapreduce_id),
1843        params={"mapreduce_id": mapreduce_spec.mapreduce_id},
1844        headers=util._get_task_headers(mapreduce_spec.mapreduce_id))
1845    queue_name = util.get_queue_name(None)
1846    if not _run_task_hook(mapreduce_spec.get_hooks(),
1847                          "enqueue_controller_task",
1848                          finalize_task,
1849                          queue_name):
1850      try:
1851        finalize_task.add(queue_name)
1852      except (taskqueue.TombstonedTaskError,
1853              taskqueue.TaskAlreadyExistsError), e:
1854        logging.warning("Task %r already exists. %s: %s",
1855                        task_name, e.__class__, e)
1856
1857
1858class CleanUpJobHandler(base_handler.PostJsonHandler):
1859  """Command to kick off tasks to clean up a job's data."""
1860
1861  def handle(self):
1862    mapreduce_id = self.request.get("mapreduce_id")
1863
1864    mapreduce_state = model.MapreduceState.get_by_job_id(mapreduce_id)
1865    if mapreduce_state:
1866      shard_keys = model.ShardState.calculate_keys_by_mapreduce_state(
1867          mapreduce_state)
1868      db.delete(shard_keys)
1869      db.delete(mapreduce_state)
1870    self.json_response["status"] = ("Job %s successfully cleaned up." %
1871                                    mapreduce_id)
1872
1873
1874class AbortJobHandler(base_handler.PostJsonHandler):
1875  """Command to abort a running job."""
1876
1877  def handle(self):
1878    model.MapreduceControl.abort(self.request.get("mapreduce_id"))
1879    self.json_response["status"] = "Abort signal sent."
1880