Add civit model pull ability

This commit is contained in:
2025-09-20 12:12:56 +07:00
parent 65bffc38eb
commit 991d655756
9 changed files with 632 additions and 13 deletions

View File

@@ -49,4 +49,61 @@ class SetsDict:
return self._data.get(key, None)
@property
def keys(self): return self._data.keys()
def keys(self): return self._data.keys()
def select_elements(lst, selection_string):
"""
Выбирает элементы из списка согласно строке выбора
Args:
lst: Исходный список
selection_string: Строка вида "1 2 4-6 all"
Returns:
Новый список с выбранными элементами, отсортированными по номерам
"""
selection_string = selection_string.strip()
if not selection_string.strip():
return []
if selection_string == "all":
return lst.copy()
selected_indices = set()
parts = selection_string.split()
for part in parts:
if '-' in part:
# Обработка диапазона
start, end = map(int, part.split('-'))
# Обработка диапазона в любом направлении
if start <= end:
selected_indices.update(range(start, end + 1))
else:
selected_indices.update(range(start, end - 1, -1))
else:
# Обработка отдельного элемента
selected_indices.add(int(part))
# Преобразуем в список и сортируем по номерам
sorted_indices = sorted(selected_indices)
# Выбираем элементы
result = []
for idx in sorted_indices:
if 0 <= idx < len(lst):
result.append(lst[idx])
return result
def format_bytes(bytes_size):
"""Convert bytes to human readable format"""
if bytes_size < 1024:
return f"{bytes_size} B"
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if bytes_size < 1024.0:
return f"{bytes_size:.1f} {unit}"
bytes_size /= 1024.0
return f"{bytes_size:.1f} PB"

View File

@@ -22,12 +22,15 @@ class PackageInfo(Config):
quantization: str = "" # fp8, bf16
dependencies: List[str] = None
resources: List[str] = None
tags: List[str] = None
def __post_init__(self):
if self.dependencies is None:
self.dependencies = []
if self.resources is None:
self.resources = []
if self.tags is None:
self.tags = []
super().__post_init__()
@@ -131,6 +134,15 @@ class ModelPackage:
resources.append(resource)
package_info.resources = resources
print("Теги (введите по одному, пустая строка для завершения):")
tags = []
while True:
tag = input().strip()
if not tag:
break
tags.append(resource)
package_info.tags = tags
# Генерируем UUID случайным образом (не запрашиваем у пользователя)
package_info.uuid = pkg_uuid
if not package_info.uuid:
@@ -206,6 +218,7 @@ class ModelPackage:
provides_list = self.info.resources.copy() # Возвращаем копию
if self.info.name: # Добавляем имя пакета, если оно есть
provides_list.append(self.info.name)
provides_list.extend(self.info.tags)
return provides_list
@classmethod

View File

@@ -0,0 +1,50 @@
from dataclasses import dataclass
from pathlib import Path
from pythonapp.Libs.ConfigDataClass import Config
@dataclass
class ModelPackageCollection(Config):
name: str = None
external_packages: list[str] = None
unsorted_packages: list[str] = None
categorized_packages: dict[str, list[str]] = None
def __post_init__(self):
if not self.external_packages: self.external_packages = list()
if not self.unsorted_packages: self.unsorted_packages = list()
if not self.categorized_packages: self.categorized_packages = dict()
super().__post_init__()
if not self.name: raise ValueError('ModelPackageCollection(): name must be specified')
def add_package(self, pkg_name, category: str = None, internal=True):
if not internal and pkg_name not in self.external_packages: self.external_packages.append(pkg_name)
elif not category and pkg_name not in self.unsorted_packages: self.unsorted_packages.append(pkg_name)
else:
if category not in self.categorized_packages: self.categorized_packages[category] = list()
if pkg_name not in self.categorized_packages[category]: self.categorized_packages[category].append(pkg_name)
self.save()
def get_path(self, pkg_name) -> Path:
if pkg_name in self.external_packages: return Path('')
elif pkg_name in self.unsorted_packages: return Path(self.name)
else:
for category in self.categorized_packages:
if pkg_name in self.categorized_packages[category]: return Path(self.name) / category
raise FileNotFoundError(f'package {pkg_name} not in collection {self.name}')
@property
def paths_dict(self) -> dict[str, Path]:
result: dict[str, Path] = dict()
for category in self.categorized_packages:
for pkg_name in list(self.categorized_packages[category]):
result[pkg_name] = Path(self.name) / category
for pkg_name in list(self.unsorted_packages): result[pkg_name] = Path(self.name)
for pkg_name in list(self.external_packages): result[pkg_name] = Path('')
return result

