diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 65085f67..2dbe830f 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -26,6 +26,10 @@ from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl from diffusers.schedulers import DDPMScheduler from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +import diffusers + +# tell it to shut up +diffusers.logging.set_verbosity(diffusers.logging.ERROR) class BlankNetwork: @@ -113,6 +117,7 @@ class StableDiffusion: # sdxl stuff self.logit_scale = None self.ckppt_info = None + self.is_loaded = False # to hold network if there is one self.network = None @@ -120,6 +125,8 @@ class StableDiffusion: self.is_v2 = model_config.is_v2 def load_model(self): + if self.is_loaded: + return dtype = get_torch_dtype(self.dtype) # TODO handle other schedulers @@ -217,6 +224,7 @@ class StableDiffusion: self.tokenizer = tokenizer self.text_encoder = text_encoder self.pipeline = pipe + self.is_loaded = True def generate_images(self, image_configs: List[GenerateImageConfig]): # sample_folder = os.path.join(self.save_root, 'samples')