|
|
@ -0,0 +1,167 @@ |
|
|
|
""" |
|
|
|
Provide a "Generic ForeignKey", similar to Django. A "GFK" is composed of two |
|
|
|
columns: an object ID and an object type identifier. The object types are |
|
|
|
collected in a global registry (all_models), so all you need to do is subclass |
|
|
|
``gfk.Model`` and your model will be added to the registry. |
|
|
|
|
|
|
|
Example: |
|
|
|
|
|
|
|
class Tag(Model): |
|
|
|
tag = CharField() |
|
|
|
object_type = CharField(null=True) |
|
|
|
object_id = IntegerField(null=True) |
|
|
|
object = GFKField('object_type', 'object_id') |
|
|
|
|
|
|
|
class Blog(Model): |
|
|
|
tags = ReverseGFK(Tag, 'object_type', 'object_id') |
|
|
|
|
|
|
|
class Photo(Model): |
|
|
|
tags = ReverseGFK(Tag, 'object_type', 'object_id') |
|
|
|
|
|
|
|
tag.object -> a blog or photo |
|
|
|
blog.tags -> select query of tags for ``blog`` instance |
|
|
|
Blog.tags -> select query of all tags for Blog instances |
|
|
|
""" |
|
|
|
|
|
|
|
from peewee import * |
|
|
|
from peewee import BaseModel as _BaseModel |
|
|
|
from peewee import Model as _Model |
|
|
|
from peewee import SelectQuery |
|
|
|
from peewee import UpdateQuery |
|
|
|
from peewee import with_metaclass |
|
|
|
|
|
|
|
|
|
|
|
all_models = set() |
|
|
|
table_cache = {} |
|
|
|
|
|
|
|
|
|
|
|
class BaseModel(_BaseModel): |
|
|
|
def __new__(cls, name, bases, attrs): |
|
|
|
cls = super(BaseModel, cls).__new__(cls, name, bases, attrs) |
|
|
|
if name not in ('_metaclass_helper_', 'Model'): |
|
|
|
all_models.add(cls) |
|
|
|
return cls |
|
|
|
|
|
|
|
class Model(with_metaclass(BaseModel, _Model)): |
|
|
|
pass |
|
|
|
|
|
|
|
def get_model(tbl_name): |
|
|
|
if tbl_name not in table_cache: |
|
|
|
for model in all_models: |
|
|
|
if model._meta.db_table == tbl_name: |
|
|
|
table_cache[tbl_name] = model |
|
|
|
break |
|
|
|
return table_cache.get(tbl_name) |
|
|
|
|
|
|
|
class BoundGFKField(object): |
|
|
|
__slots__ = ('model_class', 'gfk_field') |
|
|
|
|
|
|
|
def __init__(self, model_class, gfk_field): |
|
|
|
self.model_class = model_class |
|
|
|
self.gfk_field = gfk_field |
|
|
|
|
|
|
|
@property |
|
|
|
def unique(self): |
|
|
|
indexes = self.model_class._meta.indexes |
|
|
|
fields = set((self.gfk_field.model_type_field, |
|
|
|
self.gfk_field.model_id_field)) |
|
|
|
for (indexed_columns, is_unique) in indexes: |
|
|
|
if not fields - set(indexed_columns): |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
@property |
|
|
|
def primary_key(self): |
|
|
|
pk = self.model_class._meta.primary_key |
|
|
|
if isinstance(pk, CompositeKey): |
|
|
|
fields = set((self.gfk_field.model_type_field, |
|
|
|
self.gfk_field.model_id_field)) |
|
|
|
if not fields - set(pk.field_names): |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
def __eq__(self, other): |
|
|
|
meta = self.model_class._meta |
|
|
|
type_field = meta.fields[self.gfk_field.model_type_field] |
|
|
|
id_field = meta.fields[self.gfk_field.model_id_field] |
|
|
|
return ( |
|
|
|
(type_field == other._meta.db_table) & |
|
|
|
(id_field == other._get_pk_value())) |
|
|
|
|
|
|
|
def __ne__(self, other): |
|
|
|
other_cls = type(other) |
|
|
|
type_field = other._meta.fields[self.gfk_field.model_type_field] |
|
|
|
id_field = other._meta.fields[self.gfk_field.model_id_field] |
|
|
|
return ( |
|
|
|
(type_field == other._meta.db_table) & |
|
|
|
(id_field != other._get_pk_value())) |
|
|
|
|
|
|
|
|
|
|
|
class GFKField(object): |
|
|
|
def __init__(self, model_type_field='object_type', |
|
|
|
model_id_field='object_id'): |
|
|
|
self.model_type_field = model_type_field |
|
|
|
self.model_id_field = model_id_field |
|
|
|
self.att_name = '.'.join((self.model_type_field, self.model_id_field)) |
|
|
|
|
|
|
|
def get_obj(self, instance): |
|
|
|
data = instance._data |
|
|
|
if data.get(self.model_type_field) and data.get(self.model_id_field): |
|
|
|
tbl_name = data[self.model_type_field] |
|
|
|
model_class = get_model(tbl_name) |
|
|
|
if not model_class: |
|
|
|
raise AttributeError('Model for table "%s" not found in GFK ' |
|
|
|
'lookup.' % tbl_name) |
|
|
|
query = model_class.select().where( |
|
|
|
model_class._meta.primary_key == data[self.model_id_field]) |
|
|
|
return query.get() |
|
|
|
|
|
|
|
def __get__(self, instance, instance_type=None): |
|
|
|
if instance: |
|
|
|
if self.att_name not in instance._obj_cache: |
|
|
|
rel_obj = self.get_obj(instance) |
|
|
|
if rel_obj: |
|
|
|
instance._obj_cache[self.att_name] = rel_obj |
|
|
|
return instance._obj_cache.get(self.att_name) |
|
|
|
return BoundGFKField(instance_type, self) |
|
|
|
|
|
|
|
def __set__(self, instance, value): |
|
|
|
instance._obj_cache[self.att_name] = value |
|
|
|
instance._data[self.model_type_field] = value._meta.db_table |
|
|
|
instance._data[self.model_id_field] = value._get_pk_value() |
|
|
|
|
|
|
|
|
|
|
|
class ReverseGFK(object): |
|
|
|
def __init__(self, model, model_type_field='object_type', |
|
|
|
model_id_field='object_id'): |
|
|
|
self.model_class = model |
|
|
|
self.model_type_field = model._meta.fields[model_type_field] |
|
|
|
self.model_id_field = model._meta.fields[model_id_field] |
|
|
|
|
|
|
|
def __get__(self, instance, instance_type=None): |
|
|
|
if instance: |
|
|
|
return self.model_class.select().where( |
|
|
|
(self.model_type_field == instance._meta.db_table) & |
|
|
|
(self.model_id_field == instance._get_pk_value()) |
|
|
|
) |
|
|
|
else: |
|
|
|
return self.model_class.select().where( |
|
|
|
self.model_type_field == instance_type._meta.db_table |
|
|
|
) |
|
|
|
|
|
|
|
def __set__(self, instance, value): |
|
|
|
mtv = instance._meta.db_table |
|
|
|
miv = instance._get_pk_value() |
|
|
|
if (isinstance(value, SelectQuery) and |
|
|
|
value.model_class == self.model_class): |
|
|
|
UpdateQuery(self.model_class, { |
|
|
|
self.model_type_field: mtv, |
|
|
|
self.model_id_field: miv, |
|
|
|
}).where(value._where).execute() |
|
|
|
elif all(map(lambda i: isinstance(i, self.model_class), value)): |
|
|
|
for obj in value: |
|
|
|
setattr(obj, self.model_type_field.name, mtv) |
|
|
|
setattr(obj, self.model_id_field.name, miv) |
|
|
|
obj.save() |
|
|
|
else: |
|
|
|
raise ValueError('ReverseGFK field unable to handle "%s"' % value) |