mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Merge branch 'main' of github.com:ostris/ai-toolkit
This commit is contained in:
@@ -26,6 +26,10 @@ from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
|
||||
from diffusers.schedulers import DDPMScheduler
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
import diffusers
|
||||
|
||||
# tell it to shut up
|
||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||
|
||||
|
||||
class BlankNetwork:
|
||||
@@ -113,6 +117,7 @@ class StableDiffusion:
|
||||
# sdxl stuff
|
||||
self.logit_scale = None
|
||||
self.ckppt_info = None
|
||||
self.is_loaded = False
|
||||
|
||||
# to hold network if there is one
|
||||
self.network = None
|
||||
@@ -120,6 +125,8 @@ class StableDiffusion:
|
||||
self.is_v2 = model_config.is_v2
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
return
|
||||
dtype = get_torch_dtype(self.dtype)
|
||||
|
||||
# TODO handle other schedulers
|
||||
@@ -217,6 +224,7 @@ class StableDiffusion:
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.pipeline = pipe
|
||||
self.is_loaded = True
|
||||
|
||||
def generate_images(self, image_configs: List[GenerateImageConfig]):
|
||||
# sample_folder = os.path.join(self.save_root, 'samples')
|
||||
|
||||
Reference in New Issue
Block a user