1#!/usr/bin/env python
2"""A class representing entity property range."""
3
4
5
6# pylint: disable=g-bad-name
7# pylint: disable=g-import-not-at-top
8
9import datetime
10
11from google.appengine.ext import ndb
12
13from google.appengine.ext import db
14from mapreduce import errors
15from mapreduce import util
16
17__all__ = [
18    "should_shard_by_property_range",
19    "PropertyRange"]
20
21
22def should_shard_by_property_range(filters):
23  """Returns whether these filters suggests sharding by property range.
24
25  Args:
26    filters: user supplied filters. Each filter should be a list or tuple of
27      format (<property_name_as_str>, <query_operator_as_str>,
28      <value_of_certain_type>). Value type is up to the property's type.
29
30  Returns:
31    True if these filters suggests sharding by property range. False
32  Otherwise.
33  """
34  if not filters:
35    return False
36
37  for f in filters:
38    if f[1] != "=":
39      return True
40  return False
41
42
43class PropertyRange(object):
44  """A class that represents a range on a db.Model's property.
45
46  It supports splitting the range into n shards and generating a query that
47  returns entities within that range.
48  """
49
50  def __init__(self,
51               filters,
52               model_class_path):
53    """Init.
54
55    Args:
56      filters: user supplied filters. Each filter should be a list or tuple of
57        format (<property_name_as_str>, <query_operator_as_str>,
58        <value_of_certain_type>). Value type should satisfy the property's type.
59      model_class_path: full path to the model class in str.
60    """
61    self.filters = filters
62    self.model_class_path = model_class_path
63    self.model_class = util.for_name(self.model_class_path)
64    self.prop, self.start, self.end = self._get_range_from_filters(
65        self.filters, self.model_class)
66
67  @classmethod
68  def _get_range_from_filters(cls, filters, model_class):
69    """Get property range from filters user provided.
70
71    This method also validates there is one and only one closed range on a
72    single property.
73
74    Args:
75      filters: user supplied filters. Each filter should be a list or tuple of
76        format (<property_name_as_str>, <query_operator_as_str>,
77        <value_of_certain_type>). Value type should satisfy the property's type.
78      model_class: the model class for the entity type to apply filters on.
79
80    Returns:
81      a tuple of (property, start_filter, end_filter). property is the model's
82    field that the range is about. start_filter and end_filter define the
83    start and the end of the range. (None, None, None) if no range is found.
84
85    Raises:
86      BadReaderParamsError: if any filter is invalid in any way.
87    """
88    if not filters:
89      return None, None, None
90
91    range_property = None
92    start_val = None
93    end_val = None
94    start_filter = None
95    end_filter = None
96    for f in filters:
97      prop, op, val = f
98
99      if op in [">", ">=", "<", "<="]:
100        if range_property and range_property != prop:
101          raise errors.BadReaderParamsError(
102              "Range on only one property is supported.")
103        range_property = prop
104
105        if val is None:
106          raise errors.BadReaderParamsError(
107              "Range can't be None in filter %s", f)
108
109        if op in [">", ">="]:
110          if start_val is not None:
111            raise errors.BadReaderParamsError(
112                "Operation %s is specified more than once.", op)
113          start_val = val
114          start_filter = f
115        else:
116          if end_val is not None:
117            raise errors.BadReaderParamsError(
118                "Operation %s is specified more than once.", op)
119          end_val = val
120          end_filter = f
121      elif op != "=":
122        raise errors.BadReaderParamsError(
123            "Only < <= > >= = are supported as operation. Got %s", op)
124
125    if not range_property:
126      return None, None, None
127
128    if start_val is None or end_val is None:
129      raise errors.BadReaderParamsError(
130          "Filter should contains a complete range on property %s",
131          range_property)
132    if issubclass(model_class, db.Model):
133      property_obj = model_class.properties()[range_property]
134    else:
135      property_obj = (
136          model_class._properties[  # pylint: disable=protected-access
137              range_property])
138    supported_properties = (
139        _DISCRETE_PROPERTY_SPLIT_FUNCTIONS.keys() +
140        _CONTINUOUS_PROPERTY_SPLIT_FUNCTIONS.keys())
141    if not isinstance(property_obj, tuple(supported_properties)):
142      raise errors.BadReaderParamsError(
143          "Filtered property %s is not supported by sharding.", range_property)
144    if not start_val < end_val:
145      raise errors.BadReaderParamsError(
146          "Start value %s should be smaller than end value %s",
147          start_val, end_val)
148
149    return property_obj, start_filter, end_filter
150
151  def split(self, n):
152    """Evenly split this range into contiguous, non overlapping subranges.
153
154    Args:
155      n: number of splits.
156
157    Returns:
158      a list of contiguous, non overlapping sub PropertyRanges. Maybe less than
159    n when not enough subranges.
160    """
161    new_range_filters = []
162    name = self.start[0]
163    prop_cls = self.prop.__class__
164    if prop_cls in _DISCRETE_PROPERTY_SPLIT_FUNCTIONS:
165      splitpoints = _DISCRETE_PROPERTY_SPLIT_FUNCTIONS[prop_cls](
166          self.start[2], self.end[2], n,
167          self.start[1] == ">=", self.end[1] == "<=")
168      start_filter = (name, ">=", splitpoints[0])
169      for p in splitpoints[1:]:
170        end_filter = (name, "<", p)
171        new_range_filters.append([start_filter, end_filter])
172        start_filter = (name, ">=", p)
173    else:
174      splitpoints = _CONTINUOUS_PROPERTY_SPLIT_FUNCTIONS[prop_cls](
175          self.start[2], self.end[2], n)
176      start_filter = self.start
177      for p in splitpoints:
178        end_filter = (name, "<", p)
179        new_range_filters.append([start_filter, end_filter])
180        start_filter = (name, ">=", p)
181      new_range_filters.append([start_filter, self.end])
182
183    for f in new_range_filters:
184      f.extend(self._equality_filters)
185
186    return [self.__class__(f, self.model_class_path) for f in new_range_filters]
187
188  def make_query(self, ns):
189    """Make a query of entities within this range.
190
191    Query options are not supported. They should be specified when the query
192    is run.
193
194    Args:
195      ns: namespace of this query.
196
197    Returns:
198      a db.Query or ndb.Query, depends on the model class's type.
199    """
200    if issubclass(self.model_class, db.Model):
201      query = db.Query(self.model_class, namespace=ns)
202      for f in self.filters:
203        query.filter("%s %s" % (f[0], f[1]), f[2])
204    else:
205      query = self.model_class.query(namespace=ns)
206      for f in self.filters:
207        query = query.filter(ndb.FilterNode(*f))
208    return query
209
210  @property
211  def _equality_filters(self):
212    return [f for f in self.filters if f[1] == "="]
213
214  def to_json(self):
215    return {"filters": self.filters,
216            "model_class_path": self.model_class_path}
217
218  @classmethod
219  def from_json(cls, json):
220    return cls(json["filters"], json["model_class_path"])
221
222
223def _split_datetime_property(start, end, n, include_start, include_end):
224  # datastore stored datetime precision is microsecond.
225  if not include_start:
226    start += datetime.timedelta(microseconds=1)
227  if include_end:
228    end += datetime.timedelta(microseconds=1)
229  delta = end - start
230  stride = delta // n
231  if stride <= datetime.timedelta():
232    raise ValueError("Range too small to split: start %r end %r", start, end)
233  splitpoints = [start]
234  previous = start
235  for _ in range(n-1):
236    point = previous + stride
237    if point == previous or point > end:
238      continue
239    previous = point
240    splitpoints.append(point)
241  if end not in splitpoints:
242    splitpoints.append(end)
243  return splitpoints
244
245
246def _split_float_property(start, end, n):
247  delta = float(end - start)
248  stride = delta / n
249  if stride <= 0:
250    raise ValueError("Range too small to split: start %r end %r", start, end)
251  splitpoints = []
252  for i in range(1, n):
253    splitpoints.append(start + i * stride)
254  return splitpoints
255
256
257def _split_integer_property(start, end, n, include_start, include_end):
258  if not include_start:
259    start += 1
260  if include_end:
261    end += 1
262  delta = float(end - start)
263  stride = delta / n
264  if stride <= 0:
265    raise ValueError("Range too small to split: start %r end %r", start, end)
266  splitpoints = [start]
267  previous = start
268  for i in range(1, n):
269    point = start + int(round(i * stride))
270    if point == previous or point > end:
271      continue
272    previous = point
273    splitpoints.append(point)
274  if end not in splitpoints:
275    splitpoints.append(end)
276  return splitpoints
277
278
279def _split_string_property(start, end, n, include_start, include_end):
280  try:
281    start = start.encode("ascii")
282    end = end.encode("ascii")
283  except UnicodeEncodeError, e:
284    raise ValueError("Only ascii str is supported.", e)
285
286  return _split_byte_string_property(start, end, n, include_start, include_end)
287
288
289# The alphabet splitting supports.
290_ALPHABET = "".join(chr(i) for i in range(128))
291# String length determines how many unique strings we can choose from.
292# We can't split into more shards than this: len(_ALPHABET)^_STRING_LENGTH
293_STRING_LENGTH = 4
294
295
296def _split_byte_string_property(start, end, n, include_start, include_end):
297  # Get prefix, suffix, and the real start/end to split on.
298  i = 0
299  for i, (s, e) in enumerate(zip(start, end)):
300    if s != e:
301      break
302  common_prefix = start[:i]
303  start_suffix = start[i+_STRING_LENGTH:]
304  end_suffix = end[i+_STRING_LENGTH:]
305  start = start[i:i+_STRING_LENGTH]
306  end = end[i:i+_STRING_LENGTH]
307
308  # Convert str to ord.
309  weights = _get_weights(_STRING_LENGTH)
310  start_ord = _str_to_ord(start, weights)
311  if not include_start:
312    start_ord += 1
313  end_ord = _str_to_ord(end, weights)
314  if include_end:
315    end_ord += 1
316
317  # Do split.
318  stride = (end_ord - start_ord) / float(n)
319  if stride <= 0:
320    raise ValueError("Range too small to split: start %s end %s", start, end)
321  splitpoints = [_ord_to_str(start_ord, weights)]
322  previous = start_ord
323  for i in range(1, n):
324    point = start_ord + int(round(stride * i))
325    if point == previous or point > end_ord:
326      continue
327    previous = point
328    splitpoints.append(_ord_to_str(point, weights))
329  end_str = _ord_to_str(end_ord, weights)
330  if end_str not in splitpoints:
331    splitpoints.append(end_str)
332
333  # Append suffix.
334  splitpoints[0] += start_suffix
335  splitpoints[-1] += end_suffix
336
337  return [common_prefix + point for point in splitpoints]
338
339
340def _get_weights(max_length):
341  """Get weights for each offset in str of certain max length.
342
343  Args:
344    max_length: max length of the strings.
345
346  Returns:
347    A list of ints as weights.
348
349  Example:
350    If max_length is 2 and alphabet is "ab", then we have order "", "a", "aa",
351  "ab", "b", "ba", "bb". So the weight for the first char is 3.
352  """
353  weights = [1]
354  for i in range(1, max_length):
355    weights.append(weights[i-1] * len(_ALPHABET) + 1)
356  weights.reverse()
357  return weights
358
359
360def _str_to_ord(content, weights):
361  """Converts a string to its lexicographical order.
362
363  Args:
364    content: the string to convert. Of type str.
365    weights: weights from _get_weights.
366
367  Returns:
368    an int or long that represents the order of this string. "" has order 0.
369  """
370  ordinal = 0
371  for i, c in enumerate(content):
372    ordinal += weights[i] * _ALPHABET.index(c) + 1
373  return ordinal
374
375
376def _ord_to_str(ordinal, weights):
377  """Reverse function of _str_to_ord."""
378  chars = []
379  for weight in weights:
380    if ordinal == 0:
381      return "".join(chars)
382    ordinal -= 1
383    index, ordinal = divmod(ordinal, weight)
384    chars.append(_ALPHABET[index])
385  return "".join(chars)
386
387
388# discrete property split functions all have the same interface.
389# They take start, end, shard_number n, include_start, include_end.
390# They return at most n+1 points, forming n ranges.
391# Each range should be include_start, exclude_end.
392_DISCRETE_PROPERTY_SPLIT_FUNCTIONS = {
393    db.DateTimeProperty: _split_datetime_property,
394    db.IntegerProperty: _split_integer_property,
395    db.StringProperty: _split_string_property,
396    db.ByteStringProperty: _split_byte_string_property,
397    # ndb.
398    ndb.DateTimeProperty: _split_datetime_property,
399    ndb.IntegerProperty: _split_integer_property,
400    ndb.StringProperty: _split_string_property,
401    ndb.BlobProperty: _split_byte_string_property
402}
403
404_CONTINUOUS_PROPERTY_SPLIT_FUNCTIONS = {
405    db.FloatProperty: _split_float_property,
406    # ndb.
407    ndb.FloatProperty: _split_float_property,
408}
409