374 lines
16 KiB
Python
374 lines
16 KiB
Python
import json
|
||
import os
|
||
import uuid
|
||
import warnings
|
||
from pathlib import Path
|
||
|
||
from modelspace.Essentials import SetsDict, select_elements, format_bytes
|
||
from modelspace.ModelPackage import ModelPackage, PackageInfo
|
||
from modelspace.ModelPackageCollection import ModelPackageCollection
|
||
from modules.civit.client import Client
|
||
|
||
|
||
class ModelPackageSubRepository:
|
||
def __init__(self, path, seed):
|
||
path.mkdir(exist_ok=True)
|
||
self.path = Path(path)
|
||
self.seed = seed
|
||
self.packages: dict[str, ModelPackage] | None = None
|
||
self.package_names: set[str] | None = None
|
||
self.resources: SetsDict | None = None
|
||
self.collections: dict[str, ModelPackageCollection] | None = None
|
||
self.reload()
|
||
|
||
def _reload_packages(self):
|
||
self.packages = dict()
|
||
try:
|
||
dirs = [item.name for item in self.path.iterdir() if item.is_dir()]
|
||
except OSError as e:
|
||
print(f"Ошибка доступа к директории: {e}")
|
||
dirs = []
|
||
|
||
for d in dirs:
|
||
package = ModelPackage.load(str(self.path / d))
|
||
self.packages[package.uuid] = package
|
||
|
||
self.package_names = {p.name for id, p in self.packages.items()}
|
||
|
||
def _reload_resources(self):
|
||
self.resources = SetsDict()
|
||
|
||
for pkg_id, package in self.packages.items():
|
||
for resource in package.provides:
|
||
self.resources.add(resource, pkg_id)
|
||
|
||
def _reload_collections(self):
|
||
try:
|
||
filenames = [item.name for item in self.path.iterdir() if item.is_file() and item.name.endswith('_collection.json')]
|
||
except OSError as e:
|
||
print(f"Ошибка доступа к директории: {e}")
|
||
return
|
||
|
||
self.collections = dict()
|
||
for filename in filenames:
|
||
collection = ModelPackageCollection(filename=str(Path(self.path) / filename))
|
||
self.collections[collection.name] = collection
|
||
|
||
def reload(self):
|
||
self._reload_packages()
|
||
self._reload_resources()
|
||
self._reload_collections()
|
||
|
||
def add_package_to_collection(self, pkg_name, collection_name, category: str = None, internal=True):
|
||
if pkg_name not in self.package_names:
|
||
if pkg_name in self.resources.keys or pkg_name in self.collections: raise RuntimeWarning('Only packages allowed to add in collections')
|
||
else: raise RuntimeWarning(f'Package {pkg_name} not found')
|
||
if collection_name not in self.collections:
|
||
self.collections[collection_name] = ModelPackageCollection(
|
||
self.path / (collection_name + '_collection.json'), name=collection_name, autosave=True
|
||
)
|
||
self.collections[collection_name].add_package(pkg_name, category, internal)
|
||
|
||
|
||
|
||
def resources_from_pkg_list(self, uuids: list[str]):
|
||
selected_packages = []
|
||
for pkg_id in uuids:
|
||
package = self.packages.get(pkg_id, None)
|
||
selected_packages.append(package)
|
||
|
||
resources = SetsDict()
|
||
for package in selected_packages:
|
||
for resource in package.provides:
|
||
resources.add(resource, package.uuid)
|
||
|
||
return resources
|
||
|
||
@staticmethod
|
||
def deps_from_pkg_list(packages: list[ModelPackage]) -> set[str]:
|
||
res = set()
|
||
for package in packages: res = res | set(package.dependencies)
|
||
return res
|
||
|
||
def packages_by_resource(self, resource):
|
||
packages_ids = self.resources.by_key(resource)
|
||
|
||
if not packages_ids or len(packages_ids) == 0:
|
||
raise RuntimeError(f"{resource}: There are no packages in the repository that provide this resource")
|
||
else:
|
||
packages_ids = list(packages_ids)
|
||
|
||
packages: set[ModelPackage] = set()
|
||
for pkg_id in packages_ids: packages.add(self.package_by_id(pkg_id))
|
||
return packages
|
||
|
||
def package_by_id(self, pkg_id):
|
||
package = self.packages.get(pkg_id, None)
|
||
if not package: raise RuntimeError(f"{pkg_id}: Something went wrong while reading package info")
|
||
return package
|
||
|
||
|
||
def add_package_interactive(self) -> ModelPackage:
|
||
"""Добавляет новый пакет модели интерактивно"""
|
||
# Генерируем новый UUID
|
||
package_uuid = str(uuid.uuid4())
|
||
|
||
# Создаем путь к новому пакету
|
||
package_path = self.path / package_uuid
|
||
|
||
# Вызываем интерактивное создание пакета
|
||
package = ModelPackage.interactive(str(package_path), package_uuid)
|
||
loaded_package = ModelPackage.load(str(package_path))
|
||
self.packages[loaded_package.uuid] = loaded_package
|
||
|
||
# Добавляем пакет в коллекции
|
||
self._add_package_to_collections_interactive(package)
|
||
|
||
return package
|
||
|
||
def _add_package_to_collections_interactive(self, package: ModelPackage):
|
||
while True:
|
||
print('Input collections, blank for stop')
|
||
collection = input().strip()
|
||
if collection == '': break
|
||
external = input('External? (blank for no): ').strip()
|
||
if external != '':
|
||
self.add_package_to_collection(package.name, collection, category=None, internal=False)
|
||
continue
|
||
category = input('Category: ').strip()
|
||
if category == '':
|
||
self.add_package_to_collection(package.name, collection, category=None, internal=True)
|
||
continue
|
||
else:
|
||
self.add_package_to_collection(package.name, collection, category, internal=True)
|
||
continue
|
||
|
||
def pull_civit_package(self, client: Client, model_id: int, version_id: int = None, file_id: int = None):
|
||
model_info = client.get_model_raw(model_id)
|
||
model_versions = model_info.get('modelVersions', None)
|
||
if not model_versions:
|
||
warnings.warn(f'Unable to find model {model_id}')
|
||
return
|
||
|
||
pull_candidates = list()
|
||
print('Model name:', model_info.get('name', None))
|
||
|
||
ic_package_type = model_info.get('type', None)
|
||
ic_tags = model_info.get('tags', None)
|
||
|
||
for model_version in model_versions:
|
||
if not model_version.get('availability', None) or model_version.get('availability', None) != 'Public': continue
|
||
ic_version_id = model_version.get('id', None)
|
||
ic_version = model_version.get('name', None)
|
||
ic_release_date = model_version.get('publishedAt', None)
|
||
ic_lineage = model_version.get('baseModel', None)
|
||
|
||
ic_images = None
|
||
|
||
images = model_version.get('images', None)
|
||
if images and isinstance(images, list): ic_images = [i.get('url', None) for i in images if i.get('url', None) is not None]
|
||
|
||
ic_provides = [f'civit-{model_id}-{ic_version_id}', f'civit-{model_id}'].copy()
|
||
|
||
for file in model_version.get('files', list()):
|
||
ic_size_bytes = file.get('sizeKB', None)
|
||
if ic_size_bytes and isinstance(ic_size_bytes, float):
|
||
ic_size_bytes = int(ic_size_bytes * 1024)
|
||
metadata = file.get('metadata', None)
|
||
ic_quantisation = None
|
||
if metadata:
|
||
ic_quantisation = metadata.get('fp', None)
|
||
|
||
ic_file_id = file.get('id', None)
|
||
ic_filename = file.get('name', None)
|
||
if file.get('type', None) and file.get('type', None) != 'Model':
|
||
continue
|
||
ic_url = file.get('downloadUrl', None)
|
||
ic_model_info = model_info.copy()
|
||
ic_name = f'civit-{model_id}-{ic_version_id}-{ic_file_id}'
|
||
ic_uuid = ic_name
|
||
|
||
|
||
pull_candidates.append({
|
||
'uuid': ic_uuid,
|
||
'name': ic_name,
|
||
'provides': ic_provides,
|
||
'version_id': ic_version_id,
|
||
'file_id': ic_file_id,
|
||
'package_type': ic_package_type.lower(),
|
||
'tags': ic_tags,
|
||
'version': ic_version,
|
||
'release_date': ic_release_date,
|
||
'lineage': ic_lineage,
|
||
'images': ic_images or list(),
|
||
'size_bytes': ic_size_bytes,
|
||
'quantisation': ic_quantisation or '',
|
||
'url': ic_url,
|
||
'filename': ic_filename,
|
||
'model_info': ic_model_info,
|
||
})
|
||
|
||
try:
|
||
del file, ic_url, ic_package_type, ic_images, ic_release_date, ic_tags, ic_filename, ic_lineage, model_info
|
||
del ic_model_info, ic_quantisation, ic_size_bytes, ic_version, images, metadata, model_version, model_versions
|
||
del ic_file_id, ic_version_id, ic_uuid, ic_name, ic_provides
|
||
except Exception:
|
||
pass
|
||
|
||
# check already pulled packages
|
||
already_pulled = list()
|
||
available_to_pull = list()
|
||
|
||
for candidate in pull_candidates:
|
||
if candidate['name'] in self.package_names: already_pulled.append(candidate)
|
||
else: available_to_pull.append(candidate)
|
||
|
||
if version_id: available_to_pull = [p for p in available_to_pull if p['version_id'] == version_id]
|
||
if file_id: available_to_pull = [p for p in available_to_pull if p['file_id'] == file_id]
|
||
|
||
if len(available_to_pull) == 0:
|
||
warnings.warn(f'Pull candidate not found for model_id:{model_id} and version_id:{version_id} and file_id:{file_id}')
|
||
return
|
||
|
||
|
||
# selection output
|
||
if len(already_pulled) > 0:
|
||
print('Already pulled packages:')
|
||
print(f' {'N':<{2}} {'version':<{10}} {'type':<{10}} {'release_date':<{25}}'
|
||
f' {'lineage':<{10}} {'quant':<{5}} {'size':<{10}} ')
|
||
for candidate in already_pulled:
|
||
print(
|
||
f' {'N':<{2}} {candidate['version']:<{10}} {candidate['package_type']:<{10}} {candidate['release_date']:<{25}}'
|
||
f' {candidate['lineage']:<{10}} {candidate['quantisation']:<{5}} {format_bytes(candidate['size_bytes']):<{10}} ')
|
||
|
||
if len(available_to_pull) == 0:
|
||
print('All available packages already pulled')
|
||
return
|
||
else:
|
||
print('Available packages:')
|
||
print(f' {'N':<{2}} {'version':<{10}} {'type':<{10}} {'release_date':<{25}}'
|
||
f' {'lineage':<{10}} {'quant':<{5}} {'size':<{10}} ')
|
||
|
||
for i in range(len(available_to_pull)):
|
||
candidate = available_to_pull[i]
|
||
quantisation = candidate['quantisation'] or 'N/A'
|
||
print(f' {i:<{2}} {candidate['version']:<{10}} {candidate['package_type']:<{10}} {candidate['release_date']:<{25}}'
|
||
f' {candidate['lineage']:<{10}} {quantisation:<{5}} {format_bytes(candidate['size_bytes']):<{10}} ')
|
||
|
||
if len(available_to_pull) > 1: to_pull = select_elements(pull_candidates, input("Your choice: "))
|
||
else: to_pull = available_to_pull
|
||
|
||
# Ввод зависимостей
|
||
print("Зависимости (введите по одной, п1658427устая строка для завершения):")
|
||
additional_dependencies = []
|
||
while True:
|
||
dep = input().strip()
|
||
if not dep:
|
||
break
|
||
additional_dependencies.append(dep)
|
||
|
||
# Ввод ресурсов
|
||
print("Ресурсы (введите по одному, пустая строка для завершения):")
|
||
additional_resources = []
|
||
while True:
|
||
resource = input().strip()
|
||
if not resource:
|
||
break
|
||
additional_resources.append(resource)
|
||
|
||
print("Теги (введите по одному, пустая строка для завершения):")
|
||
additional_tags = []
|
||
while True:
|
||
tag = input().strip()
|
||
if not tag:
|
||
break
|
||
additional_tags.append(resource)
|
||
|
||
while True:
|
||
print('One collection for all selected packages, blank for None')
|
||
collection = input().strip()
|
||
if collection == '': break
|
||
external = input('External? (blank for no): ').strip()
|
||
if external != '':
|
||
category: str | None = None
|
||
internal = False
|
||
break
|
||
category = str(input('Category: ')).strip()
|
||
if category == '':
|
||
category: str | None = None
|
||
internal = True
|
||
break
|
||
else:
|
||
internal = True
|
||
break
|
||
|
||
pulled: list[ModelPackage] = list()
|
||
for candidate in to_pull:
|
||
package_path = self.path / candidate['uuid']
|
||
if os.path.exists(str(Path(package_path) / "package.json")): raise RuntimeError("package exists!")
|
||
package_info = PackageInfo(str(Path(package_path) / "package.json"))
|
||
package_info.uuid = candidate['uuid']
|
||
package_info.name = candidate['name']
|
||
|
||
# TODO список дополнительных ресурсов, зависимостей, по линейке
|
||
# TODO add deps and resources based on lineage (use civit lineages)
|
||
package_info.resources = candidate['provides'].copy()
|
||
package_info.resources.extend(additional_resources)
|
||
package_info.dependencies = additional_dependencies.copy()
|
||
|
||
package_info.tags = candidate['tags'].copy()
|
||
package_info.tags.extend(additional_tags)
|
||
|
||
package_info.lineage = candidate['lineage'].lower()
|
||
# TODO cast package types (diffusion_model or checkpoint) (use civit lineages)
|
||
package_info.package_type = candidate['package_type'].lower()
|
||
package_info.version = candidate['version']
|
||
package_info.release_date = candidate['release_date']
|
||
package_info.size_bytes = candidate['size_bytes']
|
||
package_info.quantisation = candidate['quantisation']
|
||
package_info.save()
|
||
|
||
os.makedirs(package_path / 'files')
|
||
with open(package_path / 'model_info.json', 'w') as f:
|
||
json.dump(candidate['model_info'], f, indent=2, ensure_ascii=False)
|
||
|
||
print('Pulling model...')
|
||
client.download_file(url=candidate['url'], path=package_path / 'files' / candidate['filename'])
|
||
|
||
print('Pulling main thumbnail...')
|
||
preview = candidate['images'][0]
|
||
dir, file = str(preview).rsplit('/', maxsplit=1)
|
||
image_name, image_extension = str(file).rsplit('.', maxsplit=1)
|
||
ckpt = candidate['filename']
|
||
ckpt_name, ckpt_extension = str(ckpt).rsplit('.', maxsplit=1)
|
||
client.download_file(url=preview, path=package_path / 'files' / (ckpt_name + '.' + image_extension))
|
||
|
||
os.makedirs(package_path / 'images')
|
||
print('Pulling thumbnails...')
|
||
for image in candidate['images']:
|
||
dir, file = str(image).rsplit('/', maxsplit=1)
|
||
client.download_file(url=image, path=package_path / 'images' / file)
|
||
|
||
package = ModelPackage(package_path, [], package_info)
|
||
|
||
self.packages[package.uuid] = package
|
||
self.package_names.add(package.name)
|
||
|
||
self.add_package_to_collection(package.name, 'civit', internal=True)
|
||
if collection: self.add_package_to_collection(package.name, collection, category, internal=internal)
|
||
pulled.append(package)
|
||
|
||
for package in pulled:
|
||
info = package.info
|
||
print('Collections for package:')
|
||
print(f' {'N':<{2}} {'version':<{10}} {'type':<{10}} {'release_date':<{25}}'
|
||
f' {'lineage':<{10}} {'quant':<{5}} {'size':<{10}} ')
|
||
print(
|
||
f' {'N':<{2}} {info.version:<{10}} {info.package_type:<{10}} {info.release_date:<{25}}'
|
||
f' {info.lineage:<{10}} {info.quantization:<{5}} {format_bytes(info.size_bytes):<{10}} ')
|
||
|
||
self._add_package_to_collections_interactive(package)
|
||
|
||
|
||
|