Upodate model tag on loras

This commit is contained in:
Jaret Burkett
2025-04-19 10:41:27 -06:00
parent d455e76c4f
commit 77001ee77f
6 changed files with 35 additions and 12 deletions

View File

@@ -386,3 +386,6 @@ class ChromaModel(BaseModel):
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd
def get_base_model_version(self):
return "chroma"

View File

@@ -443,3 +443,6 @@ class HidreamModel(BaseModel):
new_sd[new_key] = value
return new_sd
def get_base_model_version(self):
return "hidream_i1"

View File

@@ -339,18 +339,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
o_dict = OrderedDict({
"training_info": self.get_training_info()
})
if self.model_config.is_v2:
o_dict['ss_v2'] = True
o_dict['ss_base_model_version'] = 'sd_2.1'
elif self.model_config.is_xl:
o_dict['ss_base_model_version'] = 'sdxl_1.0'
elif self.model_config.is_flux:
o_dict['ss_base_model_version'] = 'flux.1'
elif self.model_config.is_lumina2:
o_dict['ss_base_model_version'] = 'lumina2'
else:
o_dict['ss_base_model_version'] = 'sd_1.5'
o_dict['ss_base_model_version'] = self.sd.get_base_model_version()
o_dict = add_base_model_info_to_meta(
o_dict,

View File

@@ -1486,3 +1486,7 @@ class BaseModel:
def get_transformer_block_names(self) -> Optional[List[str]]:
# override in child classes to get transformer block names for lora targeting
return None
def get_base_model_version() -> str:
# override in child classes to get the base model version
return "unknown"

View File

@@ -663,3 +663,6 @@ class Wan21(BaseModel):
def convert_lora_weights_before_load(self, state_dict):
return convert_to_diffusers(state_dict)
def get_base_model_version(self):
return "wan_2.1"

View File

@@ -3082,3 +3082,24 @@ class StableDiffusion:
def get_transformer_block_names(self) -> Optional[List[str]]:
# override in child classes to get transformer block names for lora targeting
return None
def get_base_model_version(self) -> str:
if self.is_pixart:
return 'pixart'
if self.is_v3:
return 'sd_3'
if self.is_auraflow:
return 'auraflow'
if self.is_flux:
return 'flux.1'
if self.is_lumina2:
return 'lumina2'
if self.is_ssd:
return 'ssd'
if self.is_vega:
return 'vega'
if self.is_xl:
return 'sdxl_1.0'
if self.is_v2:
return 'sd_2.1'
return 'sd_1.5'