mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Upodate model tag on loras
This commit is contained in:
@@ -1486,3 +1486,7 @@ class BaseModel:
|
||||
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||
# override in child classes to get transformer block names for lora targeting
|
||||
return None
|
||||
|
||||
def get_base_model_version() -> str:
|
||||
# override in child classes to get the base model version
|
||||
return "unknown"
|
||||
|
||||
@@ -663,3 +663,6 @@ class Wan21(BaseModel):
|
||||
|
||||
def convert_lora_weights_before_load(self, state_dict):
|
||||
return convert_to_diffusers(state_dict)
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "wan_2.1"
|
||||
|
||||
@@ -3082,3 +3082,24 @@ class StableDiffusion:
|
||||
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||
# override in child classes to get transformer block names for lora targeting
|
||||
return None
|
||||
|
||||
def get_base_model_version(self) -> str:
|
||||
if self.is_pixart:
|
||||
return 'pixart'
|
||||
if self.is_v3:
|
||||
return 'sd_3'
|
||||
if self.is_auraflow:
|
||||
return 'auraflow'
|
||||
if self.is_flux:
|
||||
return 'flux.1'
|
||||
if self.is_lumina2:
|
||||
return 'lumina2'
|
||||
if self.is_ssd:
|
||||
return 'ssd'
|
||||
if self.is_vega:
|
||||
return 'vega'
|
||||
if self.is_xl:
|
||||
return 'sdxl_1.0'
|
||||
if self.is_v2:
|
||||
return 'sd_2.1'
|
||||
return 'sd_1.5'
|
||||
|
||||
Reference in New Issue
Block a user