mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
WIP create new class to add new models more easily
This commit is contained in:
@@ -68,6 +68,8 @@ import transformers
|
||||
import diffusers
|
||||
import hashlib
|
||||
|
||||
from toolkit.util.get_model import get_model_class
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
@@ -1423,7 +1425,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
model_config_to_load.refiner_name_or_path = previous_refiner_save
|
||||
self.load_training_state_from_metadata(previous_refiner_save)
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
ModelClass = get_model_class(self.model_config)
|
||||
self.sd = ModelClass(
|
||||
device=self.device,
|
||||
model_config=model_config_to_load,
|
||||
dtype=self.train_config.dtype,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
torch==2.5.1
|
||||
torchvision==0.20.1
|
||||
safetensors
|
||||
git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec
|
||||
# https://github.com/huggingface/diffusers/pull/10921
|
||||
git+https://github.com/huggingface/diffusers@refs/pull/10921/head
|
||||
transformers
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
|
||||
@@ -423,6 +423,9 @@ class TrainConfig:
|
||||
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
|
||||
|
||||
|
||||
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.name_or_path: str = kwargs.get('name_or_path', None)
|
||||
@@ -500,6 +503,36 @@ class ModelConfig:
|
||||
self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3)
|
||||
|
||||
self.te_name_or_path = kwargs.get("te_name_or_path", None)
|
||||
|
||||
self.arch: ModelArch = kwargs.get("model_arch", None)
|
||||
|
||||
# handle migrating to new model arch
|
||||
if self.arch is None:
|
||||
if kwargs.get('is_v2', False):
|
||||
self.arch = 'sd2'
|
||||
elif kwargs.get('is_v3', False):
|
||||
self.arch = 'sd3'
|
||||
elif kwargs.get('is_xl', False):
|
||||
self.arch = 'sdxl'
|
||||
elif kwargs.get('is_pixart', False):
|
||||
self.arch = 'pixart'
|
||||
elif kwargs.get('is_pixart_sigma', False):
|
||||
self.arch = 'pixart_sigma'
|
||||
elif kwargs.get('is_auraflow', False):
|
||||
self.arch = 'auraflow'
|
||||
elif kwargs.get('is_flux', False):
|
||||
self.arch = 'flux'
|
||||
elif kwargs.get('is_flex2', False):
|
||||
self.arch = 'flex2'
|
||||
elif kwargs.get('is_lumina2', False):
|
||||
self.arch = 'lumina2'
|
||||
elif kwargs.get('is_vega', False):
|
||||
self.arch = 'vega'
|
||||
elif kwargs.get('is_ssd', False):
|
||||
self.arch = 'ssd'
|
||||
else:
|
||||
self.arch = 'sd1'
|
||||
|
||||
|
||||
|
||||
class EMAConfig:
|
||||
|
||||
1467
toolkit/models/base_model.py
Normal file
1467
toolkit/models/base_model.py
Normal file
File diff suppressed because it is too large
Load Diff
56
toolkit/models/wan21.py
Normal file
56
toolkit/models/wan21.py
Normal file
@@ -0,0 +1,56 @@
|
||||
|
||||
import torch
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanPipeline
|
||||
|
||||
class Wan21(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype='bf16',
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(device, model_config, dtype,
|
||||
custom_pipeline, noise_scheduler, **kwargs)
|
||||
self.is_flow_matching = True
|
||||
# these must be implemented in child classes
|
||||
|
||||
def load_model(self):
|
||||
# override this in child classes
|
||||
raise NotImplementedError(
|
||||
"load_model must be implemented in child classes")
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
# override this in child classes
|
||||
raise NotImplementedError(
|
||||
"get_generation_pipeline must be implemented in child classes")
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
# override this in child classes
|
||||
raise NotImplementedError(
|
||||
"generate_single_image must be implemented in child classes")
|
||||
|
||||
def get_noise_prediction(
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
**kwargs
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"get_noise_prediction must be implemented in child classes")
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
raise NotImplementedError(
|
||||
"get_prompt_embeds must be implemented in child classes")
|
||||
@@ -29,7 +29,7 @@ from toolkit.ip_adapter import IPAdapter
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict, load_vae
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.models.decorator import Decorator
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
@@ -177,16 +177,17 @@ class StableDiffusion:
|
||||
self.network = None
|
||||
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
|
||||
self.decorator: Union[Decorator, None] = None
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
self.is_ssd = model_config.is_ssd
|
||||
self.is_v3 = model_config.is_v3
|
||||
self.is_vega = model_config.is_vega
|
||||
self.is_pixart = model_config.is_pixart
|
||||
self.is_auraflow = model_config.is_auraflow
|
||||
self.is_flux = model_config.is_flux
|
||||
self.is_flex2 = model_config.is_flex2
|
||||
self.is_lumina2 = model_config.is_lumina2
|
||||
self.arch: ModelArch = model_config.arch
|
||||
# self.is_xl = model_config.is_xl
|
||||
# self.is_v2 = model_config.is_v2
|
||||
# self.is_ssd = model_config.is_ssd
|
||||
# self.is_v3 = model_config.is_v3
|
||||
# self.is_vega = model_config.is_vega
|
||||
# self.is_pixart = model_config.is_pixart
|
||||
# self.is_auraflow = model_config.is_auraflow
|
||||
# self.is_flux = model_config.is_flux
|
||||
# self.is_flex2 = model_config.is_flex2
|
||||
# self.is_lumina2 = model_config.is_lumina2
|
||||
|
||||
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
||||
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
||||
@@ -204,6 +205,47 @@ class StableDiffusion:
|
||||
self.invert_assistant_lora = False
|
||||
self._after_sample_img_hooks = []
|
||||
self._status_update_hooks = []
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def is_xl(self):
|
||||
return self.arch == 'sdxl'
|
||||
|
||||
@property
|
||||
def is_v2(self):
|
||||
return self.arch == 'sd2'
|
||||
|
||||
@property
|
||||
def is_ssd(self):
|
||||
return self.arch == 'ssd'
|
||||
|
||||
@property
|
||||
def is_v3(self):
|
||||
return self.arch == 'sd3'
|
||||
|
||||
@property
|
||||
def is_vega(self):
|
||||
return self.arch == 'vega'
|
||||
|
||||
@property
|
||||
def is_pixart(self):
|
||||
return self.arch == 'pixart'
|
||||
|
||||
@property
|
||||
def is_auraflow(self):
|
||||
return self.arch == 'auraflow'
|
||||
|
||||
@property
|
||||
def is_flux(self):
|
||||
return self.arch == 'flux'
|
||||
|
||||
@property
|
||||
def is_flex2(self):
|
||||
return self.arch == 'flex2'
|
||||
|
||||
@property
|
||||
def is_lumina2(self):
|
||||
return self.arch == 'lumina2'
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
|
||||
9
toolkit/util/get_model.py
Normal file
9
toolkit/util/get_model.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.config_modules import ModelConfig
|
||||
|
||||
def get_model_class(config: ModelConfig):
|
||||
if config.arch == "wan21":
|
||||
from toolkit.models.wan21 import Wan21
|
||||
return Wan21
|
||||
else:
|
||||
return StableDiffusion
|
||||
Reference in New Issue
Block a user