mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-07 19:49:49 +00:00
WIP create new class to add new models more easily
This commit is contained in:
@@ -29,7 +29,7 @@ from toolkit.ip_adapter import IPAdapter
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict, load_vae
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.models.decorator import Decorator
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
@@ -177,16 +177,17 @@ class StableDiffusion:
|
||||
self.network = None
|
||||
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
|
||||
self.decorator: Union[Decorator, None] = None
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
self.is_ssd = model_config.is_ssd
|
||||
self.is_v3 = model_config.is_v3
|
||||
self.is_vega = model_config.is_vega
|
||||
self.is_pixart = model_config.is_pixart
|
||||
self.is_auraflow = model_config.is_auraflow
|
||||
self.is_flux = model_config.is_flux
|
||||
self.is_flex2 = model_config.is_flex2
|
||||
self.is_lumina2 = model_config.is_lumina2
|
||||
self.arch: ModelArch = model_config.arch
|
||||
# self.is_xl = model_config.is_xl
|
||||
# self.is_v2 = model_config.is_v2
|
||||
# self.is_ssd = model_config.is_ssd
|
||||
# self.is_v3 = model_config.is_v3
|
||||
# self.is_vega = model_config.is_vega
|
||||
# self.is_pixart = model_config.is_pixart
|
||||
# self.is_auraflow = model_config.is_auraflow
|
||||
# self.is_flux = model_config.is_flux
|
||||
# self.is_flex2 = model_config.is_flex2
|
||||
# self.is_lumina2 = model_config.is_lumina2
|
||||
|
||||
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
||||
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
||||
@@ -204,6 +205,47 @@ class StableDiffusion:
|
||||
self.invert_assistant_lora = False
|
||||
self._after_sample_img_hooks = []
|
||||
self._status_update_hooks = []
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def is_xl(self):
|
||||
return self.arch == 'sdxl'
|
||||
|
||||
@property
|
||||
def is_v2(self):
|
||||
return self.arch == 'sd2'
|
||||
|
||||
@property
|
||||
def is_ssd(self):
|
||||
return self.arch == 'ssd'
|
||||
|
||||
@property
|
||||
def is_v3(self):
|
||||
return self.arch == 'sd3'
|
||||
|
||||
@property
|
||||
def is_vega(self):
|
||||
return self.arch == 'vega'
|
||||
|
||||
@property
|
||||
def is_pixart(self):
|
||||
return self.arch == 'pixart'
|
||||
|
||||
@property
|
||||
def is_auraflow(self):
|
||||
return self.arch == 'auraflow'
|
||||
|
||||
@property
|
||||
def is_flux(self):
|
||||
return self.arch == 'flux'
|
||||
|
||||
@property
|
||||
def is_flex2(self):
|
||||
return self.arch == 'flex2'
|
||||
|
||||
@property
|
||||
def is_lumina2(self):
|
||||
return self.arch == 'lumina2'
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
|
||||
Reference in New Issue
Block a user