WIP create new class to add new models more easily

This commit is contained in:
Jaret Burkett
2025-03-01 13:49:02 -07:00
parent 60539c0b0f
commit acc79956aa
7 changed files with 1624 additions and 13 deletions

View File

@@ -29,7 +29,7 @@ from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict, load_vae
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.decorator import Decorator
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
@@ -177,16 +177,17 @@ class StableDiffusion:
self.network = None
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
self.decorator: Union[Decorator, None] = None
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
self.is_ssd = model_config.is_ssd
self.is_v3 = model_config.is_v3
self.is_vega = model_config.is_vega
self.is_pixart = model_config.is_pixart
self.is_auraflow = model_config.is_auraflow
self.is_flux = model_config.is_flux
self.is_flex2 = model_config.is_flex2
self.is_lumina2 = model_config.is_lumina2
self.arch: ModelArch = model_config.arch
# self.is_xl = model_config.is_xl
# self.is_v2 = model_config.is_v2
# self.is_ssd = model_config.is_ssd
# self.is_v3 = model_config.is_v3
# self.is_vega = model_config.is_vega
# self.is_pixart = model_config.is_pixart
# self.is_auraflow = model_config.is_auraflow
# self.is_flux = model_config.is_flux
# self.is_flex2 = model_config.is_flex2
# self.is_lumina2 = model_config.is_lumina2
self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2
@@ -204,6 +205,47 @@ class StableDiffusion:
self.invert_assistant_lora = False
self._after_sample_img_hooks = []
self._status_update_hooks = []
# properties for old arch for backwards compatibility
@property
def is_xl(self):
return self.arch == 'sdxl'
@property
def is_v2(self):
return self.arch == 'sd2'
@property
def is_ssd(self):
return self.arch == 'ssd'
@property
def is_v3(self):
return self.arch == 'sd3'
@property
def is_vega(self):
return self.arch == 'vega'
@property
def is_pixart(self):
return self.arch == 'pixart'
@property
def is_auraflow(self):
return self.arch == 'auraflow'
@property
def is_flux(self):
return self.arch == 'flux'
@property
def is_flex2(self):
return self.arch == 'flex2'
@property
def is_lumina2(self):
return self.arch == 'lumina2'
def load_model(self):
if self.is_loaded: