Add civit model pull ability
This commit is contained in:
1
main.py
1
main.py
@@ -1,4 +1,5 @@
|
|||||||
from shell.Interactive import Interactive
|
from shell.Interactive import Interactive
|
||||||
|
|
||||||
|
|
||||||
Interactive().start()
|
Interactive().start()
|
||||||
|
|
||||||
|
|||||||
@@ -50,3 +50,60 @@ class SetsDict:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def keys(self): return self._data.keys()
|
def keys(self): return self._data.keys()
|
||||||
|
|
||||||
|
def select_elements(lst, selection_string):
|
||||||
|
"""
|
||||||
|
Выбирает элементы из списка согласно строке выбора
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lst: Исходный список
|
||||||
|
selection_string: Строка вида "1 2 4-6 all"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Новый список с выбранными элементами, отсортированными по номерам
|
||||||
|
"""
|
||||||
|
selection_string = selection_string.strip()
|
||||||
|
if not selection_string.strip():
|
||||||
|
return []
|
||||||
|
|
||||||
|
if selection_string == "all":
|
||||||
|
return lst.copy()
|
||||||
|
|
||||||
|
selected_indices = set()
|
||||||
|
parts = selection_string.split()
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
if '-' in part:
|
||||||
|
# Обработка диапазона
|
||||||
|
start, end = map(int, part.split('-'))
|
||||||
|
# Обработка диапазона в любом направлении
|
||||||
|
if start <= end:
|
||||||
|
selected_indices.update(range(start, end + 1))
|
||||||
|
else:
|
||||||
|
selected_indices.update(range(start, end - 1, -1))
|
||||||
|
else:
|
||||||
|
# Обработка отдельного элемента
|
||||||
|
selected_indices.add(int(part))
|
||||||
|
|
||||||
|
# Преобразуем в список и сортируем по номерам
|
||||||
|
sorted_indices = sorted(selected_indices)
|
||||||
|
|
||||||
|
# Выбираем элементы
|
||||||
|
result = []
|
||||||
|
for idx in sorted_indices:
|
||||||
|
if 0 <= idx < len(lst):
|
||||||
|
result.append(lst[idx])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def format_bytes(bytes_size):
|
||||||
|
"""Convert bytes to human readable format"""
|
||||||
|
if bytes_size < 1024:
|
||||||
|
return f"{bytes_size} B"
|
||||||
|
|
||||||
|
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||||
|
if bytes_size < 1024.0:
|
||||||
|
return f"{bytes_size:.1f} {unit}"
|
||||||
|
bytes_size /= 1024.0
|
||||||
|
|
||||||
|
return f"{bytes_size:.1f} PB"
|
||||||
|
|||||||
@@ -22,12 +22,15 @@ class PackageInfo(Config):
|
|||||||
quantization: str = "" # fp8, bf16
|
quantization: str = "" # fp8, bf16
|
||||||
dependencies: List[str] = None
|
dependencies: List[str] = None
|
||||||
resources: List[str] = None
|
resources: List[str] = None
|
||||||
|
tags: List[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.dependencies is None:
|
if self.dependencies is None:
|
||||||
self.dependencies = []
|
self.dependencies = []
|
||||||
if self.resources is None:
|
if self.resources is None:
|
||||||
self.resources = []
|
self.resources = []
|
||||||
|
if self.tags is None:
|
||||||
|
self.tags = []
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
|
|
||||||
@@ -131,6 +134,15 @@ class ModelPackage:
|
|||||||
resources.append(resource)
|
resources.append(resource)
|
||||||
package_info.resources = resources
|
package_info.resources = resources
|
||||||
|
|
||||||
|
print("Теги (введите по одному, пустая строка для завершения):")
|
||||||
|
tags = []
|
||||||
|
while True:
|
||||||
|
tag = input().strip()
|
||||||
|
if not tag:
|
||||||
|
break
|
||||||
|
tags.append(resource)
|
||||||
|
package_info.tags = tags
|
||||||
|
|
||||||
# Генерируем UUID случайным образом (не запрашиваем у пользователя)
|
# Генерируем UUID случайным образом (не запрашиваем у пользователя)
|
||||||
package_info.uuid = pkg_uuid
|
package_info.uuid = pkg_uuid
|
||||||
if not package_info.uuid:
|
if not package_info.uuid:
|
||||||
@@ -206,6 +218,7 @@ class ModelPackage:
|
|||||||
provides_list = self.info.resources.copy() # Возвращаем копию
|
provides_list = self.info.resources.copy() # Возвращаем копию
|
||||||
if self.info.name: # Добавляем имя пакета, если оно есть
|
if self.info.name: # Добавляем имя пакета, если оно есть
|
||||||
provides_list.append(self.info.name)
|
provides_list.append(self.info.name)
|
||||||
|
provides_list.extend(self.info.tags)
|
||||||
return provides_list
|
return provides_list
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
50
modelspace/ModelPackageCollection.py
Normal file
50
modelspace/ModelPackageCollection.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pythonapp.Libs.ConfigDataClass import Config
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelPackageCollection(Config):
|
||||||
|
name: str = None
|
||||||
|
external_packages: list[str] = None
|
||||||
|
unsorted_packages: list[str] = None
|
||||||
|
categorized_packages: dict[str, list[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.external_packages: self.external_packages = list()
|
||||||
|
if not self.unsorted_packages: self.unsorted_packages = list()
|
||||||
|
if not self.categorized_packages: self.categorized_packages = dict()
|
||||||
|
super().__post_init__()
|
||||||
|
if not self.name: raise ValueError('ModelPackageCollection(): name must be specified')
|
||||||
|
|
||||||
|
def add_package(self, pkg_name, category: str = None, internal=True):
|
||||||
|
if not internal and pkg_name not in self.external_packages: self.external_packages.append(pkg_name)
|
||||||
|
elif not category and pkg_name not in self.unsorted_packages: self.unsorted_packages.append(pkg_name)
|
||||||
|
else:
|
||||||
|
if category not in self.categorized_packages: self.categorized_packages[category] = list()
|
||||||
|
if pkg_name not in self.categorized_packages[category]: self.categorized_packages[category].append(pkg_name)
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def get_path(self, pkg_name) -> Path:
|
||||||
|
if pkg_name in self.external_packages: return Path('')
|
||||||
|
elif pkg_name in self.unsorted_packages: return Path(self.name)
|
||||||
|
else:
|
||||||
|
for category in self.categorized_packages:
|
||||||
|
if pkg_name in self.categorized_packages[category]: return Path(self.name) / category
|
||||||
|
|
||||||
|
raise FileNotFoundError(f'package {pkg_name} not in collection {self.name}')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def paths_dict(self) -> dict[str, Path]:
|
||||||
|
result: dict[str, Path] = dict()
|
||||||
|
for category in self.categorized_packages:
|
||||||
|
for pkg_name in list(self.categorized_packages[category]):
|
||||||
|
result[pkg_name] = Path(self.name) / category
|
||||||
|
|
||||||
|
for pkg_name in list(self.unsorted_packages): result[pkg_name] = Path(self.name)
|
||||||
|
for pkg_name in list(self.external_packages): result[pkg_name] = Path('')
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@@ -1,8 +1,13 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from modelspace.Essentials import SetsDict
|
from modelspace.Essentials import SetsDict, select_elements, format_bytes
|
||||||
from modelspace.ModelPackage import ModelPackage
|
from modelspace.ModelPackage import ModelPackage, PackageInfo
|
||||||
|
from modelspace.ModelPackageCollection import ModelPackageCollection
|
||||||
|
from modules.civit.client import Client
|
||||||
|
|
||||||
|
|
||||||
class ModelPackageSubRepository:
|
class ModelPackageSubRepository:
|
||||||
@@ -11,10 +16,11 @@ class ModelPackageSubRepository:
|
|||||||
self.path = Path(path)
|
self.path = Path(path)
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.packages: dict[str, ModelPackage] | None = None
|
self.packages: dict[str, ModelPackage] | None = None
|
||||||
|
self.package_names: set[str] | None = None
|
||||||
self.resources: SetsDict | None = None
|
self.resources: SetsDict | None = None
|
||||||
|
self.collections: dict[str, ModelPackageCollection] | None = None
|
||||||
self.reload()
|
self.reload()
|
||||||
|
|
||||||
# Completed
|
|
||||||
def _reload_packages(self):
|
def _reload_packages(self):
|
||||||
self.packages = dict()
|
self.packages = dict()
|
||||||
try:
|
try:
|
||||||
@@ -27,7 +33,8 @@ class ModelPackageSubRepository:
|
|||||||
package = ModelPackage.load(str(self.path / d))
|
package = ModelPackage.load(str(self.path / d))
|
||||||
self.packages[package.uuid] = package
|
self.packages[package.uuid] = package
|
||||||
|
|
||||||
# Completed
|
self.package_names = {p.name for id, p in self.packages.items()}
|
||||||
|
|
||||||
def _reload_resources(self):
|
def _reload_resources(self):
|
||||||
self.resources = SetsDict()
|
self.resources = SetsDict()
|
||||||
|
|
||||||
@@ -35,12 +42,35 @@ class ModelPackageSubRepository:
|
|||||||
for resource in package.provides:
|
for resource in package.provides:
|
||||||
self.resources.add(resource, pkg_id)
|
self.resources.add(resource, pkg_id)
|
||||||
|
|
||||||
|
def _reload_collections(self):
|
||||||
|
try:
|
||||||
|
filenames = [item.name for item in self.path.iterdir() if item.is_file() and item.name.endswith('_collection.json')]
|
||||||
|
except OSError as e:
|
||||||
|
print(f"Ошибка доступа к директории: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.collections = dict()
|
||||||
|
for filename in filenames:
|
||||||
|
collection = ModelPackageCollection(filename=str(Path(self.path) / filename))
|
||||||
|
self.collections[collection.name] = collection
|
||||||
|
|
||||||
def reload(self):
|
def reload(self):
|
||||||
self._reload_packages()
|
self._reload_packages()
|
||||||
self._reload_resources()
|
self._reload_resources()
|
||||||
|
self._reload_collections()
|
||||||
|
|
||||||
|
def add_package_to_collection(self, pkg_name, collection_name, category: str = None, internal=True):
|
||||||
|
if pkg_name not in self.package_names:
|
||||||
|
if pkg_name in self.resources.keys or pkg_name in self.collections: raise RuntimeWarning('Only packages allowed to add in collections')
|
||||||
|
else: raise RuntimeWarning(f'Package {pkg_name} not found')
|
||||||
|
if collection_name not in self.collections:
|
||||||
|
self.collections[collection_name] = ModelPackageCollection(
|
||||||
|
self.path / (collection_name + '_collection.json'), name=collection_name, autosave=True
|
||||||
|
)
|
||||||
|
self.collections[collection_name].add_package(pkg_name, category, internal)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# debugged
|
|
||||||
def resources_from_pkg_list(self, uuids: list[str]):
|
def resources_from_pkg_list(self, uuids: list[str]):
|
||||||
selected_packages = []
|
selected_packages = []
|
||||||
for pkg_id in uuids:
|
for pkg_id in uuids:
|
||||||
@@ -60,7 +90,6 @@ class ModelPackageSubRepository:
|
|||||||
for package in packages: res = res | set(package.dependencies)
|
for package in packages: res = res | set(package.dependencies)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
# debugged
|
|
||||||
def packages_by_resource(self, resource):
|
def packages_by_resource(self, resource):
|
||||||
packages_ids = self.resources.by_key(resource)
|
packages_ids = self.resources.by_key(resource)
|
||||||
|
|
||||||
@@ -73,7 +102,6 @@ class ModelPackageSubRepository:
|
|||||||
for pkg_id in packages_ids: packages.add(self.package_by_id(pkg_id))
|
for pkg_id in packages_ids: packages.add(self.package_by_id(pkg_id))
|
||||||
return packages
|
return packages
|
||||||
|
|
||||||
# debugged
|
|
||||||
def package_by_id(self, pkg_id):
|
def package_by_id(self, pkg_id):
|
||||||
package = self.packages.get(pkg_id, None)
|
package = self.packages.get(pkg_id, None)
|
||||||
if not package: raise RuntimeError(f"{pkg_id}: Something went wrong while reading package info")
|
if not package: raise RuntimeError(f"{pkg_id}: Something went wrong while reading package info")
|
||||||
@@ -92,4 +120,247 @@ class ModelPackageSubRepository:
|
|||||||
package = ModelPackage.interactive(str(package_path), package_uuid)
|
package = ModelPackage.interactive(str(package_path), package_uuid)
|
||||||
loaded_package = ModelPackage.load(str(package_path))
|
loaded_package = ModelPackage.load(str(package_path))
|
||||||
self.packages[loaded_package.uuid] = loaded_package
|
self.packages[loaded_package.uuid] = loaded_package
|
||||||
|
|
||||||
|
# Добавляем пакет в коллекции
|
||||||
|
self._add_package_to_collections_interactive(package)
|
||||||
|
|
||||||
return package
|
return package
|
||||||
|
|
||||||
|
def _add_package_to_collections_interactive(self, package: ModelPackage):
|
||||||
|
while True:
|
||||||
|
print('Input collections, blank for stop')
|
||||||
|
collection = input().strip()
|
||||||
|
if collection == '': break
|
||||||
|
external = input('External? (blank for no): ').strip()
|
||||||
|
if external != '':
|
||||||
|
self.add_package_to_collection(package.name, collection, category=None, internal=False)
|
||||||
|
continue
|
||||||
|
category = input('Category: ').strip()
|
||||||
|
if category == '':
|
||||||
|
self.add_package_to_collection(package.name, collection, category=None, internal=True)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
self.add_package_to_collection(package.name, collection, category, internal=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
def pull_civit_package(self, client: Client, model_id: int, version_id: int = None, file_id: int = None):
|
||||||
|
model_info = client.get_model_raw(model_id)
|
||||||
|
model_versions = model_info.get('modelVersions', None)
|
||||||
|
if not model_versions:
|
||||||
|
warnings.warn(f'Unable to find model {model_id}')
|
||||||
|
return
|
||||||
|
|
||||||
|
pull_candidates = list()
|
||||||
|
print('Model name:', model_info.get('name', None))
|
||||||
|
|
||||||
|
ic_package_type = model_info.get('type', None)
|
||||||
|
ic_tags = model_info.get('tags', None)
|
||||||
|
|
||||||
|
for model_version in model_versions:
|
||||||
|
if not model_version.get('availability', None) or model_version.get('availability', None) != 'Public': continue
|
||||||
|
ic_version_id = model_version.get('id', None)
|
||||||
|
ic_version = model_version.get('name', None)
|
||||||
|
ic_release_date = model_version.get('publishedAt', None)
|
||||||
|
ic_lineage = model_version.get('baseModel', None)
|
||||||
|
|
||||||
|
ic_images = None
|
||||||
|
|
||||||
|
images = model_version.get('images', None)
|
||||||
|
if images and isinstance(images, list): ic_images = [i.get('url', None) for i in images if i.get('url', None) is not None]
|
||||||
|
|
||||||
|
ic_provides = [f'civit-{model_id}-{ic_version_id}', f'civit-{model_id}'].copy()
|
||||||
|
|
||||||
|
for file in model_version.get('files', list()):
|
||||||
|
ic_size_bytes = file.get('sizeKB', None)
|
||||||
|
if ic_size_bytes and isinstance(ic_size_bytes, float):
|
||||||
|
ic_size_bytes = int(ic_size_bytes * 1024)
|
||||||
|
metadata = file.get('metadata', None)
|
||||||
|
ic_quantisation = None
|
||||||
|
if metadata:
|
||||||
|
ic_quantisation = metadata.get('fp', None)
|
||||||
|
|
||||||
|
ic_file_id = file.get('id', None)
|
||||||
|
ic_filename = file.get('name', None)
|
||||||
|
if file.get('type', None) and file.get('type', None) != 'Model':
|
||||||
|
continue
|
||||||
|
ic_url = file.get('downloadUrl', None)
|
||||||
|
ic_model_info = model_info.copy()
|
||||||
|
ic_name = f'civit-{model_id}-{ic_version_id}-{ic_file_id}'
|
||||||
|
ic_uuid = ic_name
|
||||||
|
|
||||||
|
|
||||||
|
pull_candidates.append({
|
||||||
|
'uuid': ic_uuid,
|
||||||
|
'name': ic_name,
|
||||||
|
'provides': ic_provides,
|
||||||
|
'version_id': ic_version_id,
|
||||||
|
'file_id': ic_file_id,
|
||||||
|
'package_type': ic_package_type,
|
||||||
|
'tags': ic_tags,
|
||||||
|
'version': ic_version,
|
||||||
|
'release_date': ic_release_date,
|
||||||
|
'lineage': ic_lineage,
|
||||||
|
'images': ic_images,
|
||||||
|
'size_bytes': ic_size_bytes,
|
||||||
|
'quantisation': ic_quantisation,
|
||||||
|
'url': ic_url,
|
||||||
|
'filename': ic_filename,
|
||||||
|
'model_info': ic_model_info,
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
del file, ic_url, ic_package_type, ic_images, ic_release_date, ic_tags, ic_filename, ic_lineage, model_info
|
||||||
|
del ic_model_info, ic_quantisation, ic_size_bytes, ic_version, images, metadata, model_version, model_versions
|
||||||
|
del ic_file_id, ic_version_id, ic_uuid, ic_name, ic_provides
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# check already pulled packages
|
||||||
|
already_pulled = list()
|
||||||
|
available_to_pull = list()
|
||||||
|
|
||||||
|
for candidate in pull_candidates:
|
||||||
|
if candidate['name'] in self.package_names: already_pulled.append(candidate)
|
||||||
|
else: available_to_pull.append(candidate)
|
||||||
|
|
||||||
|
if version_id: available_to_pull = [p for p in available_to_pull if p['version_id'] == version_id]
|
||||||
|
if file_id: available_to_pull = [p for p in available_to_pull if p['file_id'] == file_id]
|
||||||
|
|
||||||
|
if len(available_to_pull) == 0:
|
||||||
|
warnings.warn(f'Pull candidate not found for model_id:{model_id} and version_id:{version_id} and file_id:{file_id}')
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# selection output
|
||||||
|
if len(already_pulled) > 0:
|
||||||
|
print('Already pulled packages:')
|
||||||
|
print(f' {'N':<{2}} {'version':<{10}} {'type':<{10}} {'release_date':<{25}}'
|
||||||
|
f' {'lineage':<{10}} {'quant':<{5}} {'size':<{10}} ')
|
||||||
|
for candidate in already_pulled:
|
||||||
|
print(
|
||||||
|
f' {'N':<{2}} {candidate['version']:<{10}} {candidate['package_type']:<{10}} {candidate['release_date']:<{25}}'
|
||||||
|
f' {candidate['lineage']:<{10}} {candidate['quantisation']:<{5}} {format_bytes(candidate['size_bytes']):<{10}} ')
|
||||||
|
|
||||||
|
if len(available_to_pull) == 0:
|
||||||
|
print('All available packages already pulled')
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print('Available packages:')
|
||||||
|
print(f' {'N':<{2}} {'version':<{10}} {'type':<{10}} {'release_date':<{25}}'
|
||||||
|
f' {'lineage':<{10}} {'quant':<{5}} {'size':<{10}} ')
|
||||||
|
|
||||||
|
for i in range(len(available_to_pull)):
|
||||||
|
candidate = available_to_pull[i]
|
||||||
|
print(f' {i:<{2}} {candidate['version']:<{10}} {candidate['package_type']:<{10}} {candidate['release_date']:<{25}}'
|
||||||
|
f' {candidate['lineage']:<{10}} {candidate['quantisation']:<{5}} {format_bytes(candidate['size_bytes']):<{10}} ')
|
||||||
|
|
||||||
|
if len(available_to_pull) > 1: to_pull = select_elements(pull_candidates, input("Your choice: "))
|
||||||
|
else: to_pull = available_to_pull
|
||||||
|
|
||||||
|
# Ввод зависимостей
|
||||||
|
print("Зависимости (введите по одной, пустая строка для завершения):")
|
||||||
|
additional_dependencies = []
|
||||||
|
while True:
|
||||||
|
dep = input().strip()
|
||||||
|
if not dep:
|
||||||
|
break
|
||||||
|
additional_dependencies.append(dep)
|
||||||
|
|
||||||
|
# Ввод ресурсов
|
||||||
|
print("Ресурсы (введите по одному, пустая строка для завершения):")
|
||||||
|
additional_resources = []
|
||||||
|
while True:
|
||||||
|
resource = input().strip()
|
||||||
|
if not resource:
|
||||||
|
break
|
||||||
|
additional_resources.append(resource)
|
||||||
|
|
||||||
|
print("Теги (введите по одному, пустая строка для завершения):")
|
||||||
|
additional_tags = []
|
||||||
|
while True:
|
||||||
|
tag = input().strip()
|
||||||
|
if not tag:
|
||||||
|
break
|
||||||
|
additional_tags.append(resource)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
print('One collection for all selected packages, blank for None')
|
||||||
|
collection = input().strip()
|
||||||
|
if collection == '': break
|
||||||
|
external = input('External? (blank for no): ').strip()
|
||||||
|
if external != '':
|
||||||
|
category: str | None = None
|
||||||
|
internal = False
|
||||||
|
break
|
||||||
|
category = str(input('Category: ')).strip()
|
||||||
|
if category == '':
|
||||||
|
category: str | None = None
|
||||||
|
internal = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
internal = True
|
||||||
|
break
|
||||||
|
|
||||||
|
pulled = list()
|
||||||
|
for candidate in to_pull:
|
||||||
|
package_path = self.path / candidate['uuid']
|
||||||
|
if os.path.exists(str(Path(package_path) / "package.json")): raise RuntimeError("package exists!")
|
||||||
|
package_info = PackageInfo(str(Path(package_path) / "package.json"))
|
||||||
|
package_info.uuid = candidate['uuid']
|
||||||
|
package_info.name = candidate['name']
|
||||||
|
|
||||||
|
# TODO список дополнительных ресурсов, зависимостей, по линейке
|
||||||
|
# TODO add deps and resources based on lineage (use civit lineages)
|
||||||
|
package_info.resources = candidate['provides'].copy()
|
||||||
|
package_info.resources.extend(additional_resources)
|
||||||
|
package_info.dependencies = additional_dependencies.copy()
|
||||||
|
|
||||||
|
package_info.tags = candidate['tags'].copy()
|
||||||
|
package_info.tags.extend(additional_tags)
|
||||||
|
|
||||||
|
package_info.lineage = candidate['lineage'].lower()
|
||||||
|
# TODO cast package types (diffusion_model or checkpoint) (use civit lineages)
|
||||||
|
package_info.package_type = candidate['package_type'].lower()
|
||||||
|
package_info.version = candidate['version']
|
||||||
|
package_info.release_date = candidate['release_date']
|
||||||
|
package_info.size_bytes = candidate['size_bytes']
|
||||||
|
package_info.quantisation = candidate['quantisation']
|
||||||
|
package_info.save()
|
||||||
|
|
||||||
|
os.makedirs(package_path / 'files')
|
||||||
|
with open(package_path / 'model_info.json', 'w') as f:
|
||||||
|
json.dump(candidate['model_info'], f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
print('Pulling model...')
|
||||||
|
client.download_file(url=candidate['url'], path=package_path / 'files' / candidate['filename'])
|
||||||
|
|
||||||
|
print('Pulling main thumbnail...')
|
||||||
|
preview = candidate['images'][0]
|
||||||
|
dir, file = str(preview).rsplit('/', maxsplit=1)
|
||||||
|
image_name, image_extension = str(file).rsplit('.', maxsplit=1)
|
||||||
|
ckpt = candidate['filename']
|
||||||
|
ckpt_name, ckpt_extension = str(ckpt).rsplit('.', maxsplit=1)
|
||||||
|
client.download_file(url=preview, path=package_path / 'files' / (ckpt_name + image_extension))
|
||||||
|
|
||||||
|
os.makedirs(package_path / 'images')
|
||||||
|
print('Pulling thumbnails...')
|
||||||
|
for image in candidate['images']:
|
||||||
|
dir, file = str(image).rsplit('/', maxsplit=1)
|
||||||
|
client.download_file(url=image, path=package_path / 'images' / file)
|
||||||
|
|
||||||
|
package = ModelPackage(package_path, [], package_info)
|
||||||
|
print('Collections for package:')
|
||||||
|
print(f' {'N':<{2}} {'version':<{10}} {'type':<{10}} {'release_date':<{25}}'
|
||||||
|
f' {'lineage':<{10}} {'quant':<{5}} {'size':<{10}} ')
|
||||||
|
print(
|
||||||
|
f' {'N':<{2}} {candidate['version']:<{10}} {candidate['package_type']:<{10}} {candidate['release_date']:<{25}}'
|
||||||
|
f' {candidate['lineage']:<{10}} {candidate['quantisation']:<{5}} {format_bytes(candidate['size_bytes']):<{10}} ')
|
||||||
|
|
||||||
|
self.packages[package.uuid] = package
|
||||||
|
|
||||||
|
self.add_package_to_collection(package.name, 'civit', internal=True)
|
||||||
|
if collection: self.add_package_to_collection(package.name, collection, category, internal=internal)
|
||||||
|
pulled.append(package)
|
||||||
|
|
||||||
|
for package in pulled: self._add_package_to_collections_interactive(package)
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import time
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from requests import Session
|
from requests import Session
|
||||||
|
|
||||||
from modelspace.Repository import Repository
|
|
||||||
from pythonapp.Libs.ConfigDataClass import Config
|
from pythonapp.Libs.ConfigDataClass import Config
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -108,4 +107,70 @@ class Client:
|
|||||||
def get_creators_raw(self, page=None, limit = 200, query = None): return self.get_creators_tags_raw('creators', page, limit, query)
|
def get_creators_raw(self, page=None, limit = 200, query = None): return self.get_creators_tags_raw('creators', page, limit, query)
|
||||||
def get_tags_raw(self, page=None, limit = 200, query = None): return self.get_creators_tags_raw('tags', page, limit, query)
|
def get_tags_raw(self, page=None, limit = 200, query = None): return self.get_creators_tags_raw('tags', page, limit, query)
|
||||||
|
|
||||||
|
def get_model_raw(self, model_id: int):
|
||||||
|
try:
|
||||||
|
return self.make_get_request(f'{self.config.base_url}/api/v1/models/{model_id}').json()
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
print(e)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def download_file(self, url: str, path: str, chill_time: int = 3, max_retries: int = 3):
|
||||||
|
"""
|
||||||
|
Загружает файл по URL в указанный путь
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): URL файла для загрузки
|
||||||
|
path (str): Путь для сохранения файла
|
||||||
|
chill_time (int): Время ожидания в секундах при ошибке (по умолчанию 3)
|
||||||
|
max_retries (int): Максимальное количество попыток загрузки (по умолчанию 3)
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
# Создаем запрос с прогресс-баром
|
||||||
|
response = self.session.get(url, stream=True, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Получаем размер файла
|
||||||
|
total_size = int(response.headers.get('content-length', 0))
|
||||||
|
|
||||||
|
# Загружаем файл по частями
|
||||||
|
with open(path, 'wb') as file:
|
||||||
|
downloaded = 0
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
file.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
|
||||||
|
# Отображаем прогресс
|
||||||
|
if total_size > 0:
|
||||||
|
progress = (downloaded / total_size) * 100
|
||||||
|
print(f"\rDownloading: {progress:.1f}% ({downloaded}/{total_size} bytes)", end='',
|
||||||
|
flush=True)
|
||||||
|
|
||||||
|
print(f"\nFile downloaded successfully to {path}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"Attempt {attempt + 1} failed: {e}")
|
||||||
|
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
print(f"Waiting {chill_time} seconds before retry...")
|
||||||
|
time.sleep(chill_time)
|
||||||
|
else:
|
||||||
|
print(f"Failed to download file after {max_retries} attempts")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except IOError as e:
|
||||||
|
print(f"IO Error while saving file: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unexpected error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,122 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .datamodel_base import ForwardingBase
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Creator(ForwardingBase):
|
||||||
|
username: Optional[str] = None
|
||||||
|
modelCount: Optional[int] = None
|
||||||
|
link: Optional[str] = None
|
||||||
|
image: Optional[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
self._forwarding = {}
|
||||||
|
self._key_field = 'username'
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Tag(ForwardingBase):
|
||||||
|
name: Optional[str] = None
|
||||||
|
link: Optional[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
self._forwarding = {}
|
||||||
|
self._key_field = 'name'
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelVersionStats(ForwardingBase):
|
||||||
|
downloadCount: Optional[int] = None
|
||||||
|
ratingCount: Optional[int] = None
|
||||||
|
rating: Optional[int] = None
|
||||||
|
thumbsUpCount: Optional[int] = None
|
||||||
|
thumbsDownCount: Optional[int] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
self._forwarding = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelVersion(ForwardingBase):
|
||||||
|
id: Optional[int] = None
|
||||||
|
index: Optional[int] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
baseModel: Optional[str] = None
|
||||||
|
baseModelType: Optional[str] = None
|
||||||
|
publishedAt: Optional[str] = None
|
||||||
|
availability: Optional[str] = None
|
||||||
|
nsfwLevel: Optional[int] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
trainedWords: Optional[list[str]] = None
|
||||||
|
stats: Optional[ModelVersionStats] = None
|
||||||
|
supportsGeneration: Optional[bool] = None
|
||||||
|
downloadUrl: Optional[str] = None
|
||||||
|
# FILES
|
||||||
|
# IMAGES
|
||||||
|
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
self._forwarding = {
|
||||||
|
'stats': ModelVersionStats,
|
||||||
|
}
|
||||||
|
self._key_field = 'id'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelStats(ForwardingBase):
|
||||||
|
downloadCount: Optional[int] = None
|
||||||
|
favoriteCount: Optional[int] = None
|
||||||
|
thumbsUpCount: Optional[int] = None
|
||||||
|
thumbsDownCount: Optional[int] = None
|
||||||
|
commentCount: Optional[int] = None
|
||||||
|
ratingCount: Optional[int] = None
|
||||||
|
rating: Optional[int] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
self._forwarding = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Model(ForwardingBase):
|
||||||
|
id: Optional[int] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
allowNoCredit: Optional[bool] = None
|
||||||
|
allowCommercialUse: Optional[list] = None
|
||||||
|
allowDerivatives: Optional[bool] = None
|
||||||
|
allowDifferentLicense: Optional[bool] = None
|
||||||
|
type: Optional[str] = None
|
||||||
|
minor: Optional[bool] = None
|
||||||
|
sfwOnly: Optional[bool] = None
|
||||||
|
poi: Optional[bool] = None
|
||||||
|
nsfw: Optional[bool] = None
|
||||||
|
nsfwLevel: Optional[int] = None
|
||||||
|
availability: Optional[str] = None
|
||||||
|
cosmetic: Optional[str] = None
|
||||||
|
supportsGeneration: Optional[bool] = None
|
||||||
|
stats: Optional[ModelStats] = None
|
||||||
|
creator: Optional[Creator] = None
|
||||||
|
tags: Optional[list[Tag]] = None
|
||||||
|
modelVersions: Optional[list[ModelVersion]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
self._forwarding = {
|
||||||
|
'stats': ModelStats,
|
||||||
|
'creator': Creator,
|
||||||
|
'tags': Tag,
|
||||||
|
'modelVersions': ModelVersion,
|
||||||
|
}
|
||||||
|
self._key_field = 'id'
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -352,7 +352,12 @@ class Fetch:
|
|||||||
print(f"Fetching {entity}...")
|
print(f"Fetching {entity}...")
|
||||||
path = Path(client.path) / ('fetch_' + entity)
|
path = Path(client.path) / ('fetch_' + entity)
|
||||||
items = list()
|
items = list()
|
||||||
first_page = client.make_get_request(url=f'{client.config.base_url}/api/v1/{entity}{client.build_query_string(params)}').json()
|
url = f'{client.config.base_url}/api/v1/{entity}{client.build_query_string(params)}'
|
||||||
|
first_page = client.make_get_request(url)
|
||||||
|
if not first_page:
|
||||||
|
with open(Path(client.path) / 'bugs.log', 'a') as f: f.write(url + '\n')
|
||||||
|
return items
|
||||||
|
first_page = first_page.json()
|
||||||
if first_page.get('items', None): items.extend(first_page.get('items', None))
|
if first_page.get('items', None): items.extend(first_page.get('items', None))
|
||||||
if save:
|
if save:
|
||||||
path.mkdir(exist_ok=True)
|
path.mkdir(exist_ok=True)
|
||||||
@@ -419,7 +424,7 @@ class Fetch:
|
|||||||
for item in page_items: page_items_dict[item['id']] = item
|
for item in page_items: page_items_dict[item['id']] = item
|
||||||
print(f'Added {len(page_items_dict) - l} images by {sort} sort crawl. {len(page_items_dict)} images total')
|
print(f'Added {len(page_items_dict) - l} images by {sort} sort crawl. {len(page_items_dict)} images total')
|
||||||
|
|
||||||
page_items = [key for key, value in page_items_dict.items()]
|
page_items = [value for key, value in page_items_dict.items()]
|
||||||
|
|
||||||
|
|
||||||
l = len(items)
|
l = len(items)
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
|
from unicodedata import category
|
||||||
|
|
||||||
from modelspace.ModelPackageSelector import format_bytes
|
from modelspace.ModelPackageSelector import format_bytes
|
||||||
from modelspace.ModelSpace import ModelSpace
|
from modelspace.ModelSpace import ModelSpace
|
||||||
|
from modules.civit.client import Client
|
||||||
from shell.Handlers.ABS import Handler
|
from shell.Handlers.ABS import Handler
|
||||||
from modelspace.Repository import global_repo
|
from modelspace.Repository import global_repo
|
||||||
|
|
||||||
@@ -17,18 +20,50 @@ class ModelSpaceHandler(Handler):
|
|||||||
'load': self._load,
|
'load': self._load,
|
||||||
'list': self._list,
|
'list': self._list,
|
||||||
'debug': self._debug,
|
'debug': self._debug,
|
||||||
|
'add-to-collection': self._add_to_collection,
|
||||||
|
'init-civit': self._init_civit,
|
||||||
|
'pull-civit': self._pull_civit,
|
||||||
# 'show': self._show,
|
# 'show': self._show,
|
||||||
# 'activate': self._activate,
|
# 'activate': self._activate,
|
||||||
|
|
||||||
}
|
}
|
||||||
self._loaded_instances: dict[str, ModelSpace] = {}
|
self._loaded_instances: dict[str, ModelSpace] = {}
|
||||||
self._active_instance: ModelSpace | None = None
|
self._active_instance: ModelSpace | None = None
|
||||||
pass
|
self.client: Client | None = None
|
||||||
|
|
||||||
|
def _init_civit(self, command: list[str], pos=0):
|
||||||
|
keys, args = self.parse_arguments(command[pos:], ['path', 'key'])
|
||||||
|
self._check_arg(keys, 'path')
|
||||||
|
self.client = Client(keys['path'], keys['key'])
|
||||||
|
self.succeed = True
|
||||||
|
|
||||||
def _create_inter(self, command: list[str], pos=0):
|
def _create_inter(self, command: list[str], pos=0):
|
||||||
global_repo.add_model_package_interactive()
|
global_repo.add_model_package_interactive()
|
||||||
self.succeed = True
|
self.succeed = True
|
||||||
|
|
||||||
|
def _add_to_collection(self, command: list[str], pos=0):
|
||||||
|
keys, args = self.parse_arguments(command[pos:], ['pkg', 'collection', 'category', 'ext'])
|
||||||
|
self._check_arg(keys, 'pkg')
|
||||||
|
self._check_arg(keys, 'collection')
|
||||||
|
if keys['ext']:
|
||||||
|
internal = False
|
||||||
|
category = None
|
||||||
|
else:
|
||||||
|
internal = True
|
||||||
|
category = keys['category']
|
||||||
|
global_repo.model_sub_repo.add_package_to_collection(keys['pkg'], keys['collection'], category, internal)
|
||||||
|
self.succeed = True
|
||||||
|
|
||||||
|
def _pull_civit(self, command: list[str], pos=0):
|
||||||
|
keys, args = self.parse_arguments(command[pos:], ['model', 'version', 'file'])
|
||||||
|
self._check_arg(keys, 'model')
|
||||||
|
global_repo.model_sub_repo.pull_civit_package(self.client, keys['model'], keys['version'], keys['file'])
|
||||||
|
self.succeed = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _load(self, command: list[str], pos = 0):
|
def _load(self, command: list[str], pos = 0):
|
||||||
keys, args = self.parse_arguments(command[pos:], ['path', 'layout', 'name'])
|
keys, args = self.parse_arguments(command[pos:], ['path', 'layout', 'name'])
|
||||||
self._check_arg(keys, 'path')
|
self._check_arg(keys, 'path')
|
||||||
|
|||||||
Reference in New Issue
Block a user