1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Utility functions for use with the Google App Engine Pipeline API."""
18
19__all__ = ["for_name",
20           "JsonEncoder",
21           "JsonDecoder"]
22
23#pylint: disable=g-bad-name
24
25import datetime
26import inspect
27import logging
28import os
29
30try:
31  import json
32except ImportError:
33  import simplejson as json
34
35# pylint: disable=protected-access
36
37
38def _get_task_target():
39  """Get the default target for a pipeline task.
40
41  Current version id format is: user_defined_version.minor_version_number
42  Current module id is just the module's name. It could be "default"
43
44  Returns:
45    A complete target name is of format version.module. If module is the
46  default module, just version. None if target can not be determined.
47  """
48  # Break circular dependency.
49  # pylint: disable=g-import-not-at-top
50  import pipeline
51  if pipeline._TEST_MODE:
52    return None
53
54  # Further protect against test cases that doesn't set env vars
55  # propertly.
56  if ("CURRENT_VERSION_ID" not in os.environ or
57      "CURRENT_MODULE_ID" not in os.environ):
58    logging.warning("Running Pipeline in non TEST_MODE but important "
59                    "env vars are not set.")
60    return None
61
62  version = os.environ["CURRENT_VERSION_ID"].split(".")[0]
63  module = os.environ["CURRENT_MODULE_ID"]
64  if module == "default":
65    return version
66  return "%s.%s" % (version, module)
67
68
69def for_name(fq_name, recursive=False):
70  """Find class/function/method specified by its fully qualified name.
71
72  Fully qualified can be specified as:
73    * <module_name>.<class_name>
74    * <module_name>.<function_name>
75    * <module_name>.<class_name>.<method_name> (an unbound method will be
76      returned in this case).
77
78  for_name works by doing __import__ for <module_name>, and looks for
79  <class_name>/<function_name> in module's __dict__/attrs. If fully qualified
80  name doesn't contain '.', the current module will be used.
81
82  Args:
83    fq_name: fully qualified name of something to find
84
85  Returns:
86    class object.
87
88  Raises:
89    ImportError: when specified module could not be loaded or the class
90    was not found in the module.
91  """
92  fq_name = str(fq_name)
93  module_name = __name__
94  short_name = fq_name
95
96  if fq_name.rfind(".") >= 0:
97    (module_name, short_name) = (fq_name[:fq_name.rfind(".")],
98                                 fq_name[fq_name.rfind(".") + 1:])
99
100  try:
101    result = __import__(module_name, None, None, [short_name])
102    return result.__dict__[short_name]
103  except KeyError:
104    # If we're recursively inside a for_name() chain, then we want to raise
105    # this error as a key error so we can report the actual source of the
106    # problem. If we're *not* recursively being called, that means the
107    # module was found and the specific item could not be loaded, and thus
108    # we want to raise an ImportError directly.
109    if recursive:
110      raise
111    else:
112      raise ImportError("Could not find '%s' on path '%s'" % (
113                        short_name, module_name))
114  except ImportError, e:
115    # module_name is not actually a module. Try for_name for it to figure
116    # out what's this.
117    try:
118      module = for_name(module_name, recursive=True)
119      if hasattr(module, short_name):
120        return getattr(module, short_name)
121      else:
122        # The module was found, but the function component is missing.
123        raise KeyError()
124    except KeyError:
125      raise ImportError("Could not find '%s' on path '%s'" % (
126                        short_name, module_name))
127    except ImportError:
128      # This means recursive import attempts failed, thus we will raise the
129      # first ImportError we encountered, since it's likely the most accurate.
130      pass
131    # Raise the original import error that caused all of this, since it is
132    # likely the real cause of the overall problem.
133    raise
134
135
136def is_generator_function(obj):
137  """Return true if the object is a user-defined generator function.
138
139  Generator function objects provides same attributes as functions.
140  See isfunction.__doc__ for attributes listing.
141
142  Adapted from Python 2.6.
143
144  Args:
145    obj: an object to test.
146
147  Returns:
148    true if the object is generator function.
149  """
150  CO_GENERATOR = 0x20
151  return bool(((inspect.isfunction(obj) or inspect.ismethod(obj)) and
152               obj.func_code.co_flags & CO_GENERATOR))
153
154
155class JsonEncoder(json.JSONEncoder):
156  """Pipeline customized json encoder."""
157
158  TYPE_ID = "__pipeline_json_type"
159
160  def default(self, o):
161    """Inherit docs."""
162    if type(o) in _TYPE_TO_ENCODER:
163      encoder = _TYPE_TO_ENCODER[type(o)]
164      json_struct = encoder(o)
165      json_struct[self.TYPE_ID] = type(o).__name__
166      return json_struct
167    return super(JsonEncoder, self).default(o)
168
169
170class JsonDecoder(json.JSONDecoder):
171  """Pipeline customized json decoder."""
172
173  def __init__(self, **kwargs):
174    if "object_hook" not in kwargs:
175      kwargs["object_hook"] = self._dict_to_obj
176    super(JsonDecoder, self).__init__(**kwargs)
177
178  def _dict_to_obj(self, d):
179    """Converts a dictionary of json object to a Python object."""
180    if JsonEncoder.TYPE_ID not in d:
181      return d
182
183    type_name = d.pop(JsonEncoder.TYPE_ID)
184    if type_name in _TYPE_NAME_TO_DECODER:
185      decoder = _TYPE_NAME_TO_DECODER[type_name]
186      return decoder(d)
187    else:
188      raise TypeError("Invalid type %s.", type_name)
189
190
191_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f"
192
193
194def _json_encode_datetime(o):
195  """Json encode a datetime object.
196
197  Args:
198    o: a datetime object.
199
200  Returns:
201    A dict of json primitives.
202  """
203  return {"isostr": o.strftime(_DATETIME_FORMAT)}
204
205
206def _json_decode_datetime(d):
207  """Converts a dict of json primitives to a datetime object."""
208  return datetime.datetime.strptime(d["isostr"], _DATETIME_FORMAT)
209
210
211def _register_json_primitive(object_type, encoder, decoder):
212  """Extend what Pipeline can serialize.
213
214  Args:
215    object_type: type of the object.
216    encoder: a function that takes in an object and returns
217      a dict of json primitives.
218    decoder: inverse function of encoder.
219  """
220  global _TYPE_TO_ENCODER
221  global _TYPE_NAME_TO_DECODER
222  if object_type not in _TYPE_TO_ENCODER:
223    _TYPE_TO_ENCODER[object_type] = encoder
224    _TYPE_NAME_TO_DECODER[object_type.__name__] = decoder
225
226
227_TYPE_TO_ENCODER = {}
228_TYPE_NAME_TO_DECODER = {}
229_register_json_primitive(datetime.datetime,
230                         _json_encode_datetime,
231                         _json_decode_datetime)
232