View File

@@ -1,8 +1,13 @@
import json
import os
import uuid
import warnings
from pathlib import Path
from modelspace.Essentials import SetsDict
from modelspace.ModelPackage import ModelPackage
from modelspace.Essentials import SetsDict, select_elements, format_bytes
from modelspace.ModelPackage import ModelPackage, PackageInfo
from modelspace.ModelPackageCollection import ModelPackageCollection
from modules.civit.client import Client
class ModelPackageSubRepository:
@@ -11,10 +16,11 @@ class ModelPackageSubRepository:
self.path = Path(path)
self.seed = seed
self.packages: dict[str, ModelPackage] | None = None
self.package_names: set[str] | None = None
self.resources: SetsDict | None = None
self.collections: dict[str, ModelPackageCollection] | None = None
self.reload()
# Completed
def _reload_packages(self):
self.packages = dict()
try:
@@ -27,7 +33,8 @@ class ModelPackageSubRepository:
package = ModelPackage.load(str(self.path / d))
self.packages[package.uuid] = package
# Completed
self.package_names = {p.name for id, p in self.packages.items()}
def _reload_resources(self):
self.resources = SetsDict()
@@ -35,12 +42,35 @@ class ModelPackageSubRepository:
for resource in package.provides:
self.resources.add(resource, pkg_id)
def _reload_collections(self):
try:
filenames = [item.name for item in self.path.iterdir() if item.is_file() and item.name.endswith('_collection.json')]
except OSError as e:
print(f"Ошибка доступа к директории: {e}")
return
self.collections = dict()
for filename in filenames:
collection = ModelPackageCollection(filename=str(Path(self.path) / filename))
self.collections[collection.name] = collection
def reload(self):
self._reload_packages()
self._reload_resources()
self._reload_collections()
def add_package_to_collection(self, pkg_name, collection_name, category: str = None, internal=True):
if pkg_name not in self.package_names:
if pkg_name in self.resources.keys or pkg_name in self.collections: raise RuntimeWarning('Only packages allowed to add in collections')
else: raise RuntimeWarning(f'Package {pkg_name} not found')
if collection_name not in self.collections:
self.collections[collection_name] = ModelPackageCollection(
self.path / (collection_name + '_collection.json'), name=collection_name, autosave=True
)
self.collections[collection_name].add_package(pkg_name, category, internal)
# debugged
def resources_from_pkg_list(self, uuids: list[str]):
selected_packages = []
for pkg_id in uuids:
@@ -60,7 +90,6 @@ class ModelPackageSubRepository:
for package in packages: res = res | set(package.dependencies)
return res
# debugged
def packages_by_resource(self, resource):
packages_ids = self.resources.by_key(resource)
@@ -73,7 +102,6 @@ class ModelPackageSubRepository:
for pkg_id in packages_ids: packages.add(self.package_by_id(pkg_id))
return packages
# debugged
def package_by_id(self, pkg_id):
package = self.packages.get(pkg_id, None)
if not package: raise RuntimeError(f"{pkg_id}: Something went wrong while reading package info")
@@ -92,4 +120,247 @@ class ModelPackageSubRepository:
package = ModelPackage.interactive(str(package_path), package_uuid)
loaded_package = ModelPackage.load(str(package_path))
self.packages[loaded_package.uuid] = loaded_package
return package
# Добавляем пакет в коллекции
self._add_package_to_collections_interactive(package)
return package
def _add_package_to_collections_interactive(self, package: ModelPackage):
while True:
print('Input collections, blank for stop')
collection = input().strip()
if collection == '': break
external = input('External? (blank for no): ').strip()
if external != '':
self.add_package_to_collection(package.name, collection, category=None, internal=False)
continue
category = input('Category: ').strip()
if category == '':
self.add_package_to_collection(package.name, collection, category=None, internal=True)
continue
else:
self.add_package_to_collection(package.name, collection, category, internal=True)
continue
def pull_civit_package(self, client: Client, model_id: int, version_id: int = None, file_id: int = None):
model_info = client.get_model_raw(model_id)
model_versions = model_info.get('modelVersions', None)
if not model_versions:
warnings.warn(f'Unable to find model {model_id}')
return
pull_candidates = list()
print('Model name:', model_info.get('name', None))
ic_package_type = model_info.get('type', None)
ic_tags = model_info.get('tags', None)
for model_version in model_versions:
if not model_version.get('availability', None) or model_version.get('availability', None) != 'Public': continue
ic_version_id = model_version.get('id', None)
ic_version = model_version.get('name', None)
ic_release_date = model_version.get('publishedAt', None)
ic_lineage = model_version.get('baseModel', None)
ic_images = None
images = model_version.get('images', None)
if images and isinstance(images, list): ic_images = [i.get('url', None) for i in images if i.get('url', None) is not None]
ic_provides = [f'civit-{model_id}-{ic_version_id}', f'civit-{model_id}'].copy()
for file in model_version.get('files', list()):
ic_size_bytes = file.get('sizeKB', None)
if ic_size_bytes and isinstance(ic_size_bytes, float):
ic_size_bytes = int(ic_size_bytes * 1024)
metadata = file.get('metadata', None)
ic_quantisation = None
if metadata:
ic_quantisation = metadata.get('fp', None)
ic_file_id = file.get('id', None)
ic_filename = file.get('name', None)
if file.get('type', None) and file.get('type', None) != 'Model':
continue
ic_url = file.get('downloadUrl', None)
ic_model_info = model_info.copy()
ic_name = f'civit-{model_id}-{ic_version_id}-{ic_file_id}'
ic_uuid = ic_name
pull_candidates.append({
'uuid': ic_uuid,
'name': ic_name,
'provides': ic_provides,
'version_id': ic_version_id,
'file_id': ic_file_id,
'package_type': ic_package_type,
'tags': ic_tags,
'version': ic_version,
'release_date': ic_release_date,
'lineage': ic_lineage,
'images': ic_images,
'size_bytes': ic_size_bytes,
'quantisation': ic_quantisation,
'url': ic_url,
'filename': ic_filename,
'model_info': ic_model_info,
})
try:
del file, ic_url, ic_package_type, ic_images, ic_release_date, ic_tags, ic_filename, ic_lineage, model_info
del ic_model_info, ic_quantisation, ic_size_bytes, ic_version, images, metadata, model_version, model_versions
del ic_file_id, ic_version_id, ic_uuid, ic_name, ic_provides
except Exception:
pass
# check already pulled packages
already_pulled = list()
available_to_pull = list()
for candidate in pull_candidates:
if candidate['name'] in self.package_names: already_pulled.append(candidate)
else: available_to_pull.append(candidate)
if version_id: available_to_pull = [p for p in available_to_pull if p['version_id'] == version_id]
if file_id: available_to_pull = [p for p in available_to_pull if p['file_id'] == file_id]
if len(available_to_pull) == 0:
warnings.warn(f'Pull candidate not found for model_id:{model_id} and version_id:{version_id} and file_id:{file_id}')
return
# selection output
if len(already_pulled) > 0:
print('Already pulled packages:')
print(f' {'N':<{2}} {'version':<{10}} {'type':<{10}} {'release_date':<{25}}'
f' {'lineage':<{10}} {'quant':<{5}} {'size':<{10}} ')
for candidate in already_pulled:
print(
f' {'N':<{2}} {candidate['version']:<{10}} {candidate['package_type']:<{10}} {candidate['release_date']:<{25}}'
f' {candidate['lineage']:<{10}} {candidate['quantisation']:<{5}} {format_bytes(candidate['size_bytes']):<{10}} ')
if len(available_to_pull) == 0:
print('All available packages already pulled')
return
else:
print('Available packages:')
print(f' {'N':<{2}} {'version':<{10}} {'type':<{10}} {'release_date':<{25}}'
f' {'lineage':<{10}} {'quant':<{5}} {'size':<{10}} ')
for i in range(len(available_to_pull)):
candidate = available_to_pull[i]
print(f' {i:<{2}} {candidate['version']:<{10}} {candidate['package_type']:<{10}} {candidate['release_date']:<{25}}'
f' {candidate['lineage']:<{10}} {candidate['quantisation']:<{5}} {format_bytes(candidate['size_bytes']):<{10}} ')
if len(available_to_pull) > 1: to_pull = select_elements(pull_candidates, input("Your choice: "))
else: to_pull = available_to_pull
# Ввод зависимостей
print("Зависимости (введите по одной, пустая строка для завершения):")
additional_dependencies = []
while True:
dep = input().strip()
if not dep:
break
additional_dependencies.append(dep)
# Ввод ресурсов
print("Ресурсы (введите по одному, пустая строка для завершения):")
additional_resources = []
while True:
resource = input().strip()
if not resource:
break
additional_resources.append(resource)
print("Теги (введите по одному, пустая строка для завершения):")
additional_tags = []
while True:
tag = input().strip()
if not tag:
break
additional_tags.append(resource)
while True:
print('One collection for all selected packages, blank for None')
collection = input().strip()
if collection == '': break
external = input('External? (blank for no): ').strip()
if external != '':
category: str | None = None
internal = False
break
category = str(input('Category: ')).strip()
if category == '':
category: str | None = None
internal = True
break
else:
internal = True
break
pulled = list()
for candidate in to_pull:
package_path = self.path / candidate['uuid']
if os.path.exists(str(Path(package_path) / "package.json")): raise RuntimeError("package exists!")
package_info = PackageInfo(str(Path(package_path) / "package.json"))
package_info.uuid = candidate['uuid']
package_info.name = candidate['name']
# TODO список дополнительных ресурсов, зависимостей, по линейке
# TODO add deps and resources based on lineage (use civit lineages)
package_info.resources = candidate['provides'].copy()
package_info.resources.extend(additional_resources)
package_info.dependencies = additional_dependencies.copy()
package_info.tags = candidate['tags'].copy()
package_info.tags.extend(additional_tags)
package_info.lineage = candidate['lineage'].lower()
# TODO cast package types (diffusion_model or checkpoint) (use civit lineages)
package_info.package_type = candidate['package_type'].lower()
package_info.version = candidate['version']
package_info.release_date = candidate['release_date']
package_info.size_bytes = candidate['size_bytes']
package_info.quantisation = candidate['quantisation']
package_info.save()
os.makedirs(package_path / 'files')
with open(package_path / 'model_info.json', 'w') as f:
json.dump(candidate['model_info'], f, indent=2, ensure_ascii=False)
print('Pulling model...')
client.download_file(url=candidate['url'], path=package_path / 'files' / candidate['filename'])
print('Pulling main thumbnail...')
preview = candidate['images'][0]
dir, file = str(preview).rsplit('/', maxsplit=1)
image_name, image_extension = str(file).rsplit('.', maxsplit=1)
ckpt = candidate['filename']
ckpt_name, ckpt_extension = str(ckpt).rsplit('.', maxsplit=1)
client.download_file(url=preview, path=package_path / 'files' / (ckpt_name + image_extension))
os.makedirs(package_path / 'images')
print('Pulling thumbnails...')
for image in candidate['images']:
dir, file = str(image).rsplit('/', maxsplit=1)
client.download_file(url=image, path=package_path / 'images' / file)
package = ModelPackage(package_path, [], package_info)
print('Collections for package:')
print(f' {'N':<{2}} {'version':<{10}} {'type':<{10}} {'release_date':<{25}}'
f' {'lineage':<{10}} {'quant':<{5}} {'size':<{10}} ')
print(
f' {'N':<{2}} {candidate['version']:<{10}} {candidate['package_type']:<{10}} {candidate['release_date']:<{25}}'
f' {candidate['lineage']:<{10}} {candidate['quantisation']:<{5}} {format_bytes(candidate['size_bytes']):<{10}} ')
self.packages[package.uuid] = package
self.add_package_to_collection(package.name, 'civit', internal=True)
if collection: self.add_package_to_collection(package.name, collection, category, internal=internal)
pulled.append(package)
for package in pulled: self._add_package_to_collections_interactive(package)