mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
218 lines
7.7 KiB
Python
218 lines
7.7 KiB
Python
from toolkit.paths import MODELS_PATH
|
|
import requests
|
|
import os
|
|
import json
|
|
import tqdm
|
|
|
|
|
|
class ModelCache:
|
|
def __init__(self):
|
|
self.raw_cache = {}
|
|
self.cache_path = os.path.join(MODELS_PATH, '.ai_toolkit_cache.json')
|
|
if os.path.exists(self.cache_path):
|
|
with open(self.cache_path, 'r') as f:
|
|
all_cache = json.load(f)
|
|
if 'models' in all_cache:
|
|
self.raw_cache = all_cache['models']
|
|
else:
|
|
self.raw_cache = all_cache
|
|
|
|
def get_model_path(self, model_id: int, model_version_id: int = None):
|
|
if str(model_id) not in self.raw_cache:
|
|
return None
|
|
if model_version_id is None:
|
|
# get latest version
|
|
model_version_id = max([int(x) for x in self.raw_cache[str(model_id)].keys()])
|
|
if model_version_id is None:
|
|
return None
|
|
model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path']
|
|
# check if model path exists
|
|
if not os.path.exists(model_path):
|
|
# remove version from cache
|
|
del self.raw_cache[str(model_id)][str(model_version_id)]
|
|
self.save()
|
|
return None
|
|
return model_path
|
|
else:
|
|
if str(model_version_id) not in self.raw_cache[str(model_id)]:
|
|
return None
|
|
model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path']
|
|
# check if model path exists
|
|
if not os.path.exists(model_path):
|
|
# remove version from cache
|
|
del self.raw_cache[str(model_id)][str(model_version_id)]
|
|
self.save()
|
|
return None
|
|
return model_path
|
|
|
|
def update_cache(self, model_id: int, model_version_id: int, model_path: str):
|
|
if str(model_id) not in self.raw_cache:
|
|
self.raw_cache[str(model_id)] = {}
|
|
if str(model_version_id) not in self.raw_cache[str(model_id)]:
|
|
self.raw_cache[str(model_id)][str(model_version_id)] = {}
|
|
self.raw_cache[str(model_id)][str(model_version_id)] = {
|
|
'model_path': model_path
|
|
}
|
|
self.save()
|
|
|
|
def save(self):
|
|
if not os.path.exists(os.path.dirname(self.cache_path)):
|
|
os.makedirs(os.path.dirname(self.cache_path), exist_ok=True)
|
|
all_cache = {'models': {}}
|
|
if os.path.exists(self.cache_path):
|
|
# load it first
|
|
with open(self.cache_path, 'r') as f:
|
|
all_cache = json.load(f)
|
|
|
|
all_cache['models'] = self.raw_cache
|
|
|
|
with open(self.cache_path, 'w') as f:
|
|
json.dump(all_cache, f, indent=2)
|
|
|
|
|
|
def get_model_download_info(model_id: int, model_version_id: int = None):
|
|
# curl https://civitai.com/api/v1/models?limit=3&types=TextualInversion \
|
|
# -H "Content-Type: application/json" \
|
|
# -X GET
|
|
print(
|
|
f"Getting model info for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}")
|
|
endpoint = f"https://civitai.com/api/v1/models/{model_id}"
|
|
|
|
# get the json
|
|
response = requests.get(endpoint)
|
|
response.raise_for_status()
|
|
model_data = response.json()
|
|
|
|
model_version = None
|
|
|
|
# go through versions and get the top one if one is not set
|
|
for version in model_data['modelVersions']:
|
|
if model_version_id is not None:
|
|
if str(version['id']) == str(model_version_id):
|
|
model_version = version
|
|
break
|
|
else:
|
|
# get first version
|
|
model_version = version
|
|
break
|
|
|
|
if model_version is None:
|
|
raise ValueError(
|
|
f"Could not find a model version for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}")
|
|
|
|
model_file = None
|
|
# go through files and prefer fp16 safetensors
|
|
# "metadata": {
|
|
# "fp": "fp16",
|
|
# "size": "pruned",
|
|
# "format": "SafeTensor"
|
|
# },
|
|
# todo check pickle scans and skip if not good
|
|
# try to get fp16 safetensor
|
|
for file in model_version['files']:
|
|
if file['metadata']['fp'] == 'fp16' and file['metadata']['format'] == 'SafeTensor':
|
|
model_file = file
|
|
break
|
|
|
|
if model_file is None:
|
|
# try to get primary
|
|
for file in model_version['files']:
|
|
if file['primary']:
|
|
model_file = file
|
|
break
|
|
|
|
if model_file is None:
|
|
# try to get any safetensor
|
|
for file in model_version['files']:
|
|
if file['metadata']['format'] == 'SafeTensor':
|
|
model_file = file
|
|
break
|
|
|
|
if model_file is None:
|
|
# try to get any fp16
|
|
for file in model_version['files']:
|
|
if file['metadata']['fp'] == 'fp16':
|
|
model_file = file
|
|
break
|
|
|
|
if model_file is None:
|
|
# try to get any
|
|
for file in model_version['files']:
|
|
model_file = file
|
|
break
|
|
|
|
if model_file is None:
|
|
raise ValueError(f"Could not find a model file to download for model id: {model_id}")
|
|
|
|
return model_file, model_version['id']
|
|
|
|
|
|
def get_model_path_from_url(url: str):
|
|
# get query params form url if they are set
|
|
# https: // civitai.com / models / 25694?modelVersionId = 127742
|
|
query_params = {}
|
|
if '?' in url:
|
|
query_string = url.split('?')[1]
|
|
query_params = dict(qc.split("=") for qc in query_string.split("&"))
|
|
|
|
# get model id from url
|
|
model_id = url.split('/')[-1]
|
|
# remove query params from model id
|
|
if '?' in model_id:
|
|
model_id = model_id.split('?')[0]
|
|
if model_id.isdigit():
|
|
model_id = int(model_id)
|
|
else:
|
|
raise ValueError(f"Invalid model id: {model_id}")
|
|
|
|
model_cache = ModelCache()
|
|
model_path = model_cache.get_model_path(model_id, query_params.get('modelVersionId', None))
|
|
if model_path is not None:
|
|
return model_path
|
|
else:
|
|
# download model
|
|
file_info, model_version_id = get_model_download_info(model_id, query_params.get('modelVersionId', None))
|
|
|
|
download_url = file_info['downloadUrl'] # url does not work directly
|
|
size_kb = file_info['sizeKB']
|
|
filename = file_info['name']
|
|
model_path = os.path.join(MODELS_PATH, filename)
|
|
|
|
# download model
|
|
print(f"Did not find model locally, downloading from model from: {download_url}")
|
|
|
|
# use tqdm to show status of downlod
|
|
response = requests.get(download_url, stream=True)
|
|
response.raise_for_status()
|
|
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
|
block_size = 1024 # 1 Kibibyte
|
|
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
|
tmp_path = os.path.join(MODELS_PATH, f".download_tmp_{filename}")
|
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
|
# remove tmp file if it exists
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|
|
|
|
try:
|
|
|
|
with open(tmp_path, 'wb') as f:
|
|
for data in response.iter_content(block_size):
|
|
progress_bar.update(len(data))
|
|
f.write(data)
|
|
progress_bar.close()
|
|
# move to final path
|
|
os.rename(tmp_path, model_path)
|
|
model_cache.update_cache(model_id, model_version_id, model_path)
|
|
|
|
return model_path
|
|
except Exception as e:
|
|
# remove tmp file
|
|
os.remove(tmp_path)
|
|
raise e
|
|
|
|
|
|
# if is main
|
|
if __name__ == '__main__':
|
|
model_path = get_model_path_from_url("https://civitai.com/models/25694?modelVersionId=127742")
|
|
print(model_path)
|