1""" 2Extensions to Django's model logic. 3""" 4 5import django.core.exceptions 6from django.db import connection 7from django.db import connections 8from django.db import models as dbmodels 9from django.db import transaction 10from django.db.models.sql import query 11import django.db.models.sql.where 12# TODO(akeshet): Replace with monarch stats once we know how to instrument rpc 13# handling with ts_mon. 14from autotest_lib.client.common_lib.cros.graphite import autotest_stats 15from autotest_lib.frontend.afe import rdb_model_extensions 16 17 18class ValidationError(django.core.exceptions.ValidationError): 19 """\ 20 Data validation error in adding or updating an object. The associated 21 value is a dictionary mapping field names to error strings. 22 """ 23 24def _quote_name(name): 25 """Shorthand for connection.ops.quote_name().""" 26 return connection.ops.quote_name(name) 27 28 29class LeasedHostManager(dbmodels.Manager): 30 """Query manager for unleased, unlocked hosts. 31 """ 32 def get_query_set(self): 33 return (super(LeasedHostManager, self).get_query_set().filter( 34 leased=0, locked=0)) 35 36 37class ExtendedManager(dbmodels.Manager): 38 """\ 39 Extended manager supporting subquery filtering. 40 """ 41 42 class CustomQuery(query.Query): 43 def __init__(self, *args, **kwargs): 44 super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs) 45 self._custom_joins = [] 46 47 48 def clone(self, klass=None, **kwargs): 49 obj = super(ExtendedManager.CustomQuery, self).clone(klass) 50 obj._custom_joins = list(self._custom_joins) 51 return obj 52 53 54 def combine(self, rhs, connector): 55 super(ExtendedManager.CustomQuery, self).combine(rhs, connector) 56 if hasattr(rhs, '_custom_joins'): 57 self._custom_joins.extend(rhs._custom_joins) 58 59 60 def add_custom_join(self, table, condition, join_type, 61 condition_values=(), alias=None): 62 if alias is None: 63 alias = table 64 join_dict = dict(table=table, 65 condition=condition, 66 condition_values=condition_values, 67 join_type=join_type, 68 alias=alias) 69 self._custom_joins.append(join_dict) 70 71 72 @classmethod 73 def convert_query(self, query_set): 74 """ 75 Convert the query set's "query" attribute to a CustomQuery. 76 """ 77 # Make a copy of the query set 78 query_set = query_set.all() 79 query_set.query = query_set.query.clone( 80 klass=ExtendedManager.CustomQuery, 81 _custom_joins=[]) 82 return query_set 83 84 85 class _WhereClause(object): 86 """Object allowing us to inject arbitrary SQL into Django queries. 87 88 By using this instead of extra(where=...), we can still freely combine 89 queries with & and |. 90 """ 91 def __init__(self, clause, values=()): 92 self._clause = clause 93 self._values = values 94 95 96 def as_sql(self, qn=None, connection=None): 97 return self._clause, self._values 98 99 100 def relabel_aliases(self, change_map): 101 return 102 103 104 def add_join(self, query_set, join_table, join_key, join_condition='', 105 join_condition_values=(), join_from_key=None, alias=None, 106 suffix='', exclude=False, force_left_join=False): 107 """Add a join to query_set. 108 109 Join looks like this: 110 (INNER|LEFT) JOIN <join_table> AS <alias> 111 ON (<this table>.<join_from_key> = <join_table>.<join_key> 112 and <join_condition>) 113 114 @param join_table table to join to 115 @param join_key field referencing back to this model to use for the join 116 @param join_condition extra condition for the ON clause of the join 117 @param join_condition_values values to substitute into join_condition 118 @param join_from_key column on this model to join from. 119 @param alias alias to use for for join 120 @param suffix suffix to add to join_table for the join alias, if no 121 alias is provided 122 @param exclude if true, exclude rows that match this join (will use a 123 LEFT OUTER JOIN and an appropriate WHERE condition) 124 @param force_left_join - if true, a LEFT OUTER JOIN will be used 125 instead of an INNER JOIN regardless of other options 126 """ 127 join_from_table = query_set.model._meta.db_table 128 if join_from_key is None: 129 join_from_key = self.model._meta.pk.name 130 if alias is None: 131 alias = join_table + suffix 132 full_join_key = _quote_name(alias) + '.' + _quote_name(join_key) 133 full_join_condition = '%s = %s.%s' % (full_join_key, 134 _quote_name(join_from_table), 135 _quote_name(join_from_key)) 136 if join_condition: 137 full_join_condition += ' AND (' + join_condition + ')' 138 if exclude or force_left_join: 139 join_type = query_set.query.LOUTER 140 else: 141 join_type = query_set.query.INNER 142 143 query_set = self.CustomQuery.convert_query(query_set) 144 query_set.query.add_custom_join(join_table, 145 full_join_condition, 146 join_type, 147 condition_values=join_condition_values, 148 alias=alias) 149 150 if exclude: 151 query_set = query_set.extra(where=[full_join_key + ' IS NULL']) 152 153 return query_set 154 155 156 def _info_for_many_to_one_join(self, field, join_to_query, alias): 157 """ 158 @param field: the ForeignKey field on the related model 159 @param join_to_query: the query over the related model that we're 160 joining to 161 @param alias: alias of joined table 162 """ 163 info = {} 164 rhs_table = join_to_query.model._meta.db_table 165 info['rhs_table'] = rhs_table 166 info['rhs_column'] = field.column 167 info['lhs_column'] = field.rel.get_related_field().column 168 rhs_where = join_to_query.query.where 169 rhs_where.relabel_aliases({rhs_table: alias}) 170 compiler = join_to_query.query.get_compiler(using=join_to_query.db) 171 initial_clause, values = compiler.as_sql() 172 # initial_clause is compiled from `join_to_query`, which is a SELECT 173 # query returns at most one record. For it to be used in WHERE clause, 174 # it must be converted to a boolean value using EXISTS. 175 all_clauses = ('EXISTS (%s)' % initial_clause,) 176 if hasattr(join_to_query.query, 'extra_where'): 177 all_clauses += join_to_query.query.extra_where 178 info['where_clause'] = ( 179 ' AND '.join('(%s)' % clause for clause in all_clauses)) 180 info['values'] = values 181 return info 182 183 184 def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias, 185 m2m_is_on_this_model): 186 """ 187 @param m2m_field: a Django field representing the M2M relationship. 188 It uses a pivot table with the following structure: 189 this model table <---> M2M pivot table <---> joined model table 190 @param join_to_query: the query over the related model that we're 191 joining to. 192 @param alias: alias of joined table 193 """ 194 if m2m_is_on_this_model: 195 # referenced field on this model 196 lhs_id_field = self.model._meta.pk 197 # foreign key on the pivot table referencing lhs_id_field 198 m2m_lhs_column = m2m_field.m2m_column_name() 199 # foreign key on the pivot table referencing rhd_id_field 200 m2m_rhs_column = m2m_field.m2m_reverse_name() 201 # referenced field on related model 202 rhs_id_field = m2m_field.rel.get_related_field() 203 else: 204 lhs_id_field = m2m_field.rel.get_related_field() 205 m2m_lhs_column = m2m_field.m2m_reverse_name() 206 m2m_rhs_column = m2m_field.m2m_column_name() 207 rhs_id_field = join_to_query.model._meta.pk 208 209 info = {} 210 info['rhs_table'] = m2m_field.m2m_db_table() 211 info['rhs_column'] = m2m_lhs_column 212 info['lhs_column'] = lhs_id_field.column 213 214 # select the ID of related models relevant to this join. we can only do 215 # a single join, so we need to gather this information up front and 216 # include it in the join condition. 217 rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True) 218 assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only ' 219 'match a single related object.') 220 rhs_id = rhs_ids[0] 221 222 info['where_clause'] = '%s.%s = %s' % (_quote_name(alias), 223 _quote_name(m2m_rhs_column), 224 rhs_id) 225 info['values'] = () 226 return info 227 228 229 def join_custom_field(self, query_set, join_to_query, alias, 230 left_join=True): 231 """Join to a related model to create a custom field in the given query. 232 233 This method is used to construct a custom field on the given query based 234 on a many-valued relationsip. join_to_query should be a simple query 235 (no joins) on the related model which returns at most one related row 236 per instance of this model. 237 238 For many-to-one relationships, the joined table contains the matching 239 row from the related model it one is related, NULL otherwise. 240 241 For many-to-many relationships, the joined table contains the matching 242 row if it's related, NULL otherwise. 243 """ 244 relationship_type, field = self.determine_relationship( 245 join_to_query.model) 246 247 if relationship_type == self.MANY_TO_ONE: 248 info = self._info_for_many_to_one_join(field, join_to_query, alias) 249 elif relationship_type == self.M2M_ON_RELATED_MODEL: 250 info = self._info_for_many_to_many_join( 251 m2m_field=field, join_to_query=join_to_query, alias=alias, 252 m2m_is_on_this_model=False) 253 elif relationship_type ==self.M2M_ON_THIS_MODEL: 254 info = self._info_for_many_to_many_join( 255 m2m_field=field, join_to_query=join_to_query, alias=alias, 256 m2m_is_on_this_model=True) 257 258 return self.add_join(query_set, info['rhs_table'], info['rhs_column'], 259 join_from_key=info['lhs_column'], 260 join_condition=info['where_clause'], 261 join_condition_values=info['values'], 262 alias=alias, 263 force_left_join=left_join) 264 265 266 def add_where(self, query_set, where, values=()): 267 query_set = query_set.all() 268 query_set.query.where.add(self._WhereClause(where, values), 269 django.db.models.sql.where.AND) 270 return query_set 271 272 273 def _get_quoted_field(self, table, field): 274 return _quote_name(table) + '.' + _quote_name(field) 275 276 277 def get_key_on_this_table(self, key_field=None): 278 if key_field is None: 279 # default to primary key 280 key_field = self.model._meta.pk.column 281 return self._get_quoted_field(self.model._meta.db_table, key_field) 282 283 284 def escape_user_sql(self, sql): 285 return sql.replace('%', '%%') 286 287 288 def _custom_select_query(self, query_set, selects): 289 """Execute a custom select query. 290 291 @param query_set: query set as returned by query_objects. 292 @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id. 293 294 @returns: Result of the query as returned by cursor.fetchall(). 295 """ 296 compiler = query_set.query.get_compiler(using=query_set.db) 297 sql, params = compiler.as_sql() 298 from_ = sql[sql.find(' FROM'):] 299 300 if query_set.query.distinct: 301 distinct = 'DISTINCT ' 302 else: 303 distinct = '' 304 305 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_) 306 # Chose the connection that's responsible for this type of object 307 cursor = connections[query_set.db].cursor() 308 cursor.execute(sql_query, params) 309 return cursor.fetchall() 310 311 312 def _is_relation_to(self, field, model_class): 313 return field.rel and field.rel.to is model_class 314 315 316 MANY_TO_ONE = object() 317 M2M_ON_RELATED_MODEL = object() 318 M2M_ON_THIS_MODEL = object() 319 320 def determine_relationship(self, related_model): 321 """ 322 Determine the relationship between this model and related_model. 323 324 related_model must have some sort of many-valued relationship to this 325 manager's model. 326 @returns (relationship_type, field), where relationship_type is one of 327 MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field 328 is the Django field object for the relationship. 329 """ 330 # look for a foreign key field on related_model relating to this model 331 for field in related_model._meta.fields: 332 if self._is_relation_to(field, self.model): 333 return self.MANY_TO_ONE, field 334 335 # look for an M2M field on related_model relating to this model 336 for field in related_model._meta.many_to_many: 337 if self._is_relation_to(field, self.model): 338 return self.M2M_ON_RELATED_MODEL, field 339 340 # maybe this model has the many-to-many field 341 for field in self.model._meta.many_to_many: 342 if self._is_relation_to(field, related_model): 343 return self.M2M_ON_THIS_MODEL, field 344 345 raise ValueError('%s has no relation to %s' % 346 (related_model, self.model)) 347 348 349 def _get_pivot_iterator(self, base_objects_by_id, related_model): 350 """ 351 Determine the relationship between this model and related_model, and 352 return a pivot iterator. 353 @param base_objects_by_id: dict of instances of this model indexed by 354 their IDs 355 @returns a pivot iterator, which yields a tuple (base_object, 356 related_object) for each relationship between a base object and a 357 related object. all base_object instances come from base_objects_by_id. 358 Note -- this depends on Django model internals. 359 """ 360 relationship_type, field = self.determine_relationship(related_model) 361 if relationship_type == self.MANY_TO_ONE: 362 return self._many_to_one_pivot(base_objects_by_id, 363 related_model, field) 364 elif relationship_type == self.M2M_ON_RELATED_MODEL: 365 return self._many_to_many_pivot( 366 base_objects_by_id, related_model, field.m2m_db_table(), 367 field.m2m_reverse_name(), field.m2m_column_name()) 368 else: 369 assert relationship_type == self.M2M_ON_THIS_MODEL 370 return self._many_to_many_pivot( 371 base_objects_by_id, related_model, field.m2m_db_table(), 372 field.m2m_column_name(), field.m2m_reverse_name()) 373 374 375 def _many_to_one_pivot(self, base_objects_by_id, related_model, 376 foreign_key_field): 377 """ 378 @returns a pivot iterator - see _get_pivot_iterator() 379 """ 380 filter_data = {foreign_key_field.name + '__pk__in': 381 base_objects_by_id.keys()} 382 for related_object in related_model.objects.filter(**filter_data): 383 # lookup base object in the dict, rather than grabbing it from the 384 # related object. we need to return instances from the dict, not 385 # fresh instances of the same models (and grabbing model instances 386 # from the related models incurs a DB query each time). 387 base_object_id = getattr(related_object, foreign_key_field.attname) 388 base_object = base_objects_by_id[base_object_id] 389 yield base_object, related_object 390 391 392 def _query_pivot_table(self, base_objects_by_id, pivot_table, 393 pivot_from_field, pivot_to_field, related_model): 394 """ 395 @param id_list list of IDs of self.model objects to include 396 @param pivot_table the name of the pivot table 397 @param pivot_from_field a field name on pivot_table referencing 398 self.model 399 @param pivot_to_field a field name on pivot_table referencing the 400 related model. 401 @param related_model the related model 402 403 @returns pivot list of IDs (base_id, related_id) 404 """ 405 query = """ 406 SELECT %(from_field)s, %(to_field)s 407 FROM %(table)s 408 WHERE %(from_field)s IN (%(id_list)s) 409 """ % dict(from_field=pivot_from_field, 410 to_field=pivot_to_field, 411 table=pivot_table, 412 id_list=','.join(str(id_) for id_ 413 in base_objects_by_id.iterkeys())) 414 415 # Chose the connection that's responsible for this type of object 416 # The databases for related_model and the current model will always 417 # be the same, related_model is just easier to obtain here because 418 # self is only a ExtendedManager, not the object. 419 cursor = connections[related_model.objects.db].cursor() 420 cursor.execute(query) 421 return cursor.fetchall() 422 423 424 def _many_to_many_pivot(self, base_objects_by_id, related_model, 425 pivot_table, pivot_from_field, pivot_to_field): 426 """ 427 @param pivot_table: see _query_pivot_table 428 @param pivot_from_field: see _query_pivot_table 429 @param pivot_to_field: see _query_pivot_table 430 @returns a pivot iterator - see _get_pivot_iterator() 431 """ 432 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table, 433 pivot_from_field, pivot_to_field, 434 related_model) 435 436 all_related_ids = list(set(related_id for base_id, related_id 437 in id_pivot)) 438 related_objects_by_id = related_model.objects.in_bulk(all_related_ids) 439 440 for base_id, related_id in id_pivot: 441 yield base_objects_by_id[base_id], related_objects_by_id[related_id] 442 443 444 def populate_relationships(self, base_objects, related_model, 445 related_list_name): 446 """ 447 For each instance of this model in base_objects, add a field named 448 related_list_name listing all the related objects of type related_model. 449 related_model must be in a many-to-one or many-to-many relationship with 450 this model. 451 @param base_objects - list of instances of this model 452 @param related_model - model class related to this model 453 @param related_list_name - attribute name in which to store the related 454 object list. 455 """ 456 if not base_objects: 457 # if we don't bail early, we'll get a SQL error later 458 return 459 460 base_objects_by_id = dict((base_object._get_pk_val(), base_object) 461 for base_object in base_objects) 462 pivot_iterator = self._get_pivot_iterator(base_objects_by_id, 463 related_model) 464 465 for base_object in base_objects: 466 setattr(base_object, related_list_name, []) 467 468 for base_object, related_object in pivot_iterator: 469 getattr(base_object, related_list_name).append(related_object) 470 471 472class ModelWithInvalidQuerySet(dbmodels.query.QuerySet): 473 """ 474 QuerySet that handles delete() properly for models with an "invalid" bit 475 """ 476 def delete(self): 477 for model in self: 478 model.delete() 479 480 481class ModelWithInvalidManager(ExtendedManager): 482 """ 483 Manager for objects with an "invalid" bit 484 """ 485 def get_query_set(self): 486 return ModelWithInvalidQuerySet(self.model) 487 488 489class ValidObjectsManager(ModelWithInvalidManager): 490 """ 491 Manager returning only objects with invalid=False. 492 """ 493 def get_query_set(self): 494 queryset = super(ValidObjectsManager, self).get_query_set() 495 return queryset.filter(invalid=False) 496 497 498class ModelExtensions(rdb_model_extensions.ModelValidators): 499 """\ 500 Mixin with convenience functions for models, built on top of 501 the model validators in rdb_model_extensions. 502 """ 503 # TODO: at least some of these functions really belong in a custom 504 # Manager class 505 506 507 SERIALIZATION_LINKS_TO_FOLLOW = set() 508 """ 509 To be able to send jobs and hosts to shards, it's necessary to find their 510 dependencies. 511 The most generic approach for this would be to traverse all relationships 512 to other objects recursively. This would list all objects that are related 513 in any way. 514 But this approach finds too many objects: If a host should be transferred, 515 all it's relationships would be traversed. This would find an acl group. 516 If then the acl group's relationships are traversed, the relationship 517 would be followed backwards and many other hosts would be found. 518 519 This mapping tells that algorithm which relations to follow explicitly. 520 """ 521 522 523 SERIALIZATION_LINKS_TO_KEEP = set() 524 """This set stores foreign keys which we don't want to follow, but 525 still want to include in the serialized dictionary. For 526 example, we follow the relationship `Host.hostattribute_set`, 527 but we do not want to follow `HostAttributes.host_id` back to 528 to Host, which would otherwise lead to a circle. However, we still 529 like to serialize HostAttribute.`host_id`.""" 530 531 SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set() 532 """ 533 On deserializion, if the object to persist already exists, local fields 534 will only be updated, if their name is in this set. 535 """ 536 537 538 @classmethod 539 def convert_human_readable_values(cls, data, to_human_readable=False): 540 """\ 541 Performs conversions on user-supplied field data, to make it 542 easier for users to pass human-readable data. 543 544 For all fields that have choice sets, convert their values 545 from human-readable strings to enum values, if necessary. This 546 allows users to pass strings instead of the corresponding 547 integer values. 548 549 For all foreign key fields, call smart_get with the supplied 550 data. This allows the user to pass either an ID value or 551 the name of the object as a string. 552 553 If to_human_readable=True, perform the inverse - i.e. convert 554 numeric values to human readable values. 555 556 This method modifies data in-place. 557 """ 558 field_dict = cls.get_field_dict() 559 for field_name in data: 560 if field_name not in field_dict or data[field_name] is None: 561 continue 562 field_obj = field_dict[field_name] 563 # convert enum values 564 if field_obj.choices: 565 for choice_data in field_obj.choices: 566 # choice_data is (value, name) 567 if to_human_readable: 568 from_val, to_val = choice_data 569 else: 570 to_val, from_val = choice_data 571 if from_val == data[field_name]: 572 data[field_name] = to_val 573 break 574 # convert foreign key values 575 elif field_obj.rel: 576 dest_obj = field_obj.rel.to.smart_get(data[field_name], 577 valid_only=False) 578 if to_human_readable: 579 # parameterized_jobs do not have a name_field 580 if (field_name != 'parameterized_job' and 581 dest_obj.name_field is not None): 582 data[field_name] = getattr(dest_obj, 583 dest_obj.name_field) 584 else: 585 data[field_name] = dest_obj 586 587 588 589 590 def _validate_unique(self): 591 """\ 592 Validate that unique fields are unique. Django manipulators do 593 this too, but they're a huge pain to use manually. Trust me. 594 """ 595 errors = {} 596 cls = type(self) 597 field_dict = self.get_field_dict() 598 manager = cls.get_valid_manager() 599 for field_name, field_obj in field_dict.iteritems(): 600 if not field_obj.unique: 601 continue 602 603 value = getattr(self, field_name) 604 if value is None and field_obj.auto_created: 605 # don't bother checking autoincrement fields about to be 606 # generated 607 continue 608 609 existing_objs = manager.filter(**{field_name : value}) 610 num_existing = existing_objs.count() 611 612 if num_existing == 0: 613 continue 614 if num_existing == 1 and existing_objs[0].id == self.id: 615 continue 616 errors[field_name] = ( 617 'This value must be unique (%s)' % (value)) 618 return errors 619 620 621 def _validate(self): 622 """ 623 First coerces all fields on this instance to their proper Python types. 624 Then runs validation on every field. Returns a dictionary of 625 field_name -> error_list. 626 627 Based on validate() from django.db.models.Model in Django 0.96, which 628 was removed in Django 1.0. It should reappear in a later version. See: 629 http://code.djangoproject.com/ticket/6845 630 """ 631 error_dict = {} 632 for f in self._meta.fields: 633 try: 634 python_value = f.to_python( 635 getattr(self, f.attname, f.get_default())) 636 except django.core.exceptions.ValidationError, e: 637 error_dict[f.name] = str(e) 638 continue 639 640 if not f.blank and not python_value: 641 error_dict[f.name] = 'This field is required.' 642 continue 643 644 setattr(self, f.attname, python_value) 645 646 return error_dict 647 648 649 def do_validate(self): 650 errors = self._validate() 651 unique_errors = self._validate_unique() 652 for field_name, error in unique_errors.iteritems(): 653 errors.setdefault(field_name, error) 654 if errors: 655 raise ValidationError(errors) 656 657 658 # actually (externally) useful methods follow 659 660 @classmethod 661 def add_object(cls, data={}, **kwargs): 662 """\ 663 Returns a new object created with the given data (a dictionary 664 mapping field names to values). Merges any extra keyword args 665 into data. 666 """ 667 data = dict(data) 668 data.update(kwargs) 669 data = cls.prepare_data_args(data) 670 cls.convert_human_readable_values(data) 671 data = cls.provide_default_values(data) 672 673 obj = cls(**data) 674 obj.do_validate() 675 obj.save() 676 return obj 677 678 679 def update_object(self, data={}, **kwargs): 680 """\ 681 Updates the object with the given data (a dictionary mapping 682 field names to values). Merges any extra keyword args into 683 data. 684 """ 685 data = dict(data) 686 data.update(kwargs) 687 data = self.prepare_data_args(data) 688 self.convert_human_readable_values(data) 689 for field_name, value in data.iteritems(): 690 setattr(self, field_name, value) 691 self.do_validate() 692 self.save() 693 694 695 # see query_objects() 696 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by', 697 'extra_args', 'extra_where', 'no_distinct') 698 699 700 @classmethod 701 def _extract_special_params(cls, filter_data): 702 """ 703 @returns a tuple of dicts (special_params, regular_filters), where 704 special_params contains the parameters we handle specially and 705 regular_filters is the remaining data to be handled by Django. 706 """ 707 regular_filters = dict(filter_data) 708 special_params = {} 709 for key in cls._SPECIAL_FILTER_KEYS: 710 if key in regular_filters: 711 special_params[key] = regular_filters.pop(key) 712 return special_params, regular_filters 713 714 715 @classmethod 716 def apply_presentation(cls, query, filter_data): 717 """ 718 Apply presentation parameters -- sorting and paging -- to the given 719 query. 720 @returns new query with presentation applied 721 """ 722 special_params, _ = cls._extract_special_params(filter_data) 723 sort_by = special_params.get('sort_by', None) 724 if sort_by: 725 assert isinstance(sort_by, list) or isinstance(sort_by, tuple) 726 query = query.extra(order_by=sort_by) 727 728 query_start = special_params.get('query_start', None) 729 query_limit = special_params.get('query_limit', None) 730 if query_start is not None: 731 if query_limit is None: 732 raise ValueError('Cannot pass query_start without query_limit') 733 # query_limit is passed as a page size 734 query_limit += query_start 735 return query[query_start:query_limit] 736 737 738 @classmethod 739 def query_objects(cls, filter_data, valid_only=True, initial_query=None, 740 apply_presentation=True): 741 """\ 742 Returns a QuerySet object for querying the given model_class 743 with the given filter_data. Optional special arguments in 744 filter_data include: 745 -query_start: index of first return to return 746 -query_limit: maximum number of results to return 747 -sort_by: list of fields to sort on. prefixing a '-' onto a 748 field name changes the sort to descending order. 749 -extra_args: keyword args to pass to query.extra() (see Django 750 DB layer documentation) 751 -extra_where: extra WHERE clause to append 752 -no_distinct: if True, a DISTINCT will not be added to the SELECT 753 """ 754 special_params, regular_filters = cls._extract_special_params( 755 filter_data) 756 757 if initial_query is None: 758 if valid_only: 759 initial_query = cls.get_valid_manager() 760 else: 761 initial_query = cls.objects 762 763 query = initial_query.filter(**regular_filters) 764 765 use_distinct = not special_params.get('no_distinct', False) 766 if use_distinct: 767 query = query.distinct() 768 769 extra_args = special_params.get('extra_args', {}) 770 extra_where = special_params.get('extra_where', None) 771 if extra_where: 772 # escape %'s 773 extra_where = cls.objects.escape_user_sql(extra_where) 774 extra_args.setdefault('where', []).append(extra_where) 775 if extra_args: 776 query = query.extra(**extra_args) 777 # TODO: Use readonly connection for these queries. 778 # This has been disabled, because it's not used anyway, as the 779 # configured readonly user is the same as the real user anyway. 780 781 if apply_presentation: 782 query = cls.apply_presentation(query, filter_data) 783 784 return query 785 786 787 @classmethod 788 def query_count(cls, filter_data, initial_query=None): 789 """\ 790 Like query_objects, but retreive only the count of results. 791 """ 792 filter_data.pop('query_start', None) 793 filter_data.pop('query_limit', None) 794 query = cls.query_objects(filter_data, initial_query=initial_query) 795 return query.count() 796 797 798 @classmethod 799 def clean_object_dicts(cls, field_dicts): 800 """\ 801 Take a list of dicts corresponding to object (as returned by 802 query.values()) and clean the data to be more suitable for 803 returning to the user. 804 """ 805 for field_dict in field_dicts: 806 cls.clean_foreign_keys(field_dict) 807 cls._convert_booleans(field_dict) 808 cls.convert_human_readable_values(field_dict, 809 to_human_readable=True) 810 811 812 @classmethod 813 def list_objects(cls, filter_data, initial_query=None): 814 """\ 815 Like query_objects, but return a list of dictionaries. 816 """ 817 query = cls.query_objects(filter_data, initial_query=initial_query) 818 extra_fields = query.query.extra_select.keys() 819 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields) 820 for model_object in query] 821 return field_dicts 822 823 824 @classmethod 825 def smart_get(cls, id_or_name, valid_only=True): 826 """\ 827 smart_get(integer) -> get object by ID 828 smart_get(string) -> get object by name_field 829 """ 830 if valid_only: 831 manager = cls.get_valid_manager() 832 else: 833 manager = cls.objects 834 835 if isinstance(id_or_name, (int, long)): 836 return manager.get(pk=id_or_name) 837 if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'): 838 return manager.get(**{cls.name_field : id_or_name}) 839 raise ValueError( 840 'Invalid positional argument: %s (%s)' % (id_or_name, 841 type(id_or_name))) 842 843 844 @classmethod 845 def smart_get_bulk(cls, id_or_name_list): 846 invalid_inputs = [] 847 result_objects = [] 848 for id_or_name in id_or_name_list: 849 try: 850 result_objects.append(cls.smart_get(id_or_name)) 851 except cls.DoesNotExist: 852 invalid_inputs.append(id_or_name) 853 if invalid_inputs: 854 raise cls.DoesNotExist('The following %ss do not exist: %s' 855 % (cls.__name__.lower(), 856 ', '.join(invalid_inputs))) 857 return result_objects 858 859 860 def get_object_dict(self, extra_fields=None): 861 """\ 862 Return a dictionary mapping fields to this object's values. @param 863 extra_fields: list of extra attribute names to include, in addition to 864 the fields defined on this object. 865 """ 866 fields = self.get_field_dict().keys() 867 if extra_fields: 868 fields += extra_fields 869 object_dict = dict((field_name, getattr(self, field_name)) 870 for field_name in fields) 871 self.clean_object_dicts([object_dict]) 872 self._postprocess_object_dict(object_dict) 873 return object_dict 874 875 876 def _postprocess_object_dict(self, object_dict): 877 """For subclasses to override.""" 878 pass 879 880 881 @classmethod 882 def get_valid_manager(cls): 883 return cls.objects 884 885 886 def _record_attributes(self, attributes): 887 """ 888 See on_attribute_changed. 889 """ 890 assert not isinstance(attributes, basestring) 891 self._recorded_attributes = dict((attribute, getattr(self, attribute)) 892 for attribute in attributes) 893 894 895 def _check_for_updated_attributes(self): 896 """ 897 See on_attribute_changed. 898 """ 899 for attribute, original_value in self._recorded_attributes.iteritems(): 900 new_value = getattr(self, attribute) 901 if original_value != new_value: 902 self.on_attribute_changed(attribute, original_value) 903 self._record_attributes(self._recorded_attributes.keys()) 904 905 906 def on_attribute_changed(self, attribute, old_value): 907 """ 908 Called whenever an attribute is updated. To be overridden. 909 910 To use this method, you must: 911 * call _record_attributes() from __init__() (after making the super 912 call) with a list of attributes for which you want to be notified upon 913 change. 914 * call _check_for_updated_attributes() from save(). 915 """ 916 pass 917 918 919 def serialize(self, include_dependencies=True): 920 """Serializes the object with dependencies. 921 922 The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies 923 this function will serialize with the object. 924 925 @param include_dependencies: Whether or not to follow relations to 926 objects this object depends on. 927 This parameter is used when uploading 928 jobs from a shard to the master, as the 929 master already has all the dependent 930 objects. 931 932 @returns: Dictionary representation of the object. 933 """ 934 serialized = {} 935 timer = autotest_stats.Timer('serialize_latency.%s' % ( 936 type(self).__name__)) 937 with timer.get_client('local'): 938 for field in self._meta.concrete_model._meta.local_fields: 939 if field.rel is None: 940 serialized[field.name] = field._get_val_from_obj(self) 941 elif field.name in self.SERIALIZATION_LINKS_TO_KEEP: 942 # attname will contain "_id" suffix for foreign keys, 943 # e.g. HostAttribute.host will be serialized as 'host_id'. 944 # Use it for easy deserialization. 945 serialized[field.attname] = field._get_val_from_obj(self) 946 947 if include_dependencies: 948 with timer.get_client('related'): 949 for link in self.SERIALIZATION_LINKS_TO_FOLLOW: 950 serialized[link] = self._serialize_relation(link) 951 952 return serialized 953 954 955 def _serialize_relation(self, link): 956 """Serializes dependent objects given the name of the relation. 957 958 @param link: Name of the relation to take objects from. 959 960 @returns For To-Many relationships a list of the serialized related 961 objects, for To-One relationships the serialized related object. 962 """ 963 try: 964 attr = getattr(self, link) 965 except AttributeError: 966 # One-To-One relationships that point to None may raise this 967 return None 968 969 if attr is None: 970 return None 971 if hasattr(attr, 'all'): 972 return [obj.serialize() for obj in attr.all()] 973 return attr.serialize() 974 975 976 @classmethod 977 def _split_local_from_foreign_values(cls, data): 978 """This splits local from foreign values in a serialized object. 979 980 @param data: The serialized object. 981 982 @returns A tuple of two lists, both containing tuples in the form 983 (link_name, link_value). The first list contains all links 984 for local fields, the second one contains those for foreign 985 fields/objects. 986 """ 987 links_to_local_values, links_to_related_values = [], [] 988 for link, value in data.iteritems(): 989 if link in cls.SERIALIZATION_LINKS_TO_FOLLOW: 990 # It's a foreign key 991 links_to_related_values.append((link, value)) 992 else: 993 # It's a local attribute or a foreign key 994 # we don't want to follow. 995 links_to_local_values.append((link, value)) 996 return links_to_local_values, links_to_related_values 997 998 999 @classmethod 1000 def _filter_update_allowed_fields(cls, data): 1001 """Filters data and returns only files that updates are allowed on. 1002 1003 This is i.e. needed for syncing aborted bits from the master to shards. 1004 1005 Local links are only allowed to be updated, if they are in 1006 SERIALIZATION_LOCAL_LINKS_TO_UPDATE. 1007 Overwriting existing values is allowed in order to be able to sync i.e. 1008 the aborted bit from the master to a shard. 1009 1010 The whitelisting mechanism is in place to prevent overwriting local 1011 status: If all fields were overwritten, jobs would be completely be 1012 set back to their original (unstarted) state. 1013 1014 @param data: List with tuples of the form (link_name, link_value), as 1015 returned by _split_local_from_foreign_values. 1016 1017 @returns List of the same format as data, but only containing data for 1018 fields that updates are allowed on. 1019 """ 1020 return [pair for pair in data 1021 if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE] 1022 1023 1024 @classmethod 1025 def delete_matching_record(cls, **filter_args): 1026 """Delete records matching the filter. 1027 1028 @param filter_args: Arguments for the django filter 1029 used to locate the record to delete. 1030 """ 1031 try: 1032 existing_record = cls.objects.get(**filter_args) 1033 except cls.DoesNotExist: 1034 return 1035 existing_record.delete() 1036 1037 1038 def _deserialize_local(self, data): 1039 """Set local attributes from a list of tuples. 1040 1041 @param data: List of tuples like returned by 1042 _split_local_from_foreign_values. 1043 """ 1044 if not data: 1045 return 1046 1047 for link, value in data: 1048 setattr(self, link, value) 1049 # Overwridden save() methods are prone to errors, so don't execute them. 1050 # This is because: 1051 # - the overwritten methods depend on ACL groups that don't yet exist 1052 # and don't handle errors 1053 # - the overwritten methods think this object already exists in the db 1054 # because the id is already set 1055 super(type(self), self).save() 1056 1057 1058 def _deserialize_relations(self, data): 1059 """Set foreign attributes from a list of tuples. 1060 1061 This deserialized the related objects using their own deserialize() 1062 function and then sets the relation. 1063 1064 @param data: List of tuples like returned by 1065 _split_local_from_foreign_values. 1066 """ 1067 for link, value in data: 1068 self._deserialize_relation(link, value) 1069 # See comment in _deserialize_local 1070 super(type(self), self).save() 1071 1072 1073 @classmethod 1074 def get_record(cls, data): 1075 """Retrieve a record with the data in the given input arg. 1076 1077 @param data: A dictionary containing the information to use in a query 1078 for data. If child models have different constraints of 1079 uniqueness they should override this model. 1080 1081 @return: An object with matching data. 1082 1083 @raises DoesNotExist: If a record with the given data doesn't exist. 1084 """ 1085 return cls.objects.get(id=data['id']) 1086 1087 1088 @classmethod 1089 def deserialize(cls, data): 1090 """Recursively deserializes and saves an object with it's dependencies. 1091 1092 This takes the result of the serialize method and creates objects 1093 in the database that are just like the original. 1094 1095 If an object of the same type with the same id already exists, it's 1096 local values will be left untouched, unless they are explicitly 1097 whitelisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE. 1098 1099 Deserialize will always recursively propagate to all related objects 1100 present in data though. 1101 I.e. this is necessary to add users to an already existing acl-group. 1102 1103 @param data: Representation of an object and its dependencies, as 1104 returned by serialize. 1105 1106 @returns: The object represented by data if it didn't exist before, 1107 otherwise the object that existed before and has the same type 1108 and id as the one described by data. 1109 """ 1110 if data is None: 1111 return None 1112 1113 local, related = cls._split_local_from_foreign_values(data) 1114 try: 1115 instance = cls.get_record(data) 1116 local = cls._filter_update_allowed_fields(local) 1117 except cls.DoesNotExist: 1118 instance = cls() 1119 1120 timer = autotest_stats.Timer('deserialize_latency.%s' % ( 1121 type(instance).__name__)) 1122 with timer.get_client('local'): 1123 instance._deserialize_local(local) 1124 with timer.get_client('related'): 1125 instance._deserialize_relations(related) 1126 1127 return instance 1128 1129 1130 def sanity_check_update_from_shard(self, shard, updated_serialized, 1131 *args, **kwargs): 1132 """Check if an update sent from a shard is legitimate. 1133 1134 @raises error.UnallowedRecordsSentToMaster if an update is not 1135 legitimate. 1136 """ 1137 raise NotImplementedError( 1138 'sanity_check_update_from_shard must be implemented by subclass %s ' 1139 'for type %s' % type(self)) 1140 1141 1142 @transaction.commit_on_success 1143 def update_from_serialized(self, serialized): 1144 """Updates local fields of an existing object from a serialized form. 1145 1146 This is different than the normal deserialize() in the way that it 1147 does update local values, which deserialize doesn't, but doesn't 1148 recursively propagate to related objects, which deserialize() does. 1149 1150 The use case of this function is to update job records on the master 1151 after the jobs have been executed on a slave, as the master is not 1152 interested in updates for users, labels, specialtasks, etc. 1153 1154 @param serialized: Representation of an object and its dependencies, as 1155 returned by serialize. 1156 1157 @raises ValueError: if serialized contains related objects, i.e. not 1158 only local fields. 1159 """ 1160 local, related = ( 1161 self._split_local_from_foreign_values(serialized)) 1162 if related: 1163 raise ValueError('Serialized must not contain foreign ' 1164 'objects: %s' % related) 1165 1166 self._deserialize_local(local) 1167 1168 1169 def custom_deserialize_relation(self, link, data): 1170 """Allows overriding the deserialization behaviour by subclasses.""" 1171 raise NotImplementedError( 1172 'custom_deserialize_relation must be implemented by subclass %s ' 1173 'for relation %s' % (type(self), link)) 1174 1175 1176 def _deserialize_relation(self, link, data): 1177 """Deserializes related objects and sets references on this object. 1178 1179 Relations that point to a list of objects are handled automatically. 1180 For many-to-one or one-to-one relations custom_deserialize_relation 1181 must be overridden by the subclass. 1182 1183 Related objects are deserialized using their deserialize() method. 1184 Thereby they and their dependencies are created if they don't exist 1185 and saved to the database. 1186 1187 @param link: Name of the relation. 1188 @param data: Serialized representation of the related object(s). 1189 This means a list of dictionaries for to-many relations, 1190 just a dictionary for to-one relations. 1191 """ 1192 field = getattr(self, link) 1193 1194 if field and hasattr(field, 'all'): 1195 self._deserialize_2m_relation(link, data, field.model) 1196 else: 1197 self.custom_deserialize_relation(link, data) 1198 1199 1200 def _deserialize_2m_relation(self, link, data, related_class): 1201 """Deserialize related objects for one to-many relationship. 1202 1203 @param link: Name of the relation. 1204 @param data: Serialized representation of the related objects. 1205 This is a list with of dictionaries. 1206 @param related_class: A class representing a django model, with which 1207 this class has a one-to-many relationship. 1208 """ 1209 relation_set = getattr(self, link) 1210 if related_class == self.get_attribute_model(): 1211 # When deserializing a model together with 1212 # its attributes, clear all the exising attributes to ensure 1213 # db consistency. Note 'update' won't be sufficient, as we also 1214 # want to remove any attributes that no longer exist in |data|. 1215 # 1216 # core_filters is a dictionary of filters, defines how 1217 # RelatedMangager would query for the 1-to-many relationship. E.g. 1218 # Host.objects.get( 1219 # id=20).hostattribute_set.core_filters = {host_id:20} 1220 # We use it to delete objects related to the current object. 1221 related_class.objects.filter(**relation_set.core_filters).delete() 1222 for serialized in data: 1223 relation_set.add(related_class.deserialize(serialized)) 1224 1225 1226 @classmethod 1227 def get_attribute_model(cls): 1228 """Return the attribute model. 1229 1230 Subclass with attribute-like model should override this to 1231 return the attribute model class. This method will be 1232 called by _deserialize_2m_relation to determine whether 1233 to clear the one-to-many relations first on deserialization of object. 1234 """ 1235 return None 1236 1237 1238class ModelWithInvalid(ModelExtensions): 1239 """ 1240 Overrides model methods save() and delete() to support invalidation in 1241 place of actual deletion. Subclasses must have a boolean "invalid" 1242 field. 1243 """ 1244 1245 def save(self, *args, **kwargs): 1246 first_time = (self.id is None) 1247 if first_time: 1248 # see if this object was previously added and invalidated 1249 my_name = getattr(self, self.name_field) 1250 filters = {self.name_field : my_name, 'invalid' : True} 1251 try: 1252 old_object = self.__class__.objects.get(**filters) 1253 self.resurrect_object(old_object) 1254 except self.DoesNotExist: 1255 # no existing object 1256 pass 1257 1258 super(ModelWithInvalid, self).save(*args, **kwargs) 1259 1260 1261 def resurrect_object(self, old_object): 1262 """ 1263 Called when self is about to be saved for the first time and is actually 1264 "undeleting" a previously deleted object. Can be overridden by 1265 subclasses to copy data as desired from the deleted entry (but this 1266 superclass implementation must normally be called). 1267 """ 1268 self.id = old_object.id 1269 1270 1271 def clean_object(self): 1272 """ 1273 This method is called when an object is marked invalid. 1274 Subclasses should override this to clean up relationships that 1275 should no longer exist if the object were deleted. 1276 """ 1277 pass 1278 1279 1280 def delete(self): 1281 self.invalid = self.invalid 1282 assert not self.invalid 1283 self.invalid = True 1284 self.save() 1285 self.clean_object() 1286 1287 1288 @classmethod 1289 def get_valid_manager(cls): 1290 return cls.valid_objects 1291 1292 1293 class Manipulator(object): 1294 """ 1295 Force default manipulators to look only at valid objects - 1296 otherwise they will match against invalid objects when checking 1297 uniqueness. 1298 """ 1299 @classmethod 1300 def _prepare(cls, model): 1301 super(ModelWithInvalid.Manipulator, cls)._prepare(model) 1302 cls.manager = model.valid_objects 1303 1304 1305class ModelWithAttributes(object): 1306 """ 1307 Mixin class for models that have an attribute model associated with them. 1308 The attribute model is assumed to have its value field named "value". 1309 """ 1310 1311 def _get_attribute_model_and_args(self, attribute): 1312 """ 1313 Subclasses should override this to return a tuple (attribute_model, 1314 keyword_args), where attribute_model is a model class and keyword_args 1315 is a dict of args to pass to attribute_model.objects.get() to get an 1316 instance of the given attribute on this object. 1317 """ 1318 raise NotImplementedError 1319 1320 1321 def set_attribute(self, attribute, value): 1322 attribute_model, get_args = self._get_attribute_model_and_args( 1323 attribute) 1324 attribute_object, _ = attribute_model.objects.get_or_create(**get_args) 1325 attribute_object.value = value 1326 attribute_object.save() 1327 1328 1329 def delete_attribute(self, attribute): 1330 attribute_model, get_args = self._get_attribute_model_and_args( 1331 attribute) 1332 try: 1333 attribute_model.objects.get(**get_args).delete() 1334 except attribute_model.DoesNotExist: 1335 pass 1336 1337 1338 def set_or_delete_attribute(self, attribute, value): 1339 if value is None: 1340 self.delete_attribute(attribute) 1341 else: 1342 self.set_attribute(attribute, value) 1343 1344 1345class ModelWithHashManager(dbmodels.Manager): 1346 """Manager for use with the ModelWithHash abstract model class""" 1347 1348 def create(self, **kwargs): 1349 raise Exception('ModelWithHash manager should use get_or_create() ' 1350 'instead of create()') 1351 1352 1353 def get_or_create(self, **kwargs): 1354 kwargs['the_hash'] = self.model._compute_hash(**kwargs) 1355 return super(ModelWithHashManager, self).get_or_create(**kwargs) 1356 1357 1358class ModelWithHash(dbmodels.Model): 1359 """Superclass with methods for dealing with a hash column""" 1360 1361 the_hash = dbmodels.CharField(max_length=40, unique=True) 1362 1363 objects = ModelWithHashManager() 1364 1365 class Meta: 1366 abstract = True 1367 1368 1369 @classmethod 1370 def _compute_hash(cls, **kwargs): 1371 raise NotImplementedError('Subclasses must override _compute_hash()') 1372 1373 1374 def save(self, force_insert=False, **kwargs): 1375 """Prevents saving the model in most cases 1376 1377 We want these models to be immutable, so the generic save() operation 1378 will not work. These models should be instantiated through their the 1379 model.objects.get_or_create() method instead. 1380 1381 The exception is that save(force_insert=True) will be allowed, since 1382 that creates a new row. However, the preferred way to make instances of 1383 these models is through the get_or_create() method. 1384 """ 1385 if not force_insert: 1386 # Allow a forced insert to happen; if it's a duplicate, the unique 1387 # constraint will catch it later anyways 1388 raise Exception('ModelWithHash is immutable') 1389 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs) 1390