From 77001ee77f8f5b3ff9cb97946d094852e502afce Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 19 Apr 2025 10:41:27 -0600 Subject: [PATCH] Upodate model tag on loras --- .../diffusion_models/chroma/chroma_model.py | 3 +++ .../diffusion_models/hidream/hidream_model.py | 3 +++ jobs/process/BaseSDTrainProcess.py | 13 +----------- toolkit/models/base_model.py | 4 ++++ toolkit/models/wan21/wan21.py | 3 +++ toolkit/stable_diffusion_model.py | 21 +++++++++++++++++++ 6 files changed, 35 insertions(+), 12 deletions(-) diff --git a/extensions_built_in/diffusion_models/chroma/chroma_model.py b/extensions_built_in/diffusion_models/chroma/chroma_model.py index d3a92049..35be1ac5 100644 --- a/extensions_built_in/diffusion_models/chroma/chroma_model.py +++ b/extensions_built_in/diffusion_models/chroma/chroma_model.py @@ -386,3 +386,6 @@ class ChromaModel(BaseModel): new_key = key.replace("diffusion_model.", "transformer.") new_sd[new_key] = value return new_sd + + def get_base_model_version(self): + return "chroma" diff --git a/extensions_built_in/diffusion_models/hidream/hidream_model.py b/extensions_built_in/diffusion_models/hidream/hidream_model.py index 35b4a192..5d07eb6c 100644 --- a/extensions_built_in/diffusion_models/hidream/hidream_model.py +++ b/extensions_built_in/diffusion_models/hidream/hidream_model.py @@ -443,3 +443,6 @@ class HidreamModel(BaseModel): new_sd[new_key] = value return new_sd + def get_base_model_version(self): + return "hidream_i1" + diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 8056e87b..08b6760f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -339,18 +339,7 @@ class BaseSDTrainProcess(BaseTrainProcess): o_dict = OrderedDict({ "training_info": self.get_training_info() }) - if self.model_config.is_v2: - o_dict['ss_v2'] = True - o_dict['ss_base_model_version'] = 'sd_2.1' - - elif self.model_config.is_xl: - o_dict['ss_base_model_version'] = 'sdxl_1.0' - elif self.model_config.is_flux: - o_dict['ss_base_model_version'] = 'flux.1' - elif self.model_config.is_lumina2: - o_dict['ss_base_model_version'] = 'lumina2' - else: - o_dict['ss_base_model_version'] = 'sd_1.5' + o_dict['ss_base_model_version'] = self.sd.get_base_model_version() o_dict = add_base_model_info_to_meta( o_dict, diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 61c4bcf6..0d5a821f 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -1486,3 +1486,7 @@ class BaseModel: def get_transformer_block_names(self) -> Optional[List[str]]: # override in child classes to get transformer block names for lora targeting return None + + def get_base_model_version() -> str: + # override in child classes to get the base model version + return "unknown" diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index 04eca827..57b556ed 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -663,3 +663,6 @@ class Wan21(BaseModel): def convert_lora_weights_before_load(self, state_dict): return convert_to_diffusers(state_dict) + + def get_base_model_version(self): + return "wan_2.1" diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index a3f3dcec..c8c01aa2 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -3082,3 +3082,24 @@ class StableDiffusion: def get_transformer_block_names(self) -> Optional[List[str]]: # override in child classes to get transformer block names for lora targeting return None + + def get_base_model_version(self) -> str: + if self.is_pixart: + return 'pixart' + if self.is_v3: + return 'sd_3' + if self.is_auraflow: + return 'auraflow' + if self.is_flux: + return 'flux.1' + if self.is_lumina2: + return 'lumina2' + if self.is_ssd: + return 'ssd' + if self.is_vega: + return 'vega' + if self.is_xl: + return 'sdxl_1.0' + if self.is_v2: + return 'sd_2.1' + return 'sd_1.5'