mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added ability to use civit ai url ar model name and built a model downloader and cache manager for it
This commit is contained in:
211
test.py
211
test.py
@@ -1,211 +0,0 @@
|
|||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
job_to_run = OrderedDict({
|
|
||||||
# This is the config I use on my sliders, It is solid and tested
|
|
||||||
'job': 'train',
|
|
||||||
'config': {
|
|
||||||
# the name will be used to create a folder in the output folder
|
|
||||||
# it will also replace any [name] token in the rest of this config
|
|
||||||
'name': 'detail_slider_v1',
|
|
||||||
# folder will be created with name above in folder below
|
|
||||||
# it can be relative to the project root or absolute
|
|
||||||
'training_folder': "output/LoRA",
|
|
||||||
'device': 'cuda', # cpu, cuda:0, etc
|
|
||||||
# for tensorboard logging, we will make a subfolder for this job
|
|
||||||
'log_dir': "output/.tensorboard",
|
|
||||||
# you can stack processes for other jobs, It is not tested with sliders though
|
|
||||||
# just use one for now
|
|
||||||
'process': {
|
|
||||||
'type': 'slider', # tells runner to run the slider process
|
|
||||||
# network is the LoRA network for a slider, I recommend to leave this be
|
|
||||||
'network': {
|
|
||||||
'type': "lora",
|
|
||||||
# rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
|
|
||||||
'linear': 8, # "rank" or "dim"
|
|
||||||
'linear_alpha': 4, # Do about half of rank "alpha"
|
|
||||||
# 'conv': 4, # for convolutional layers "locon"
|
|
||||||
# 'conv_alpha': 4, # Do about half of conv "alpha"
|
|
||||||
},
|
|
||||||
# training config
|
|
||||||
'train': {
|
|
||||||
# this is also used in sampling. Stick with ddpm unless you know what you are doing
|
|
||||||
'noise_scheduler': "ddpm", # or "ddpm", "lms", "euler_a"
|
|
||||||
# how many steps to train. More is not always better. I rarely go over 1000
|
|
||||||
'steps': 100,
|
|
||||||
# I have had good results with 4e-4 to 1e-4 at 500 steps
|
|
||||||
'lr': 2e-4,
|
|
||||||
# enables gradient checkpoint, saves vram, leave it on
|
|
||||||
'gradient_checkpointing': True,
|
|
||||||
# train the unet. I recommend leaving this true
|
|
||||||
'train_unet': True,
|
|
||||||
# train the text encoder. I don't recommend this unless you have a special use case
|
|
||||||
# for sliders we are adjusting representation of the concept (unet),
|
|
||||||
# not the description of it (text encoder)
|
|
||||||
'train_text_encoder': False,
|
|
||||||
|
|
||||||
# just leave unless you know what you are doing
|
|
||||||
# also supports "dadaptation" but set lr to 1 if you use that,
|
|
||||||
# but it learns too fast and I don't recommend it
|
|
||||||
'optimizer': "adamw",
|
|
||||||
# only constant for now
|
|
||||||
'lr_scheduler': "constant",
|
|
||||||
# we randomly denoise random num of steps form 1 to this number
|
|
||||||
# while training. Just leave it
|
|
||||||
'max_denoising_steps': 40,
|
|
||||||
# works great at 1. I do 1 even with my 4090.
|
|
||||||
# higher may not work right with newer single batch stacking code anyway
|
|
||||||
'batch_size': 1,
|
|
||||||
# bf16 works best if your GPU supports it (modern)
|
|
||||||
'dtype': 'bf16', # fp32, bf16, fp16
|
|
||||||
# I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
|
|
||||||
# although, the way we train sliders is comparative, so it probably won't work anyway
|
|
||||||
'noise_offset': 0.0,
|
|
||||||
},
|
|
||||||
|
|
||||||
# the model to train the LoRA network on
|
|
||||||
'model': {
|
|
||||||
# huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
|
|
||||||
'name_or_path': "runwayml/stable-diffusion-v1-5",
|
|
||||||
'is_v2': False, # for v2 models
|
|
||||||
'is_v_pred': False, # for v-prediction models (most v2 models)
|
|
||||||
# has some issues with the dual text encoder and the way we train sliders
|
|
||||||
# it works bit weights need to probably be higher to see it.
|
|
||||||
'is_xl': False, # for SDXL models
|
|
||||||
},
|
|
||||||
|
|
||||||
# saving config
|
|
||||||
'save': {
|
|
||||||
'dtype': 'float16', # precision to save. I recommend float16
|
|
||||||
'save_every': 50, # save every this many steps
|
|
||||||
# this will remove step counts more than this number
|
|
||||||
# allows you to save more often in case of a crash without filling up your drive
|
|
||||||
'max_step_saves_to_keep': 2,
|
|
||||||
},
|
|
||||||
|
|
||||||
# sampling config
|
|
||||||
'sample': {
|
|
||||||
# must match train.noise_scheduler, this is not used here
|
|
||||||
# but may be in future and in other processes
|
|
||||||
'sampler': "ddpm",
|
|
||||||
# sample every this many steps
|
|
||||||
'sample_every': 20,
|
|
||||||
# image size
|
|
||||||
'width': 512,
|
|
||||||
'height': 512,
|
|
||||||
# prompts to use for sampling. Do as many as you want, but it slows down training
|
|
||||||
# pick ones that will best represent the concept you are trying to adjust
|
|
||||||
# allows some flags after the prompt
|
|
||||||
# --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive
|
|
||||||
# slide are good tests. will inherit sample.network_multiplier if not set
|
|
||||||
# --n [string] # negative prompt, will inherit sample.neg if not set
|
|
||||||
# Only 75 tokens allowed currently
|
|
||||||
# I like to do a wide positive and negative spread so I can see a good range and stop
|
|
||||||
# early if the network is braking down
|
|
||||||
'prompts': [
|
|
||||||
"a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5",
|
|
||||||
"a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3",
|
|
||||||
"a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3",
|
|
||||||
"a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5",
|
|
||||||
"a golden retriever sitting on a leather couch, --m -5",
|
|
||||||
"a golden retriever sitting on a leather couch --m -3",
|
|
||||||
"a golden retriever sitting on a leather couch --m 3",
|
|
||||||
"a golden retriever sitting on a leather couch --m 5",
|
|
||||||
"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5",
|
|
||||||
"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3",
|
|
||||||
"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3",
|
|
||||||
"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5",
|
|
||||||
],
|
|
||||||
# negative prompt used on all prompts above as default if they don't have one
|
|
||||||
'neg': "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome",
|
|
||||||
# seed for sampling. 42 is the answer for everything
|
|
||||||
'seed': 42,
|
|
||||||
# walks the seed so s1 is 42, s2 is 43, s3 is 44, etc
|
|
||||||
# will start over on next sample_every so s1 is always seed
|
|
||||||
# works well if you use same prompt but want different results
|
|
||||||
'walk_seed': False,
|
|
||||||
# cfg scale (4 to 10 is good)
|
|
||||||
'guidance_scale': 7,
|
|
||||||
# sampler steps (20 to 30 is good)
|
|
||||||
'sample_steps': 20,
|
|
||||||
# default network multiplier for all prompts
|
|
||||||
# since we are training a slider, I recommend overriding this with --m [number]
|
|
||||||
# in the prompts above to get both sides of the slider
|
|
||||||
'network_multiplier': 1.0,
|
|
||||||
},
|
|
||||||
|
|
||||||
# logging information
|
|
||||||
'logging': {
|
|
||||||
'log_every': 10, # log every this many steps
|
|
||||||
'use_wandb': False, # not supported yet
|
|
||||||
'verbose': False, # probably done need unless you are debugging
|
|
||||||
},
|
|
||||||
|
|
||||||
# slider training config, best for last
|
|
||||||
'slider': {
|
|
||||||
# resolutions to train on. [ width, height ]. This is less important for sliders
|
|
||||||
# as we are not teaching the model anything it doesn't already know
|
|
||||||
# but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1
|
|
||||||
# and [ 1024, 1024 ] for sd_xl
|
|
||||||
# you can do as many as you want here
|
|
||||||
'resolutions': [
|
|
||||||
[512, 512],
|
|
||||||
# [ 512, 768 ]
|
|
||||||
# [ 768, 768 ]
|
|
||||||
],
|
|
||||||
# slider training uses 4 combined steps for a single round. This will do it in one gradient
|
|
||||||
# step. It is highly optimized and shouldn't take anymore vram than doing without it,
|
|
||||||
# since we break down batches for gradient accumulation now. so just leave it on.
|
|
||||||
'batch_full_slide': True,
|
|
||||||
# These are the concepts to train on. You can do as many as you want here,
|
|
||||||
# but they can conflict outweigh each other. Other than experimenting, I recommend
|
|
||||||
# just doing one for good results
|
|
||||||
'targets': [
|
|
||||||
# target_class is the base concept we are adjusting the representation of
|
|
||||||
# for example, if we are adjusting the representation of a person, we would use "person"
|
|
||||||
# if we are adjusting the representation of a cat, we would use "cat" It is not
|
|
||||||
# a keyword necessarily but what the model understands the concept to represent.
|
|
||||||
# "person" will affect men, women, children, etc but will not affect cats, dogs, etc
|
|
||||||
# it is the models base general understanding of the concept and everything it represents
|
|
||||||
# you can leave it blank to affect everything. In this example, we are adjusting
|
|
||||||
# detail, so we will leave it blank to affect everything
|
|
||||||
{
|
|
||||||
'target_class': "",
|
|
||||||
# positive is the prompt for the positive side of the slider.
|
|
||||||
# It is the concept that will be excited and amplified in the model when we slide the slider
|
|
||||||
# to the positive side and forgotten / inverted when we slide
|
|
||||||
# the slider to the negative side. It is generally best to include the target_class in
|
|
||||||
# the prompt. You want it to be the extreme of what you want to train on. For example,
|
|
||||||
# if you want to train on fat people, you would use "an extremely fat, morbidly obese person"
|
|
||||||
# as the prompt. Not just "fat person"
|
|
||||||
# max 75 tokens for now
|
|
||||||
'positive': "high detail, 8k, intricate, detailed, high resolution, high res, high quality",
|
|
||||||
# negative is the prompt for the negative side of the slider and works the same as positive
|
|
||||||
# it does not necessarily work the same as a negative prompt when generating images
|
|
||||||
# these need to be polar opposites.
|
|
||||||
# max 76 tokens for now
|
|
||||||
'negative': "blurry, boring, fuzzy, low detail, low resolution, low res, low quality",
|
|
||||||
# the loss for this target is multiplied by this number.
|
|
||||||
# if you are doing more than one target it may be good to set less important ones
|
|
||||||
# to a lower number like 0.1 so they don't outweigh the primary target
|
|
||||||
'weight': 1.0,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
# You can put any information you want here, and it will be saved in the model.
|
|
||||||
# The below is an example, but you can put your grocery list in it if you want.
|
|
||||||
# It is saved in the model so be aware of that. The software will include this
|
|
||||||
# plus some other information for you automatically
|
|
||||||
'meta': {
|
|
||||||
# [name] gets replaced with the name above
|
|
||||||
'name': "[name]",
|
|
||||||
'version': '1.0',
|
|
||||||
# 'creator': {
|
|
||||||
# 'name': 'your name',
|
|
||||||
# 'email': 'your@gmail.com',
|
|
||||||
# 'website': 'https://your.website'
|
|
||||||
# }
|
|
||||||
}
|
|
||||||
})
|
|
||||||
217
toolkit/civitai.py
Normal file
217
toolkit/civitai.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
from toolkit.paths import MODELS_PATH
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCache:
|
||||||
|
def __init__(self):
|
||||||
|
self.raw_cache = {}
|
||||||
|
self.cache_path = os.path.join(MODELS_PATH, '.ai_toolkit_cache.json')
|
||||||
|
if os.path.exists(self.cache_path):
|
||||||
|
with open(self.cache_path, 'r') as f:
|
||||||
|
all_cache = json.load(f)
|
||||||
|
if 'models' in all_cache:
|
||||||
|
self.raw_cache = all_cache['models']
|
||||||
|
else:
|
||||||
|
self.raw_cache = all_cache
|
||||||
|
|
||||||
|
def get_model_path(self, model_id: int, model_version_id: int = None):
|
||||||
|
if str(model_id) not in self.raw_cache:
|
||||||
|
return None
|
||||||
|
if model_version_id is None:
|
||||||
|
# get latest version
|
||||||
|
model_version_id = max([int(x) for x in self.raw_cache[str(model_id)].keys()])
|
||||||
|
if model_version_id is None:
|
||||||
|
return None
|
||||||
|
model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path']
|
||||||
|
# check if model path exists
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
# remove version from cache
|
||||||
|
del self.raw_cache[str(model_id)][str(model_version_id)]
|
||||||
|
self.save()
|
||||||
|
return None
|
||||||
|
return model_path
|
||||||
|
else:
|
||||||
|
if str(model_version_id) not in self.raw_cache[str(model_id)]:
|
||||||
|
return None
|
||||||
|
model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path']
|
||||||
|
# check if model path exists
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
# remove version from cache
|
||||||
|
del self.raw_cache[str(model_id)][str(model_version_id)]
|
||||||
|
self.save()
|
||||||
|
return None
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
def update_cache(self, model_id: int, model_version_id: int, model_path: str):
|
||||||
|
if str(model_id) not in self.raw_cache:
|
||||||
|
self.raw_cache[str(model_id)] = {}
|
||||||
|
if str(model_version_id) not in self.raw_cache[str(model_id)]:
|
||||||
|
self.raw_cache[str(model_id)][str(model_version_id)] = {}
|
||||||
|
self.raw_cache[str(model_id)][str(model_version_id)] = {
|
||||||
|
'model_path': model_path
|
||||||
|
}
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
if not os.path.exists(os.path.dirname(self.cache_path)):
|
||||||
|
os.makedirs(os.path.dirname(self.cache_path), exist_ok=True)
|
||||||
|
all_cache = {'models': {}}
|
||||||
|
if os.path.exists(self.cache_path):
|
||||||
|
# load it first
|
||||||
|
with open(self.cache_path, 'r') as f:
|
||||||
|
all_cache = json.load(f)
|
||||||
|
|
||||||
|
all_cache['models'] = self.raw_cache
|
||||||
|
|
||||||
|
with open(self.cache_path, 'w') as f:
|
||||||
|
json.dump(all_cache, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_download_info(model_id: int, model_version_id: int = None):
|
||||||
|
# curl https://civitai.com/api/v1/models?limit=3&types=TextualInversion \
|
||||||
|
# -H "Content-Type: application/json" \
|
||||||
|
# -X GET
|
||||||
|
print(
|
||||||
|
f"Getting model info for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}")
|
||||||
|
endpoint = f"https://civitai.com/api/v1/models/{model_id}"
|
||||||
|
|
||||||
|
# get the json
|
||||||
|
response = requests.get(endpoint)
|
||||||
|
response.raise_for_status()
|
||||||
|
model_data = response.json()
|
||||||
|
|
||||||
|
model_version = None
|
||||||
|
|
||||||
|
# go through versions and get the top one if one is not set
|
||||||
|
for version in model_data['modelVersions']:
|
||||||
|
if model_version_id is not None:
|
||||||
|
if str(version['id']) == str(model_version_id):
|
||||||
|
model_version = version
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# get first version
|
||||||
|
model_version = version
|
||||||
|
break
|
||||||
|
|
||||||
|
if model_version is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find a model version for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}")
|
||||||
|
|
||||||
|
model_file = None
|
||||||
|
# go through files and prefer fp16 safetensors
|
||||||
|
# "metadata": {
|
||||||
|
# "fp": "fp16",
|
||||||
|
# "size": "pruned",
|
||||||
|
# "format": "SafeTensor"
|
||||||
|
# },
|
||||||
|
# todo check pickle scans and skip if not good
|
||||||
|
# try to get fp16 safetensor
|
||||||
|
for file in model_version['files']:
|
||||||
|
if file['metadata']['fp'] == 'fp16' and file['metadata']['format'] == 'SafeTensor':
|
||||||
|
model_file = file
|
||||||
|
break
|
||||||
|
|
||||||
|
if model_file is None:
|
||||||
|
# try to get primary
|
||||||
|
for file in model_version['files']:
|
||||||
|
if file['primary']:
|
||||||
|
model_file = file
|
||||||
|
break
|
||||||
|
|
||||||
|
if model_file is None:
|
||||||
|
# try to get any safetensor
|
||||||
|
for file in model_version['files']:
|
||||||
|
if file['metadata']['format'] == 'SafeTensor':
|
||||||
|
model_file = file
|
||||||
|
break
|
||||||
|
|
||||||
|
if model_file is None:
|
||||||
|
# try to get any fp16
|
||||||
|
for file in model_version['files']:
|
||||||
|
if file['metadata']['fp'] == 'fp16':
|
||||||
|
model_file = file
|
||||||
|
break
|
||||||
|
|
||||||
|
if model_file is None:
|
||||||
|
# try to get any
|
||||||
|
for file in model_version['files']:
|
||||||
|
model_file = file
|
||||||
|
break
|
||||||
|
|
||||||
|
if model_file is None:
|
||||||
|
raise ValueError(f"Could not find a model file to download for model id: {model_id}")
|
||||||
|
|
||||||
|
return model_file, model_version['id']
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path_from_url(url: str):
|
||||||
|
# get query params form url if they are set
|
||||||
|
# https: // civitai.com / models / 25694?modelVersionId = 127742
|
||||||
|
query_params = {}
|
||||||
|
if '?' in url:
|
||||||
|
query_string = url.split('?')[1]
|
||||||
|
query_params = dict(qc.split("=") for qc in query_string.split("&"))
|
||||||
|
|
||||||
|
# get model id from url
|
||||||
|
model_id = url.split('/')[-1]
|
||||||
|
# remove query params from model id
|
||||||
|
if '?' in model_id:
|
||||||
|
model_id = model_id.split('?')[0]
|
||||||
|
if model_id.isdigit():
|
||||||
|
model_id = int(model_id)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid model id: {model_id}")
|
||||||
|
|
||||||
|
model_cache = ModelCache()
|
||||||
|
model_path = model_cache.get_model_path(model_id, query_params.get('modelVersionId', None))
|
||||||
|
if model_path is not None:
|
||||||
|
return model_path
|
||||||
|
else:
|
||||||
|
# download model
|
||||||
|
file_info, model_version_id = get_model_download_info(model_id, query_params.get('modelVersionId', None))
|
||||||
|
|
||||||
|
download_url = file_info['downloadUrl'] # url does not work directly
|
||||||
|
size_kb = file_info['sizeKB']
|
||||||
|
filename = file_info['name']
|
||||||
|
model_path = os.path.join(MODELS_PATH, filename)
|
||||||
|
|
||||||
|
# download model
|
||||||
|
print(f"Did not find model locally, downloading from model from: {download_url}")
|
||||||
|
|
||||||
|
# use tqdm to show status of downlod
|
||||||
|
response = requests.get(download_url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||||
|
block_size = 1024 # 1 Kibibyte
|
||||||
|
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||||
|
tmp_path = os.path.join(MODELS_PATH, f".download_tmp_{filename}")
|
||||||
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||||
|
# remove tmp file if it exists
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
with open(tmp_path, 'wb') as f:
|
||||||
|
for data in response.iter_content(block_size):
|
||||||
|
progress_bar.update(len(data))
|
||||||
|
f.write(data)
|
||||||
|
progress_bar.close()
|
||||||
|
# move to final path
|
||||||
|
os.rename(tmp_path, model_path)
|
||||||
|
model_cache.update_cache(model_id, model_version_id, model_path)
|
||||||
|
|
||||||
|
return model_path
|
||||||
|
except Exception as e:
|
||||||
|
# remove tmp file
|
||||||
|
os.remove(tmp_path)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
# if is main
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model_path = get_model_path_from_url("https://civitai.com/models/25694?modelVersionId=127742")
|
||||||
|
print(model_path)
|
||||||
@@ -5,6 +5,12 @@ CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
|||||||
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
||||||
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
|
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
|
||||||
|
|
||||||
|
# check if ENV variable is set
|
||||||
|
if 'MODELS_PATH' in os.environ:
|
||||||
|
MODELS_PATH = os.environ['MODELS_PATH']
|
||||||
|
else:
|
||||||
|
MODELS_PATH = os.path.join(TOOLKIT_ROOT, "models")
|
||||||
|
|
||||||
|
|
||||||
def get_path(path):
|
def get_path(path):
|
||||||
# we allow absolute paths, but if it is not absolute, we assume it is relative to the toolkit root
|
# we allow absolute paths, but if it is not absolute, we assume it is relative to the toolkit root
|
||||||
|
|||||||
@@ -143,6 +143,13 @@ class StableDiffusion:
|
|||||||
prediction_type=prediction_type,
|
prediction_type=prediction_type,
|
||||||
steps_offset=1
|
steps_offset=1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_path = self.model_config.name_or_path
|
||||||
|
if 'civitai.com' in self.model_config.name_or_path:
|
||||||
|
# load is a civit ai model, use the loader.
|
||||||
|
from toolkit.civitai import get_model_path_from_url
|
||||||
|
model_path = get_model_path_from_url(self.model_config.name_or_path)
|
||||||
|
|
||||||
if self.model_config.is_xl:
|
if self.model_config.is_xl:
|
||||||
if self.custom_pipeline is not None:
|
if self.custom_pipeline is not None:
|
||||||
pipln = self.custom_pipeline
|
pipln = self.custom_pipeline
|
||||||
@@ -150,17 +157,17 @@ class StableDiffusion:
|
|||||||
pipln = CustomStableDiffusionXLPipeline
|
pipln = CustomStableDiffusionXLPipeline
|
||||||
|
|
||||||
# see if path exists
|
# see if path exists
|
||||||
if not os.path.exists(self.model_config.name_or_path):
|
if not os.path.exists(model_path):
|
||||||
# try to load with default diffusers
|
# try to load with default diffusers
|
||||||
pipe = pipln.from_pretrained(
|
pipe = pipln.from_pretrained(
|
||||||
self.model_config.name_or_path,
|
model_path,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
scheduler_type='ddpm',
|
scheduler_type='ddpm',
|
||||||
device=self.device_torch,
|
device=self.device_torch,
|
||||||
).to(self.device_torch)
|
).to(self.device_torch)
|
||||||
else:
|
else:
|
||||||
pipe = pipln.from_single_file(
|
pipe = pipln.from_single_file(
|
||||||
self.model_config.name_or_path,
|
model_path,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
scheduler_type='ddpm',
|
scheduler_type='ddpm',
|
||||||
device=self.device_torch,
|
device=self.device_torch,
|
||||||
@@ -180,10 +187,10 @@ class StableDiffusion:
|
|||||||
pipln = CustomStableDiffusionPipeline
|
pipln = CustomStableDiffusionPipeline
|
||||||
|
|
||||||
# see if path exists
|
# see if path exists
|
||||||
if not os.path.exists(self.model_config.name_or_path):
|
if not os.path.exists(model_path):
|
||||||
# try to load with default diffusers
|
# try to load with default diffusers
|
||||||
pipe = pipln.from_pretrained(
|
pipe = pipln.from_pretrained(
|
||||||
self.model_config.name_or_path,
|
model_path,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
scheduler_type='dpm',
|
scheduler_type='dpm',
|
||||||
device=self.device_torch,
|
device=self.device_torch,
|
||||||
@@ -193,7 +200,7 @@ class StableDiffusion:
|
|||||||
).to(self.device_torch)
|
).to(self.device_torch)
|
||||||
else:
|
else:
|
||||||
pipe = pipln.from_single_file(
|
pipe = pipln.from_single_file(
|
||||||
self.model_config.name_or_path,
|
model_path,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
scheduler_type='dpm',
|
scheduler_type='dpm',
|
||||||
device=self.device_torch,
|
device=self.device_torch,
|
||||||
|
|||||||
Reference in New Issue
Block a user