diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index f06f2929..1841d496 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -24,6 +24,10 @@ from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl from diffusers.schedulers import DDPMScheduler from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +import diffusers + +# tell it to shut up +diffusers.logging.set_verbosity(diffusers.logging.ERROR) class BlankNetwork: @@ -111,6 +115,7 @@ class StableDiffusion: # sdxl stuff self.logit_scale = None self.ckppt_info = None + self.is_loaded = False # to hold network if there is one self.network = None @@ -118,6 +123,8 @@ class StableDiffusion: self.is_v2 = model_config.is_v2 def load_model(self): + if self.is_loaded: + return dtype = get_torch_dtype(self.dtype) # TODO handle other schedulers @@ -212,6 +219,7 @@ class StableDiffusion: self.tokenizer = tokenizer self.text_encoder = text_encoder self.pipeline = pipe + self.is_loaded = True def generate_images(self, image_configs: List[GenerateImageConfig]): # sample_folder = os.path.join(self.save_root, 'samples')