Merge branch 'main' of github.com:ostris/ai-toolkit

This commit is contained in:
Jaret Burkett
2023-08-12 05:59:58 -06:00

View File

@@ -26,6 +26,10 @@ from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
from diffusers.schedulers import DDPMScheduler
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import diffusers
# tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
class BlankNetwork:
@@ -113,6 +117,7 @@ class StableDiffusion:
# sdxl stuff
self.logit_scale = None
self.ckppt_info = None
self.is_loaded = False
# to hold network if there is one
self.network = None
@@ -120,6 +125,8 @@ class StableDiffusion:
self.is_v2 = model_config.is_v2
def load_model(self):
if self.is_loaded:
return
dtype = get_torch_dtype(self.dtype)
# TODO handle other schedulers
@@ -217,6 +224,7 @@ class StableDiffusion:
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.pipeline = pipe
self.is_loaded = True
def generate_images(self, image_configs: List[GenerateImageConfig]):
# sample_folder = os.path.join(self.save_root, 'samples')