Files
vaiola/modules/shared/DataClassDatabase.py
2025-10-16 18:42:32 +07:00

239 lines
8.7 KiB
Python

import datetime
from dataclasses import dataclass, fields
from typing import Optional, List, get_origin
from .DataClassJson import DataClassJson
from modules.shared.DatabaseAbstraction import Cursor
types = {bool: 'INTEGER', int: 'INTEGER', float: 'REAL', str: "TEXT",
Optional[bool]: 'INTEGER', Optional[int]: 'INTEGER', Optional[float]: 'REAL', Optional[str]: "TEXT", }
@dataclass
class DataClassDatabase(DataClassJson):
_standalone_entity: bool = None
_table_name: str = None
pass
def __post_init__(self):
super().__post_init__()
@classmethod
def get_create_sqls(cls, table_name = None):
tmp_instance = cls()
if not table_name: table_name = tmp_instance._table_name
pk_type = str
for field in fields(tmp_instance):
if field.name == tmp_instance._key_field:
pk_type = field.type
result: list[str] = list()
result.append(f'CREATE TABLE IF NOT EXISTS "{table_name}" (fk INTEGER NOT NULL, pk {types.get(pk_type, 'INTEGER')} NOT NULL, PRIMARY KEY(pk, fk));')
result.append(f'CREATE TABLE IF NOT EXISTS "{table_name}_archive" (fk INTEGER NOT NULL, pk {types.get(pk_type, 'INTEGER')} NOT NULL, save_date TEXT NOT NULL, PRIMARY KEY(pk, fk, save_date));')
excluded_fields = {f.name for f in fields(DataClassDatabase)}
all_fields = [f for f in fields(cls) if f.name not in excluded_fields and not f.name.startswith('_')]
for field in all_fields:
if field.name in tmp_instance._forwarding:
inner_type: type = tmp_instance._forwarding[field.name]
try: result.extend(inner_type.get_create_sqls())
except Exception as e: raise RuntimeError('invalid forwarding type') from e
elif field.type in { list, Optional[list], Optional[List] }:
result.append(f'CREATE TABLE IF NOT EXISTS "{table_name}_{field.name}" (fk TEXT NOT NULL, data TEXT NOT NULL, PRIMARY KEY(data, fk));')
else:
result.append(f'ALTER TABLE "{table_name}" ADD COLUMN "{field.name}" {types.get(field.type, 'TEXT')};')
result.append(f'ALTER TABLE "{table_name}_archive" ADD COLUMN "{field.name}" {types.get(field.type, 'TEXT')};')
return result
@classmethod
def create(cls, cur: Cursor):
for sql in cls.get_create_sqls():
try: cur.execute(sql)
except Exception as e: print(e)
@classmethod
def load(cls, cur: Cursor, pk=None, fk=None, depth = 5):
if not pk and not fk: return list()
params = list()
instance = cls()
sql = f'SELECT pk, fk FROM "{instance._table_name}"'
if pk or fk: sql += ' WHERE'
if pk:
params.append(pk)
sql += ' pk = ?'
if pk and fk: sql += ' AND'
if fk:
params.append(fk)
sql += ' fk = ?'
res: list[dict] = cur.fetchall(sql, params)
del pk, fk, sql, params
results = list()
for r in res:
item = cls._load(cur, r.get('pk', None), r.get('fk', None), depth)
if item: results.append(item)
return results
@classmethod
def _load(cls, cur: Cursor, pk, fk, depth = 5):
if not pk and not fk: return None
instance = cls()
res: dict = cur.fetchone(f'SELECT * FROM "{instance._table_name}" WHERE pk = ? AND fk = ?', [pk, fk])
if not res: return None
rpk = res.pop('pk')
rfk = res.pop('fk')
result = cls.from_dict(res)
if depth == 0: return result
for field in fields(cls):
print(field.name, field.type, get_origin(field.type))
if field.name in instance._forwarding:
items = instance._forwarding[field.name].load(cur, fk=rpk, depth=depth - 1)
if len(items) > 1: setattr(result, field.name, items) # TODO Убрать костыль
elif len(items) > 0: setattr(result, field.name, items[0])
elif field.type in {list, List, Optional[list], Optional[List]}:
items = cur.fetchall(f'SELECT data from "{instance._table_name}_{field.name}" WHERE fk=?', [rpk])
if items:
items = [row['data'] for row in items]
else:
items = list()
setattr(result, field.name, items)
return result
def save(self, cur: Cursor, fk = None):
if self._standalone_entity: fk = 0
elif not fk: raise RuntimeError('Trying to save child entity as standalone')
pk = self.key if self._key_field != 'key' else 0
prev = self._load(cur, pk=pk, fk=fk, depth=0)
if prev:
for field in self.serializable_fields():
setattr(self, field.name, getattr(self, field.name) or getattr(prev, field.name))
if prev and not self.equals_simple(prev):
d = str(datetime.datetime.now())
cur.execute(f'INSERT OR IGNORE INTO "{prev._table_name}_archive" (fk, pk, save_date) VALUES (?, ?, ?)', [fk, pk, d])
for field in prev.serializable_fields():
attr = getattr(prev, field.name)
if field.name in prev._forwarding: continue
elif field.type in {list, List, Optional[list], Optional[List]} or isinstance(attr, list): continue
else:
cur.execute(f'UPDATE "{prev._table_name}_archive" SET {field.name}=? WHERE fk=? AND pk=? AND save_date=?', [attr, fk, pk, d])
cur.execute(f'INSERT OR IGNORE INTO "{self._table_name}" (fk, pk) VALUES (?, ?)', [fk, pk])
for field in self.serializable_fields():
attr = getattr(self, field.name)
if not attr: continue
if field.name in self._forwarding:
if not isinstance(getattr(self, field.name), list): attr = [attr]
for val in attr:
val.autosave(cur, fk=pk)
continue
elif field.type in {list, List, Optional[list], Optional[List]} or isinstance(attr, list):
for val in attr: cur.execute(f'INSERT OR IGNORE INTO "{self._table_name}_{field.name}" VALUES (?, ?)', [pk, val])
continue
else:
cur.execute(f'UPDATE "{self._table_name}" SET "{field.name}"=? WHERE fk=? AND pk=?', [attr, fk, pk])
continue
def equals_simple(self, obj):
for field in self.serializable_fields():
if field.name in self._forwarding: continue
elif field.type in {list, List, Optional[list], Optional[List]}: continue
if getattr(self, field.name) != getattr(obj, field.name):
return False
return True
if __name__ == '__main__':
@dataclass
class ModelStats(DataClassDatabase):
downloadCount: Optional[int] = None
favoriteCount: Optional[int] = None
thumbsUpCount: Optional[int] = None
thumbsDownCount: Optional[int] = None
commentCount: Optional[int] = None
ratingCount: Optional[int] = None
rating: Optional[int] = None
def __post_init__(self):
super().__post_init__()
self._forwarding = {}
self._table_name = 'model_stats'
@dataclass
class Model(DataClassDatabase):
id: Optional[int] = None
name: Optional[str] = None
description: Optional[str] = None
allowNoCredit: Optional[bool] = None
allowCommercialUse: Optional[list] = None
allowDerivatives: Optional[bool] = None
allowDifferentLicense: Optional[bool] = None
type: Optional[str] = None
minor: Optional[bool] = None
sfwOnly: Optional[bool] = None
poi: Optional[bool] = None
nsfw: Optional[bool] = None
nsfwLevel: Optional[int] = None
availability: Optional[str] = None
cosmetic: Optional[str] = None
supportsGeneration: Optional[bool] = None
stats: Optional[ModelStats] = None
def __post_init__(self):
super().__post_init__()
self._forwarding = {
'stats': ModelStats,
}
self._key_field = 'id'
self._table_name = 'model'
self._standalone_entity = True
for s in Model.get_create_sqls():
print(s)
from modules.shared.DatabaseAbstraction import Database, Cursor
from modules.shared.DatabaseSqlite import SQLiteDatabase, SQLiteCursor
db = SQLiteDatabase('gagaga', '/tmp')
Model.create(db.cursor())
db.commit()
m = Model.load(db.cursor(), pk=42)
pdb = SQLiteDatabase('pidoras', '/tmp')
Model.create(pdb.cursor())
pdb.commit()
m0: Model = m[0]
m0.save(pdb.cursor())
pdb.commit()
m0.description = 'Abobus - avtobus'
m0.save(pdb.cursor())
pdb.commit()
pass