mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Adjust hidream lora names to work with comfy
This commit is contained in:
@@ -426,5 +426,20 @@ class HidreamModel(BaseModel):
|
||||
|
||||
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||
return ['double_stream_blocks', 'single_stream_blocks']
|
||||
|
||||
def convert_lora_weights_before_save(self, state_dict):
|
||||
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("transformer.", "diffusion_model.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
|
||||
def convert_lora_weights_before_load(self, state_dict):
|
||||
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("diffusion_model.", "transformer.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
|
||||
|
||||
Reference in New Issue
Block a user