mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-02 09:09:48 +00:00
Added ability to use civit ai url ar model name and built a model downloader and cache manager for it
This commit is contained in:
@@ -143,6 +143,13 @@ class StableDiffusion:
|
||||
prediction_type=prediction_type,
|
||||
steps_offset=1
|
||||
)
|
||||
|
||||
model_path = self.model_config.name_or_path
|
||||
if 'civitai.com' in self.model_config.name_or_path:
|
||||
# load is a civit ai model, use the loader.
|
||||
from toolkit.civitai import get_model_path_from_url
|
||||
model_path = get_model_path_from_url(self.model_config.name_or_path)
|
||||
|
||||
if self.model_config.is_xl:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
@@ -150,17 +157,17 @@ class StableDiffusion:
|
||||
pipln = CustomStableDiffusionXLPipeline
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(self.model_config.name_or_path):
|
||||
if not os.path.exists(model_path):
|
||||
# try to load with default diffusers
|
||||
pipe = pipln.from_pretrained(
|
||||
self.model_config.name_or_path,
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
self.model_config.name_or_path,
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
@@ -180,10 +187,10 @@ class StableDiffusion:
|
||||
pipln = CustomStableDiffusionPipeline
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(self.model_config.name_or_path):
|
||||
if not os.path.exists(model_path):
|
||||
# try to load with default diffusers
|
||||
pipe = pipln.from_pretrained(
|
||||
self.model_config.name_or_path,
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch,
|
||||
@@ -193,7 +200,7 @@ class StableDiffusion:
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
self.model_config.name_or_path,
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch,
|
||||
|
||||
Reference in New Issue
Block a user