WIP create new class to add new models more easily

This commit is contained in:
Jaret Burkett
2025-03-01 13:49:02 -07:00
parent 60539c0b0f
commit acc79956aa
7 changed files with 1624 additions and 13 deletions

View File

@@ -68,6 +68,8 @@ import transformers
import diffusers
import hashlib
from toolkit.util.get_model import get_model_class
def flush():
torch.cuda.empty_cache()
gc.collect()
@@ -1423,7 +1425,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
model_config_to_load.refiner_name_or_path = previous_refiner_save
self.load_training_state_from_metadata(previous_refiner_save)
self.sd = StableDiffusion(
ModelClass = get_model_class(self.model_config)
self.sd = ModelClass(
device=self.device,
model_config=model_config_to_load,
dtype=self.train_config.dtype,

View File

@@ -1,7 +1,8 @@
torch==2.5.1
torchvision==0.20.1
safetensors
git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec
# https://github.com/huggingface/diffusers/pull/10921
git+https://github.com/huggingface/diffusers@refs/pull/10921/head
transformers
lycoris-lora==1.8.3
flatten_json

View File

@@ -423,6 +423,9 @@ class TrainConfig:
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
class ModelConfig:
def __init__(self, **kwargs):
self.name_or_path: str = kwargs.get('name_or_path', None)
@@ -500,6 +503,36 @@ class ModelConfig:
self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3)
self.te_name_or_path = kwargs.get("te_name_or_path", None)
self.arch: ModelArch = kwargs.get("model_arch", None)
# handle migrating to new model arch
if self.arch is None:
if kwargs.get('is_v2', False):
self.arch = 'sd2'
elif kwargs.get('is_v3', False):
self.arch = 'sd3'
elif kwargs.get('is_xl', False):
self.arch = 'sdxl'
elif kwargs.get('is_pixart', False):
self.arch = 'pixart'
elif kwargs.get('is_pixart_sigma', False):
self.arch = 'pixart_sigma'
elif kwargs.get('is_auraflow', False):
self.arch = 'auraflow'
elif kwargs.get('is_flux', False):
self.arch = 'flux'
elif kwargs.get('is_flex2', False):
self.arch = 'flex2'
elif kwargs.get('is_lumina2', False):
self.arch = 'lumina2'
elif kwargs.get('is_vega', False):
self.arch = 'vega'
elif kwargs.get('is_ssd', False):
self.arch = 'ssd'
else:
self.arch = 'sd1'
class EMAConfig:

1467
toolkit/models/base_model.py Normal file

File diff suppressed because it is too large Load Diff

56
toolkit/models/wan21.py Normal file
View File

@@ -0,0 +1,56 @@
import torch
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanPipeline
class Wan21(BaseModel):
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
):
super().__init__(device, model_config, dtype,
custom_pipeline, noise_scheduler, **kwargs)
self.is_flow_matching = True
# these must be implemented in child classes
def load_model(self):
# override this in child classes
raise NotImplementedError(
"load_model must be implemented in child classes")
def get_generation_pipeline(self):
# override this in child classes
raise NotImplementedError(
"get_generation_pipeline must be implemented in child classes")
def generate_single_image(
self,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
# override this in child classes
raise NotImplementedError(
"generate_single_image must be implemented in child classes")
def get_noise_prediction(
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
):
raise NotImplementedError(
"get_noise_prediction must be implemented in child classes")
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
raise NotImplementedError(
"get_prompt_embeds must be implemented in child classes")

View File

@@ -29,7 +29,7 @@ from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict, load_vae
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.decorator import Decorator
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
@@ -177,16 +177,17 @@ class StableDiffusion:
self.network = None
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
self.decorator: Union[Decorator, None] = None
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
self.is_ssd = model_config.is_ssd
self.is_v3 = model_config.is_v3
self.is_vega = model_config.is_vega
self.is_pixart = model_config.is_pixart
self.is_auraflow = model_config.is_auraflow
self.is_flux = model_config.is_flux
self.is_flex2 = model_config.is_flex2
self.is_lumina2 = model_config.is_lumina2
self.arch: ModelArch = model_config.arch
# self.is_xl = model_config.is_xl
# self.is_v2 = model_config.is_v2
# self.is_ssd = model_config.is_ssd
# self.is_v3 = model_config.is_v3
# self.is_vega = model_config.is_vega
# self.is_pixart = model_config.is_pixart
# self.is_auraflow = model_config.is_auraflow
# self.is_flux = model_config.is_flux
# self.is_flex2 = model_config.is_flex2
# self.is_lumina2 = model_config.is_lumina2
self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2
@@ -204,6 +205,47 @@ class StableDiffusion:
self.invert_assistant_lora = False
self._after_sample_img_hooks = []
self._status_update_hooks = []
# properties for old arch for backwards compatibility
@property
def is_xl(self):
return self.arch == 'sdxl'
@property
def is_v2(self):
return self.arch == 'sd2'
@property
def is_ssd(self):
return self.arch == 'ssd'
@property
def is_v3(self):
return self.arch == 'sd3'
@property
def is_vega(self):
return self.arch == 'vega'
@property
def is_pixart(self):
return self.arch == 'pixart'
@property
def is_auraflow(self):
return self.arch == 'auraflow'
@property
def is_flux(self):
return self.arch == 'flux'
@property
def is_flex2(self):
return self.arch == 'flex2'
@property
def is_lumina2(self):
return self.arch == 'lumina2'
def load_model(self):
if self.is_loaded:

View File

@@ -0,0 +1,9 @@
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import ModelConfig
def get_model_class(config: ModelConfig):
if config.arch == "wan21":
from toolkit.models.wan21 import Wan21
return Wan21
else:
return StableDiffusion