Merge branch 'main' of github.com:ostris/ai-toolkit

This commit is contained in:
Jaret Burkett
2023-08-12 05:59:58 -06:00

View File

@@ -26,6 +26,10 @@ from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
from diffusers.schedulers import DDPMScheduler from diffusers.schedulers import DDPMScheduler
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import diffusers
# tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
class BlankNetwork: class BlankNetwork:
@@ -113,6 +117,7 @@ class StableDiffusion:
# sdxl stuff # sdxl stuff
self.logit_scale = None self.logit_scale = None
self.ckppt_info = None self.ckppt_info = None
self.is_loaded = False
# to hold network if there is one # to hold network if there is one
self.network = None self.network = None
@@ -120,6 +125,8 @@ class StableDiffusion:
self.is_v2 = model_config.is_v2 self.is_v2 = model_config.is_v2
def load_model(self): def load_model(self):
if self.is_loaded:
return
dtype = get_torch_dtype(self.dtype) dtype = get_torch_dtype(self.dtype)
# TODO handle other schedulers # TODO handle other schedulers
@@ -217,6 +224,7 @@ class StableDiffusion:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.text_encoder = text_encoder self.text_encoder = text_encoder
self.pipeline = pipe self.pipeline = pipe
self.is_loaded = True
def generate_images(self, image_configs: List[GenerateImageConfig]): def generate_images(self, image_configs: List[GenerateImageConfig]):
# sample_folder = os.path.join(self.save_root, 'samples') # sample_folder = os.path.join(self.save_root, 'samples')