import datetime from dataclasses import dataclass, fields from typing import Optional, List, get_origin from .DataClassJson import DataClassJson from modules.shared.DatabaseAbstraction import Cursor types = {bool: 'INTEGER', int: 'INTEGER', float: 'REAL', str: "TEXT", Optional[bool]: 'INTEGER', Optional[int]: 'INTEGER', Optional[float]: 'REAL', Optional[str]: "TEXT", } @dataclass class DataClassDatabase(DataClassJson): _standalone_entity: bool = None _table_name: str = None pass def __post_init__(self): super().__post_init__() @classmethod def get_create_sqls(cls, table_name = None): tmp_instance = cls() if not table_name: table_name = tmp_instance._table_name pk_type = str for field in fields(tmp_instance): if field.name == tmp_instance._key_field: pk_type = field.type result: list[str] = list() result.append(f'CREATE TABLE IF NOT EXISTS "{table_name}" (fk INTEGER NOT NULL, pk {types.get(pk_type, 'INTEGER')} NOT NULL, PRIMARY KEY(pk, fk));') result.append(f'CREATE TABLE IF NOT EXISTS "{table_name}_archive" (fk INTEGER NOT NULL, pk {types.get(pk_type, 'INTEGER')} NOT NULL, save_date TEXT NOT NULL, PRIMARY KEY(pk, fk, save_date));') excluded_fields = {f.name for f in fields(DataClassDatabase)} all_fields = [f for f in fields(cls) if f.name not in excluded_fields and not f.name.startswith('_')] for field in all_fields: if field.name in tmp_instance._forwarding: inner_type: type = tmp_instance._forwarding[field.name] try: result.extend(inner_type.get_create_sqls()) except Exception as e: raise RuntimeError('invalid forwarding type') from e elif field.type in { list, Optional[list], Optional[List] }: result.append(f'CREATE TABLE IF NOT EXISTS "{table_name}_{field.name}" (fk TEXT NOT NULL, data TEXT NOT NULL, PRIMARY KEY(data, fk));') else: result.append(f'ALTER TABLE "{table_name}" ADD COLUMN "{field.name}" {types.get(field.type, 'TEXT')};') result.append(f'ALTER TABLE "{table_name}_archive" ADD COLUMN "{field.name}" {types.get(field.type, 'TEXT')};') return result @classmethod def create(cls, cur: Cursor): for sql in cls.get_create_sqls(): try: cur.execute(sql) except Exception as e: print(e) @classmethod def load(cls, cur: Cursor, pk=None, fk=None, depth = 5): if not pk and not fk: return list() params = list() instance = cls() sql = f'SELECT pk, fk FROM "{instance._table_name}"' if pk or fk: sql += ' WHERE' if pk: params.append(pk) sql += ' pk = ?' if pk and fk: sql += ' AND' if fk: params.append(fk) sql += ' fk = ?' res: list[dict] = cur.fetchall(sql, params) del pk, fk, sql, params results = list() for r in res: item = cls._load(cur, r.get('pk', None), r.get('fk', None), depth) if item: results.append(item) return results @classmethod def _load(cls, cur: Cursor, pk, fk, depth = 5): if not pk and not fk: return None instance = cls() res: dict = cur.fetchone(f'SELECT * FROM "{instance._table_name}" WHERE pk = ? AND fk = ?', [pk, fk]) if not res: return None rpk = res.pop('pk') rfk = res.pop('fk') result = cls.from_dict(res) if depth == 0: return result for field in fields(cls): print(field.name, field.type, get_origin(field.type)) if field.name in instance._forwarding: items = instance._forwarding[field.name].load(cur, fk=rpk, depth=depth - 1) if len(items) > 1: setattr(result, field.name, items) # TODO Убрать костыль elif len(items) > 0: setattr(result, field.name, items[0]) elif field.type in {list, List, Optional[list], Optional[List]}: items = cur.fetchall(f'SELECT data from "{instance._table_name}_{field.name}" WHERE fk=?', [rpk]) if items: items = [row['data'] for row in items] else: items = list() setattr(result, field.name, items) return result def save(self, cur: Cursor, fk = None): if self._standalone_entity: fk = 0 elif not fk: raise RuntimeError('Trying to save child entity as standalone') pk = self.key if self._key_field != 'key' else 0 prev = self._load(cur, pk=pk, fk=fk, depth=0) if prev: for field in self.serializable_fields(): setattr(self, field.name, getattr(self, field.name) or getattr(prev, field.name)) if prev and not self.equals_simple(prev): d = str(datetime.datetime.now()) cur.execute(f'INSERT OR IGNORE INTO "{prev._table_name}_archive" (fk, pk, save_date) VALUES (?, ?, ?)', [fk, pk, d]) for field in prev.serializable_fields(): attr = getattr(prev, field.name) if field.name in prev._forwarding: continue elif field.type in {list, List, Optional[list], Optional[List]} or isinstance(attr, list): continue else: cur.execute(f'UPDATE "{prev._table_name}_archive" SET {field.name}=? WHERE fk=? AND pk=? AND save_date=?', [attr, fk, pk, d]) cur.execute(f'INSERT OR IGNORE INTO "{self._table_name}" (fk, pk) VALUES (?, ?)', [fk, pk]) for field in self.serializable_fields(): attr = getattr(self, field.name) if not attr: continue if field.name in self._forwarding: if not isinstance(getattr(self, field.name), list): attr = [attr] for val in attr: val.autosave(cur, fk=pk) continue elif field.type in {list, List, Optional[list], Optional[List]} or isinstance(attr, list): for val in attr: cur.execute(f'INSERT OR IGNORE INTO "{self._table_name}_{field.name}" VALUES (?, ?)', [pk, val]) continue else: cur.execute(f'UPDATE "{self._table_name}" SET "{field.name}"=? WHERE fk=? AND pk=?', [attr, fk, pk]) continue def equals_simple(self, obj): for field in self.serializable_fields(): if field.name in self._forwarding: continue elif field.type in {list, List, Optional[list], Optional[List]}: continue if getattr(self, field.name) != getattr(obj, field.name): return False return True if __name__ == '__main__': @dataclass class ModelStats(DataClassDatabase): downloadCount: Optional[int] = None favoriteCount: Optional[int] = None thumbsUpCount: Optional[int] = None thumbsDownCount: Optional[int] = None commentCount: Optional[int] = None ratingCount: Optional[int] = None rating: Optional[int] = None def __post_init__(self): super().__post_init__() self._forwarding = {} self._table_name = 'model_stats' @dataclass class Model(DataClassDatabase): id: Optional[int] = None name: Optional[str] = None description: Optional[str] = None allowNoCredit: Optional[bool] = None allowCommercialUse: Optional[list] = None allowDerivatives: Optional[bool] = None allowDifferentLicense: Optional[bool] = None type: Optional[str] = None minor: Optional[bool] = None sfwOnly: Optional[bool] = None poi: Optional[bool] = None nsfw: Optional[bool] = None nsfwLevel: Optional[int] = None availability: Optional[str] = None cosmetic: Optional[str] = None supportsGeneration: Optional[bool] = None stats: Optional[ModelStats] = None def __post_init__(self): super().__post_init__() self._forwarding = { 'stats': ModelStats, } self._key_field = 'id' self._table_name = 'model' self._standalone_entity = True for s in Model.get_create_sqls(): print(s) from modules.shared.DatabaseAbstraction import Database, Cursor from modules.shared.DatabaseSqlite import SQLiteDatabase, SQLiteCursor db = SQLiteDatabase('gagaga', '/tmp') Model.create(db.cursor()) db.commit() m = Model.load(db.cursor(), pk=42) pdb = SQLiteDatabase('pidoras', '/tmp') Model.create(pdb.cursor()) pdb.commit() m0: Model = m[0] m0.save(pdb.cursor()) pdb.commit() m0.description = 'Abobus - avtobus' m0.save(pdb.cursor()) pdb.commit() pass