mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 02:01:29 +00:00
Merge branch 'main' of github.com:ostris/ai-toolkit
This commit is contained in:
@@ -26,6 +26,10 @@ from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
|
|||||||
from diffusers.schedulers import DDPMScheduler
|
from diffusers.schedulers import DDPMScheduler
|
||||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
|
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
|
||||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||||
|
import diffusers
|
||||||
|
|
||||||
|
# tell it to shut up
|
||||||
|
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
class BlankNetwork:
|
class BlankNetwork:
|
||||||
@@ -113,6 +117,7 @@ class StableDiffusion:
|
|||||||
# sdxl stuff
|
# sdxl stuff
|
||||||
self.logit_scale = None
|
self.logit_scale = None
|
||||||
self.ckppt_info = None
|
self.ckppt_info = None
|
||||||
|
self.is_loaded = False
|
||||||
|
|
||||||
# to hold network if there is one
|
# to hold network if there is one
|
||||||
self.network = None
|
self.network = None
|
||||||
@@ -120,6 +125,8 @@ class StableDiffusion:
|
|||||||
self.is_v2 = model_config.is_v2
|
self.is_v2 = model_config.is_v2
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
|
if self.is_loaded:
|
||||||
|
return
|
||||||
dtype = get_torch_dtype(self.dtype)
|
dtype = get_torch_dtype(self.dtype)
|
||||||
|
|
||||||
# TODO handle other schedulers
|
# TODO handle other schedulers
|
||||||
@@ -217,6 +224,7 @@ class StableDiffusion:
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.text_encoder = text_encoder
|
self.text_encoder = text_encoder
|
||||||
self.pipeline = pipe
|
self.pipeline = pipe
|
||||||
|
self.is_loaded = True
|
||||||
|
|
||||||
def generate_images(self, image_configs: List[GenerateImageConfig]):
|
def generate_images(self, image_configs: List[GenerateImageConfig]):
|
||||||
# sample_folder = os.path.join(self.save_root, 'samples')
|
# sample_folder = os.path.join(self.save_root, 'samples')
|
||||||
|
|||||||
Reference in New Issue
Block a user