mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Adjust hidream lora names to work with comfy
This commit is contained in:
@@ -426,5 +426,20 @@ class HidreamModel(BaseModel):
|
|||||||
|
|
||||||
def get_transformer_block_names(self) -> Optional[List[str]]:
|
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||||
return ['double_stream_blocks', 'single_stream_blocks']
|
return ['double_stream_blocks', 'single_stream_blocks']
|
||||||
|
|
||||||
|
def convert_lora_weights_before_save(self, state_dict):
|
||||||
|
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
|
||||||
|
new_sd = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
new_key = key.replace("transformer.", "diffusion_model.")
|
||||||
|
new_sd[new_key] = value
|
||||||
|
return new_sd
|
||||||
|
|
||||||
|
def convert_lora_weights_before_load(self, state_dict):
|
||||||
|
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
|
||||||
|
new_sd = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
new_key = key.replace("diffusion_model.", "transformer.")
|
||||||
|
new_sd[new_key] = value
|
||||||
|
return new_sd
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user