Files
vaiola/modelspace/ModelPackageSubRepository.py
2025-10-16 18:42:32 +07:00

374 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import uuid
import warnings
from pathlib import Path
from modelspace.Essentials import SetsDict, select_elements, format_bytes
from modelspace.ModelPackage import ModelPackage, PackageInfo
from modelspace.ModelPackageCollection import ModelPackageCollection
from modules.civit.client import Client
class ModelPackageSubRepository:
def __init__(self, path, seed):
path.mkdir(exist_ok=True)
self.path = Path(path)
self.seed = seed
self.packages: dict[str, ModelPackage] | None = None
self.package_names: set[str] | None = None
self.resources: SetsDict | None = None
self.collections: dict[str, ModelPackageCollection] | None = None
self.reload()
def _reload_packages(self):
self.packages = dict()
try:
dirs = [item.name for item in self.path.iterdir() if item.is_dir()]
except OSError as e:
print(f"Ошибка доступа к директории: {e}")
dirs = []
for d in dirs:
package = ModelPackage.load(str(self.path / d))
self.packages[package.uuid] = package
self.package_names = {p.name for id, p in self.packages.items()}
def _reload_resources(self):
self.resources = SetsDict()
for pkg_id, package in self.packages.items():
for resource in package.provides:
self.resources.add(resource, pkg_id)
def _reload_collections(self):
try:
filenames = [item.name for item in self.path.iterdir() if item.is_file() and item.name.endswith('_collection.json')]
except OSError as e:
print(f"Ошибка доступа к директории: {e}")
return
self.collections = dict()
for filename in filenames:
collection = ModelPackageCollection(filename=str(Path(self.path) / filename))
self.collections[collection.name] = collection
def reload(self):
self._reload_packages()
self._reload_resources()
self._reload_collections()
def add_package_to_collection(self, pkg_name, collection_name, category: str = None, internal=True):
if pkg_name not in self.package_names:
if pkg_name in self.resources.keys or pkg_name in self.collections: raise RuntimeWarning('Only packages allowed to add in collections')
else: raise RuntimeWarning(f'Package {pkg_name} not found')
if collection_name not in self.collections:
self.collections[collection_name] = ModelPackageCollection(
self.path / (collection_name + '_collection.json'), name=collection_name, autosave=True
)
self.collections[collection_name].add_package(pkg_name, category, internal)
def resources_from_pkg_list(self, uuids: list[str]):
selected_packages = []
for pkg_id in uuids:
package = self.packages.get(pkg_id, None)
selected_packages.append(package)
resources = SetsDict()
for package in selected_packages:
for resource in package.provides:
resources.add(resource, package.uuid)
return resources
@staticmethod
def deps_from_pkg_list(packages: list[ModelPackage]) -> set[str]:
res = set()
for package in packages: res = res | set(package.dependencies)
return res
def packages_by_resource(self, resource):
packages_ids = self.resources.by_key(resource)
if not packages_ids or len(packages_ids) == 0:
raise RuntimeError(f"{resource}: There are no packages in the repository that provide this resource")
else:
packages_ids = list(packages_ids)
packages: set[ModelPackage] = set()
for pkg_id in packages_ids: packages.add(self.package_by_id(pkg_id))
return packages
def package_by_id(self, pkg_id):
package = self.packages.get(pkg_id, None)
if not package: raise RuntimeError(f"{pkg_id}: Something went wrong while reading package info")
return package
def add_package_interactive(self) -> ModelPackage:
"""Добавляет новый пакет модели интерактивно"""
# Генерируем новый UUID
package_uuid = str(uuid.uuid4())
# Создаем путь к новому пакету
package_path = self.path / package_uuid
# Вызываем интерактивное создание пакета
package = ModelPackage.interactive(str(package_path), package_uuid)
loaded_package = ModelPackage.load(str(package_path))
self.packages[loaded_package.uuid] = loaded_package
# Добавляем пакет в коллекции
self._add_package_to_collections_interactive(package)
return package
def _add_package_to_collections_interactive(self, package: ModelPackage):
while True:
print('Input collections, blank for stop')
collection = input().strip()
if collection == '': break
external = input('External? (blank for no): ').strip()
if external != '':
self.add_package_to_collection(package.name, collection, category=None, internal=False)
continue
category = input('Category: ').strip()
if category == '':
self.add_package_to_collection(package.name, collection, category=None, internal=True)
continue
else:
self.add_package_to_collection(package.name, collection, category, internal=True)
continue
def pull_civit_package(self, client: Client, model_id: int, version_id: int = None, file_id: int = None):
model_info = client.get_model_raw(model_id)
model_versions = model_info.get('modelVersions', None)
if not model_versions:
warnings.warn(f'Unable to find model {model_id}')
return
pull_candidates = list()
print('Model name:', model_info.get('name', None))
ic_package_type = model_info.get('type', None)
ic_tags = model_info.get('tags', None)
for model_version in model_versions:
if not model_version.get('availability', None) or model_version.get('availability', None) != 'Public': continue
ic_version_id = model_version.get('id', None)
ic_version = model_version.get('name', None)
ic_release_date = model_version.get('publishedAt', None)
ic_lineage = model_version.get('baseModel', None)
ic_images = None
images = model_version.get('images', None)
if images and isinstance(images, list): ic_images = [i.get('url', None) for i in images if i.get('url', None) is not None]
ic_provides = [f'civit-{model_id}-{ic_version_id}', f'civit-{model_id}'].copy()
for file in model_version.get('files', list()):
ic_size_bytes = file.get('sizeKB', None)
if ic_size_bytes and isinstance(ic_size_bytes, float):
ic_size_bytes = int(ic_size_bytes * 1024)
metadata = file.get('metadata', None)
ic_quantisation = None
if metadata:
ic_quantisation = metadata.get('fp', None)
ic_file_id = file.get('id', None)
ic_filename = file.get('name', None)
if file.get('type', None) and file.get('type', None) != 'Model':
continue
ic_url = file.get('downloadUrl', None)
ic_model_info = model_info.copy()
ic_name = f'civit-{model_id}-{ic_version_id}-{ic_file_id}'
ic_uuid = ic_name
pull_candidates.append({
'uuid': ic_uuid,
'name': ic_name,
'provides': ic_provides,
'version_id': ic_version_id,
'file_id': ic_file_id,
'package_type': ic_package_type.lower(),
'tags': ic_tags,
'version': ic_version,
'release_date': ic_release_date,
'lineage': ic_lineage,
'images': ic_images or list(),
'size_bytes': ic_size_bytes,
'quantisation': ic_quantisation or '',
'url': ic_url,
'filename': ic_filename,
'model_info': ic_model_info,
})
try:
del file, ic_url, ic_package_type, ic_images, ic_release_date, ic_tags, ic_filename, ic_lineage, model_info
del ic_model_info, ic_quantisation, ic_size_bytes, ic_version, images, metadata, model_version, model_versions
del ic_file_id, ic_version_id, ic_uuid, ic_name, ic_provides
except Exception:
pass
# check already pulled packages
already_pulled = list()
available_to_pull = list()
for candidate in pull_candidates:
if candidate['name'] in self.package_names: already_pulled.append(candidate)
else: available_to_pull.append(candidate)
if version_id: available_to_pull = [p for p in available_to_pull if p['version_id'] == version_id]
if file_id: available_to_pull = [p for p in available_to_pull if p['file_id'] == file_id]
if len(available_to_pull) == 0:
warnings.warn(f'Pull candidate not found for model_id:{model_id} and version_id:{version_id} and file_id:{file_id}')
return
# selection output
if len(already_pulled) > 0:
print('Already pulled packages:')
print(f' {'N':<{2}} {'version':<{10}} {'type':<{10}} {'release_date':<{25}}'
f' {'lineage':<{10}} {'quant':<{5}} {'size':<{10}} ')
for candidate in already_pulled:
print(
f' {'N':<{2}} {candidate['version']:<{10}} {candidate['package_type']:<{10}} {candidate['release_date']:<{25}}'
f' {candidate['lineage']:<{10}} {candidate['quantisation']:<{5}} {format_bytes(candidate['size_bytes']):<{10}} ')
if len(available_to_pull) == 0:
print('All available packages already pulled')
return
else:
print('Available packages:')
print(f' {'N':<{2}} {'version':<{10}} {'type':<{10}} {'release_date':<{25}}'
f' {'lineage':<{10}} {'quant':<{5}} {'size':<{10}} ')
for i in range(len(available_to_pull)):
candidate = available_to_pull[i]
quantisation = candidate['quantisation'] or 'N/A'
print(f' {i:<{2}} {candidate['version']:<{10}} {candidate['package_type']:<{10}} {candidate['release_date']:<{25}}'
f' {candidate['lineage']:<{10}} {quantisation:<{5}} {format_bytes(candidate['size_bytes']):<{10}} ')
if len(available_to_pull) > 1: to_pull = select_elements(pull_candidates, input("Your choice: "))
else: to_pull = available_to_pull
# Ввод зависимостей
print("Зависимости (введите по одной, п1658427устая строка для завершения):")
additional_dependencies = []
while True:
dep = input().strip()
if not dep:
break
additional_dependencies.append(dep)
# Ввод ресурсов
print("Ресурсы (введите по одному, пустая строка для завершения):")
additional_resources = []
while True:
resource = input().strip()
if not resource:
break
additional_resources.append(resource)
print("Теги (введите по одному, пустая строка для завершения):")
additional_tags = []
while True:
tag = input().strip()
if not tag:
break
additional_tags.append(resource)
while True:
print('One collection for all selected packages, blank for None')
collection = input().strip()
if collection == '': break
external = input('External? (blank for no): ').strip()
if external != '':
category: str | None = None
internal = False
break
category = str(input('Category: ')).strip()
if category == '':
category: str | None = None
internal = True
break
else:
internal = True
break
pulled: list[ModelPackage] = list()
for candidate in to_pull:
package_path = self.path / candidate['uuid']
if os.path.exists(str(Path(package_path) / "package.json")): raise RuntimeError("package exists!")
package_info = PackageInfo(str(Path(package_path) / "package.json"))
package_info.uuid = candidate['uuid']
package_info.name = candidate['name']
# TODO список дополнительных ресурсов, зависимостей, по линейке
# TODO add deps and resources based on lineage (use civit lineages)
package_info.resources = candidate['provides'].copy()
package_info.resources.extend(additional_resources)
package_info.dependencies = additional_dependencies.copy()
package_info.tags = candidate['tags'].copy()
package_info.tags.extend(additional_tags)
package_info.lineage = candidate['lineage'].lower()
# TODO cast package types (diffusion_model or checkpoint) (use civit lineages)
package_info.package_type = candidate['package_type'].lower()
package_info.version = candidate['version']
package_info.release_date = candidate['release_date']
package_info.size_bytes = candidate['size_bytes']
package_info.quantisation = candidate['quantisation']
package_info.save()
os.makedirs(package_path / 'files')
with open(package_path / 'model_info.json', 'w') as f:
json.dump(candidate['model_info'], f, indent=2, ensure_ascii=False)
print('Pulling model...')
client.download_file(url=candidate['url'], path=package_path / 'files' / candidate['filename'])
print('Pulling main thumbnail...')
preview = candidate['images'][0]
dir, file = str(preview).rsplit('/', maxsplit=1)
image_name, image_extension = str(file).rsplit('.', maxsplit=1)
ckpt = candidate['filename']
ckpt_name, ckpt_extension = str(ckpt).rsplit('.', maxsplit=1)
client.download_file(url=preview, path=package_path / 'files' / (ckpt_name + '.' + image_extension))
os.makedirs(package_path / 'images')
print('Pulling thumbnails...')
for image in candidate['images']:
dir, file = str(image).rsplit('/', maxsplit=1)
client.download_file(url=image, path=package_path / 'images' / file)
package = ModelPackage(package_path, [], package_info)
self.packages[package.uuid] = package
self.package_names.add(package.name)
self.add_package_to_collection(package.name, 'civit', internal=True)
if collection: self.add_package_to_collection(package.name, collection, category, internal=internal)
pulled.append(package)
for package in pulled:
info = package.info
print('Collections for package:')
print(f' {'N':<{2}} {'version':<{10}} {'type':<{10}} {'release_date':<{25}}'
f' {'lineage':<{10}} {'quant':<{5}} {'size':<{10}} ')
print(
f' {'N':<{2}} {info.version:<{10}} {info.package_type:<{10}} {info.release_date:<{25}}'
f' {info.lineage:<{10}} {info.quantization:<{5}} {format_bytes(info.size_bytes):<{10}} ')
self._add_package_to_collections_interactive(package)