diff --git a/modules/civit/Civit.py b/modules/civit/Civit.py new file mode 100644 index 0000000..aeed5c1 --- /dev/null +++ b/modules/civit/Civit.py @@ -0,0 +1,15 @@ +from modules.civit.fetch import Fetch +from modules.civit.datamodel import * +from modules.shared.DatabaseAbstraction import Database + + +class Civit: + def __init__(self, db: Database, path): + self._db = db + self.path = path + self.fetcher = Fetch() + Creator.create(self._db.cursor()) + Tag.create(self._db.cursor()) + + def creator_save(self, c: Creator): return c.save(self._db.cursor()) + def tag_save(self, t: Tag): return t.save(self._db.cursor()) diff --git a/modules/civit/datamodel.py b/modules/civit/datamodel.py index 0684473..6363f2f 100644 --- a/modules/civit/datamodel.py +++ b/modules/civit/datamodel.py @@ -1,10 +1,11 @@ from dataclasses import dataclass from typing import Optional -from .datamodel_base import ForwardingBase +from modules.shared.DataClassDatabase import DataClassDatabase + @dataclass -class Creator(ForwardingBase): +class Creator(DataClassDatabase): username: Optional[str] = None modelCount: Optional[int] = None link: Optional[str] = None @@ -14,9 +15,11 @@ class Creator(ForwardingBase): super().__post_init__() self._forwarding = {} self._key_field = 'username' + self._table_name = 'creators' + self._standalone_entity = True @dataclass -class Tag(ForwardingBase): +class Tag(DataClassDatabase): name: Optional[str] = None link: Optional[str] = None @@ -24,99 +27,102 @@ class Tag(ForwardingBase): super().__post_init__() self._forwarding = {} self._key_field = 'name' - -@dataclass -class ModelVersionStats(ForwardingBase): - downloadCount: Optional[int] = None - ratingCount: Optional[int] = None - rating: Optional[int] = None - thumbsUpCount: Optional[int] = None - thumbsDownCount: Optional[int] = None - - def __post_init__(self): - super().__post_init__() - self._forwarding = {} + self._table_name = 'tags' + self._standalone_entity = True - -@dataclass -class ModelVersion(ForwardingBase): - id: Optional[int] = None - index: Optional[int] = None - name: Optional[str] = None - baseModel: Optional[str] = None - baseModelType: Optional[str] = None - publishedAt: Optional[str] = None - availability: Optional[str] = None - nsfwLevel: Optional[int] = None - description: Optional[str] = None - trainedWords: Optional[list[str]] = None - stats: Optional[ModelVersionStats] = None - supportsGeneration: Optional[bool] = None - downloadUrl: Optional[str] = None - # FILES - # IMAGES - - - def __post_init__(self): - super().__post_init__() - self._forwarding = { - 'stats': ModelVersionStats, - } - self._key_field = 'id' - - - - -@dataclass -class ModelStats(ForwardingBase): - downloadCount: Optional[int] = None - favoriteCount: Optional[int] = None - thumbsUpCount: Optional[int] = None - thumbsDownCount: Optional[int] = None - commentCount: Optional[int] = None - ratingCount: Optional[int] = None - rating: Optional[int] = None - - def __post_init__(self): - super().__post_init__() - self._forwarding = {} - - - - -@dataclass -class Model(ForwardingBase): - id: Optional[int] = None - name: Optional[str] = None - description: Optional[str] = None - allowNoCredit: Optional[bool] = None - allowCommercialUse: Optional[list] = None - allowDerivatives: Optional[bool] = None - allowDifferentLicense: Optional[bool] = None - type: Optional[str] = None - minor: Optional[bool] = None - sfwOnly: Optional[bool] = None - poi: Optional[bool] = None - nsfw: Optional[bool] = None - nsfwLevel: Optional[int] = None - availability: Optional[str] = None - cosmetic: Optional[str] = None - supportsGeneration: Optional[bool] = None - stats: Optional[ModelStats] = None - creator: Optional[Creator] = None - tags: Optional[list[Tag]] = None - modelVersions: Optional[list[ModelVersion]] = None - - - def __post_init__(self): - super().__post_init__() - self._forwarding = { - 'stats': ModelStats, - 'creator': Creator, - 'tags': Tag, - 'modelVersions': ModelVersion, - } - self._key_field = 'id' - +# @dataclass +# class ModelVersionStats(ForwardingBase): +# downloadCount: Optional[int] = None +# ratingCount: Optional[int] = None +# rating: Optional[int] = None +# thumbsUpCount: Optional[int] = None +# thumbsDownCount: Optional[int] = None +# +# def __post_init__(self): +# super().__post_init__() +# self._forwarding = {} +# +# +# +# @dataclass +# class ModelVersion(ForwardingBase): +# id: Optional[int] = None +# index: Optional[int] = None +# name: Optional[str] = None +# baseModel: Optional[str] = None +# baseModelType: Optional[str] = None +# publishedAt: Optional[str] = None +# availability: Optional[str] = None +# nsfwLevel: Optional[int] = None +# description: Optional[str] = None +# trainedWords: Optional[list[str]] = None +# stats: Optional[ModelVersionStats] = None +# supportsGeneration: Optional[bool] = None +# downloadUrl: Optional[str] = None +# # FILES +# # IMAGES +# +# +# def __post_init__(self): +# super().__post_init__() +# self._forwarding = { +# 'stats': ModelVersionStats, +# } +# self._key_field = 'id' +# +# +# +# +# @dataclass +# class ModelStats(ForwardingBase): +# downloadCount: Optional[int] = None +# favoriteCount: Optional[int] = None +# thumbsUpCount: Optional[int] = None +# thumbsDownCount: Optional[int] = None +# commentCount: Optional[int] = None +# ratingCount: Optional[int] = None +# rating: Optional[int] = None +# +# def __post_init__(self): +# super().__post_init__() +# self._forwarding = {} +# +# +# +# +# @dataclass +# class Model(ForwardingBase): +# id: Optional[int] = None +# name: Optional[str] = None +# description: Optional[str] = None +# allowNoCredit: Optional[bool] = None +# allowCommercialUse: Optional[list] = None +# allowDerivatives: Optional[bool] = None +# allowDifferentLicense: Optional[bool] = None +# type: Optional[str] = None +# minor: Optional[bool] = None +# sfwOnly: Optional[bool] = None +# poi: Optional[bool] = None +# nsfw: Optional[bool] = None +# nsfwLevel: Optional[int] = None +# availability: Optional[str] = None +# cosmetic: Optional[str] = None +# supportsGeneration: Optional[bool] = None +# stats: Optional[ModelStats] = None +# creator: Optional[Creator] = None +# tags: Optional[list[Tag]] = None +# modelVersions: Optional[list[ModelVersion]] = None +# +# +# def __post_init__(self): +# super().__post_init__() +# self._forwarding = { +# 'stats': ModelStats, +# 'creator': Creator, +# 'tags': Tag, +# 'modelVersions': ModelVersion, +# } +# self._key_field = 'id' +# diff --git a/modules/shared/DataClassDatabase.py b/modules/shared/DataClassDatabase.py index 1ac2b23..7c97c38 100644 --- a/modules/shared/DataClassDatabase.py +++ b/modules/shared/DataClassDatabase.py @@ -1,5 +1,6 @@ +import datetime from dataclasses import dataclass, fields -from typing import Optional, List +from typing import Optional, List, get_origin from DataClassJson import DataClassJson from modules.shared.DatabaseAbstraction import Cursor @@ -9,7 +10,7 @@ types = {bool: 'INTEGER', int: 'INTEGER', float: 'REAL', str: "TEXT", @dataclass class DataClassDatabase(DataClassJson): - _main_entity: bool = None + _standalone_entity: bool = None _table_name: str = None pass @@ -18,12 +19,13 @@ class DataClassDatabase(DataClassJson): @classmethod def get_create_sqls(cls, table_name = None): - result: list[str] = list() - result.append(f'CREATE TABLE IF NOT EXISTS {table_name} (fk TEXT NOT NULL, pk TEXT NOT NULL, PRIMARY KEY(pk, fk));') - tmp_instance = cls() if not table_name: table_name = tmp_instance._table_name + result: list[str] = list() + result.append(f'CREATE TABLE IF NOT EXISTS {table_name} (fk TEXT NOT NULL, pk TEXT NOT NULL, PRIMARY KEY(pk, fk));') + result.append(f'CREATE TABLE IF NOT EXISTS {table_name}_archive (fk TEXT NOT NULL, pk TEXT NOT NULL, save_date TEXT NOT NULL, PRIMARY KEY(pk, fk, save_date));') + excluded_fields = {f.name for f in fields(DataClassDatabase)} all_fields = [f for f in fields(cls) if f.name not in excluded_fields and not f.name.startswith('_')] @@ -33,14 +35,133 @@ class DataClassDatabase(DataClassJson): try: result.extend(inner_type.get_create_sqls()) except Exception as e: raise RuntimeError('invalid forwarding type') from e elif field.type in { list, Optional[list], Optional[List] }: - result.append(f'CREATE TABLE IF NOT EXISTS {table_name}_{field.name} (fk TEXT NOT NULL, data TEXT NOT NULL);') + result.append(f'CREATE TABLE IF NOT EXISTS {table_name}_{field.name} (fk TEXT NOT NULL, data TEXT NOT NULL, PRIMARY KEY(data, fk));') else: result.append(f'ALTER TABLE {table_name} ADD COLUMN {field.name} {types.get(field.type, 'TEXT')};') + result.append(f'ALTER TABLE {table_name}_archive ADD COLUMN {field.name} {types.get(field.type, 'TEXT')};') return result @classmethod def create(cls, cur: Cursor): - for sql in cls.get_create_sqls(): cur.execute(sql) + for sql in cls.get_create_sqls(): + try: cur.execute(sql) + except Exception as e: print(e) + + @classmethod + def load(cls, cur: Cursor, pk=None, fk=None, depth = 5): + if not pk and not fk: return list() + params = list() + instance = cls() + + sql = f'SELECT pk, fk FROM {instance._table_name}' + if pk or fk: sql += ' WHERE' + if pk: + params.append(pk) + sql += ' pk = ?' + if pk and fk: sql += ' AND' + if fk: + params.append(fk) + sql += ' fk = ?' + res: list[dict] = cur.fetchall(sql, params) + del pk, fk, sql, params + results = list() + for r in res: + item = cls._load(cur, r.get('pk', None), r.get('fk', None), depth) + if item: results.append(item) + + return results + + + + @classmethod + def _load(cls, cur: Cursor, pk, fk, depth = 5): + if not pk and not fk: return None + instance = cls() + res: dict = cur.fetchone(f'SELECT * FROM {instance._table_name} WHERE pk = ? AND fk = ?', [pk, fk]) + if not res: return None + rpk = res.pop('pk') + rfk = res.pop('fk') + result = cls.from_dict(res) + + if depth == 0: return result + + for field in fields(cls): + print(field.name, field.type, get_origin(field.type)) + if field.name in instance._forwarding: + items = instance._forwarding[field.name].load(cur, fk=rpk, depth=depth - 1) + if len(items) > 1: setattr(result, field.name, items) # TODO Убрать костыль + elif len(items) > 0: setattr(result, field.name, items[0]) + + elif field.type in {list, List, Optional[list], Optional[List]}: + items = cur.fetchall(f'SELECT data from {instance._table_name}_{field.name} WHERE fk=?', [rpk]) + if items: + items = [row['data'] for row in items] + else: + items = list() + setattr(result, field.name, items) + + return result + + def save(self, cur: Cursor, fk = None): + if self._standalone_entity: fk = 0 + elif not fk: raise RuntimeError('Trying to save child entity as standalone') + + + pk = self.key if self._key_field != 'key' else 0 + prev = self._load(cur, pk=pk, fk=fk, depth=0) + + + if prev: + for field in self.serializable_fields(): + setattr(self, field.name, getattr(self, field.name) or getattr(prev, field.name)) + + if prev and not self.equals_simple(prev): + d = str(datetime.datetime.now()) + cur.execute(f'INSERT OR IGNORE INTO {prev._table_name}_archive (fk, pk, save_date) VALUES (?, ?, ?)', [fk, pk, d]) + for field in prev.serializable_fields(): + attr = getattr(prev, field.name) + if field.name in prev._forwarding: continue + elif field.type in {list, List, Optional[list], Optional[List]} or isinstance(attr, list): continue + else: + cur.execute(f'UPDATE {prev._table_name}_archive SET {field.name}=? WHERE fk=? AND pk=? AND save_date=?', [attr, fk, pk, d]) + + cur.execute(f'INSERT OR IGNORE INTO {self._table_name} (fk, pk) VALUES (?, ?)', [fk, pk]) + + for field in self.serializable_fields(): + attr = getattr(self, field.name) + if not attr: continue + + if field.name in self._forwarding: + if not isinstance(getattr(self, field.name), list): attr = [attr] + for val in attr: + val.save(cur, fk=pk) + continue + elif field.type in {list, List, Optional[list], Optional[List]} or isinstance(attr, list): + for val in attr: cur.execute(f'INSERT OR IGNORE INTO {self._table_name}_{field.name} VALUES (?, ?)', [pk, val]) + continue + else: + cur.execute(f'UPDATE {self._table_name} SET {field.name}=? WHERE fk=? AND pk=?', [attr, fk, pk]) + continue + + + + + + def equals_simple(self, obj): + for field in self.serializable_fields(): + if field.name in self._forwarding: continue + elif field.type in {list, List, Optional[list], Optional[List]}: continue + if getattr(self, field.name) != getattr(obj, field.name): + return False + return True + + + + + + + + if __name__ == '__main__': @dataclass @@ -56,6 +177,8 @@ if __name__ == '__main__': def __post_init__(self): super().__post_init__() self._forwarding = {} + self._table_name = 'model_stats' + @dataclass @@ -84,7 +207,27 @@ if __name__ == '__main__': 'stats': ModelStats, } self._key_field = 'id' - self._table_name = 'gagaga' + self._table_name = 'model' + self._standalone_entity = True for s in Model.get_create_sqls(): print(s) + + from modules.shared.DatabaseAbstraction import Database, Cursor + from modules.shared.DatabaseSqlite import SQLiteDatabase, SQLiteCursor + + db = SQLiteDatabase('gagaga', '/tmp') + Model.create(db.cursor()) + db.commit() + m = Model.load(db.cursor(), pk=42) + pdb = SQLiteDatabase('pidoras', '/tmp') + Model.create(pdb.cursor()) + pdb.commit() + m0: Model = m[0] + m0.save(pdb.cursor()) + pdb.commit() + m0.description = 'Abobus - avtobus' + m0.save(pdb.cursor()) + pdb.commit() + pass + diff --git a/modules/shared/DataClassJson.py b/modules/shared/DataClassJson.py index f04d0a8..a94ffab 100644 --- a/modules/shared/DataClassJson.py +++ b/modules/shared/DataClassJson.py @@ -102,6 +102,11 @@ class DataClassJson: return instance + @classmethod + def serializable_fields(cls): + excluded_fields = {f.name for f in fields(DataClassJson)} + return {f for f in fields(cls) if f.name not in excluded_fields and not f.name.startswith('_')} + def to_dict(self) -> Dict[str, Any]: result = {} excluded_fields = {f.name for f in fields(DataClassJson)}