mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-21 14:59:02 +00:00
Added ability to use civit ai url ar model name and built a model downloader and cache manager for it
This commit is contained in:
217
toolkit/civitai.py
Normal file
217
toolkit/civitai.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from toolkit.paths import MODELS_PATH
|
||||
import requests
|
||||
import os
|
||||
import json
|
||||
import tqdm
|
||||
|
||||
|
||||
class ModelCache:
|
||||
def __init__(self):
|
||||
self.raw_cache = {}
|
||||
self.cache_path = os.path.join(MODELS_PATH, '.ai_toolkit_cache.json')
|
||||
if os.path.exists(self.cache_path):
|
||||
with open(self.cache_path, 'r') as f:
|
||||
all_cache = json.load(f)
|
||||
if 'models' in all_cache:
|
||||
self.raw_cache = all_cache['models']
|
||||
else:
|
||||
self.raw_cache = all_cache
|
||||
|
||||
def get_model_path(self, model_id: int, model_version_id: int = None):
|
||||
if str(model_id) not in self.raw_cache:
|
||||
return None
|
||||
if model_version_id is None:
|
||||
# get latest version
|
||||
model_version_id = max([int(x) for x in self.raw_cache[str(model_id)].keys()])
|
||||
if model_version_id is None:
|
||||
return None
|
||||
model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path']
|
||||
# check if model path exists
|
||||
if not os.path.exists(model_path):
|
||||
# remove version from cache
|
||||
del self.raw_cache[str(model_id)][str(model_version_id)]
|
||||
self.save()
|
||||
return None
|
||||
return model_path
|
||||
else:
|
||||
if str(model_version_id) not in self.raw_cache[str(model_id)]:
|
||||
return None
|
||||
model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path']
|
||||
# check if model path exists
|
||||
if not os.path.exists(model_path):
|
||||
# remove version from cache
|
||||
del self.raw_cache[str(model_id)][str(model_version_id)]
|
||||
self.save()
|
||||
return None
|
||||
return model_path
|
||||
|
||||
def update_cache(self, model_id: int, model_version_id: int, model_path: str):
|
||||
if str(model_id) not in self.raw_cache:
|
||||
self.raw_cache[str(model_id)] = {}
|
||||
if str(model_version_id) not in self.raw_cache[str(model_id)]:
|
||||
self.raw_cache[str(model_id)][str(model_version_id)] = {}
|
||||
self.raw_cache[str(model_id)][str(model_version_id)] = {
|
||||
'model_path': model_path
|
||||
}
|
||||
self.save()
|
||||
|
||||
def save(self):
|
||||
if not os.path.exists(os.path.dirname(self.cache_path)):
|
||||
os.makedirs(os.path.dirname(self.cache_path), exist_ok=True)
|
||||
all_cache = {'models': {}}
|
||||
if os.path.exists(self.cache_path):
|
||||
# load it first
|
||||
with open(self.cache_path, 'r') as f:
|
||||
all_cache = json.load(f)
|
||||
|
||||
all_cache['models'] = self.raw_cache
|
||||
|
||||
with open(self.cache_path, 'w') as f:
|
||||
json.dump(all_cache, f, indent=2)
|
||||
|
||||
|
||||
def get_model_download_info(model_id: int, model_version_id: int = None):
|
||||
# curl https://civitai.com/api/v1/models?limit=3&types=TextualInversion \
|
||||
# -H "Content-Type: application/json" \
|
||||
# -X GET
|
||||
print(
|
||||
f"Getting model info for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}")
|
||||
endpoint = f"https://civitai.com/api/v1/models/{model_id}"
|
||||
|
||||
# get the json
|
||||
response = requests.get(endpoint)
|
||||
response.raise_for_status()
|
||||
model_data = response.json()
|
||||
|
||||
model_version = None
|
||||
|
||||
# go through versions and get the top one if one is not set
|
||||
for version in model_data['modelVersions']:
|
||||
if model_version_id is not None:
|
||||
if str(version['id']) == str(model_version_id):
|
||||
model_version = version
|
||||
break
|
||||
else:
|
||||
# get first version
|
||||
model_version = version
|
||||
break
|
||||
|
||||
if model_version is None:
|
||||
raise ValueError(
|
||||
f"Could not find a model version for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}")
|
||||
|
||||
model_file = None
|
||||
# go through files and prefer fp16 safetensors
|
||||
# "metadata": {
|
||||
# "fp": "fp16",
|
||||
# "size": "pruned",
|
||||
# "format": "SafeTensor"
|
||||
# },
|
||||
# todo check pickle scans and skip if not good
|
||||
# try to get fp16 safetensor
|
||||
for file in model_version['files']:
|
||||
if file['metadata']['fp'] == 'fp16' and file['metadata']['format'] == 'SafeTensor':
|
||||
model_file = file
|
||||
break
|
||||
|
||||
if model_file is None:
|
||||
# try to get primary
|
||||
for file in model_version['files']:
|
||||
if file['primary']:
|
||||
model_file = file
|
||||
break
|
||||
|
||||
if model_file is None:
|
||||
# try to get any safetensor
|
||||
for file in model_version['files']:
|
||||
if file['metadata']['format'] == 'SafeTensor':
|
||||
model_file = file
|
||||
break
|
||||
|
||||
if model_file is None:
|
||||
# try to get any fp16
|
||||
for file in model_version['files']:
|
||||
if file['metadata']['fp'] == 'fp16':
|
||||
model_file = file
|
||||
break
|
||||
|
||||
if model_file is None:
|
||||
# try to get any
|
||||
for file in model_version['files']:
|
||||
model_file = file
|
||||
break
|
||||
|
||||
if model_file is None:
|
||||
raise ValueError(f"Could not find a model file to download for model id: {model_id}")
|
||||
|
||||
return model_file, model_version['id']
|
||||
|
||||
|
||||
def get_model_path_from_url(url: str):
|
||||
# get query params form url if they are set
|
||||
# https: // civitai.com / models / 25694?modelVersionId = 127742
|
||||
query_params = {}
|
||||
if '?' in url:
|
||||
query_string = url.split('?')[1]
|
||||
query_params = dict(qc.split("=") for qc in query_string.split("&"))
|
||||
|
||||
# get model id from url
|
||||
model_id = url.split('/')[-1]
|
||||
# remove query params from model id
|
||||
if '?' in model_id:
|
||||
model_id = model_id.split('?')[0]
|
||||
if model_id.isdigit():
|
||||
model_id = int(model_id)
|
||||
else:
|
||||
raise ValueError(f"Invalid model id: {model_id}")
|
||||
|
||||
model_cache = ModelCache()
|
||||
model_path = model_cache.get_model_path(model_id, query_params.get('modelVersionId', None))
|
||||
if model_path is not None:
|
||||
return model_path
|
||||
else:
|
||||
# download model
|
||||
file_info, model_version_id = get_model_download_info(model_id, query_params.get('modelVersionId', None))
|
||||
|
||||
download_url = file_info['downloadUrl'] # url does not work directly
|
||||
size_kb = file_info['sizeKB']
|
||||
filename = file_info['name']
|
||||
model_path = os.path.join(MODELS_PATH, filename)
|
||||
|
||||
# download model
|
||||
print(f"Did not find model locally, downloading from model from: {download_url}")
|
||||
|
||||
# use tqdm to show status of downlod
|
||||
response = requests.get(download_url, stream=True)
|
||||
response.raise_for_status()
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
tmp_path = os.path.join(MODELS_PATH, f".download_tmp_{filename}")
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
# remove tmp file if it exists
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
try:
|
||||
|
||||
with open(tmp_path, 'wb') as f:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
f.write(data)
|
||||
progress_bar.close()
|
||||
# move to final path
|
||||
os.rename(tmp_path, model_path)
|
||||
model_cache.update_cache(model_id, model_version_id, model_path)
|
||||
|
||||
return model_path
|
||||
except Exception as e:
|
||||
# remove tmp file
|
||||
os.remove(tmp_path)
|
||||
raise e
|
||||
|
||||
|
||||
# if is main
|
||||
if __name__ == '__main__':
|
||||
model_path = get_model_path_from_url("https://civitai.com/models/25694?modelVersionId=127742")
|
||||
print(model_path)
|
||||
@@ -5,6 +5,12 @@ CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
||||
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
||||
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
|
||||
|
||||
# check if ENV variable is set
|
||||
if 'MODELS_PATH' in os.environ:
|
||||
MODELS_PATH = os.environ['MODELS_PATH']
|
||||
else:
|
||||
MODELS_PATH = os.path.join(TOOLKIT_ROOT, "models")
|
||||
|
||||
|
||||
def get_path(path):
|
||||
# we allow absolute paths, but if it is not absolute, we assume it is relative to the toolkit root
|
||||
|
||||
@@ -143,6 +143,13 @@ class StableDiffusion:
|
||||
prediction_type=prediction_type,
|
||||
steps_offset=1
|
||||
)
|
||||
|
||||
model_path = self.model_config.name_or_path
|
||||
if 'civitai.com' in self.model_config.name_or_path:
|
||||
# load is a civit ai model, use the loader.
|
||||
from toolkit.civitai import get_model_path_from_url
|
||||
model_path = get_model_path_from_url(self.model_config.name_or_path)
|
||||
|
||||
if self.model_config.is_xl:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
@@ -150,17 +157,17 @@ class StableDiffusion:
|
||||
pipln = CustomStableDiffusionXLPipeline
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(self.model_config.name_or_path):
|
||||
if not os.path.exists(model_path):
|
||||
# try to load with default diffusers
|
||||
pipe = pipln.from_pretrained(
|
||||
self.model_config.name_or_path,
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
self.model_config.name_or_path,
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
@@ -180,10 +187,10 @@ class StableDiffusion:
|
||||
pipln = CustomStableDiffusionPipeline
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(self.model_config.name_or_path):
|
||||
if not os.path.exists(model_path):
|
||||
# try to load with default diffusers
|
||||
pipe = pipln.from_pretrained(
|
||||
self.model_config.name_or_path,
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch,
|
||||
@@ -193,7 +200,7 @@ class StableDiffusion:
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
self.model_config.name_or_path,
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch,
|
||||
|
||||
Reference in New Issue
Block a user