Use peft format for flux loras so they are compatible with diffusers. allow loading an assistant lora

This commit is contained in:
Jaret Burkett
2024-08-05 14:34:37 -06:00
parent edb7e827ee
commit 187663ab55
4 changed files with 87 additions and 6 deletions

View File

@@ -616,7 +616,17 @@ class StableDiffusion:
if self.model_config.lora_path is not None:
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
pipe.fuse_lora()
self.unet.fuse_lora()
# unfortunately, not an easier way with peft
pipe.unload_lora_weights()
if self.model_config.assistant_lora_path is not None:
if self.model_config.lora_path is not None:
raise ValueError("Cannot have both lora and assistant lora")
print("Loading assistant lora")
pipe.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
pipe.fuse_lora(lora_scale=1.0)
# unfortunately, not an easier way with peft
pipe.unload_lora_weights()
self.tokenizer = tokenizer
self.text_encoder = text_encoder
@@ -690,7 +700,15 @@ class StableDiffusion:
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
):
merge_multiplier = 1.0
# sample_folder = os.path.join(self.save_root, 'samples')
# if using assistant, unfuse it
if self.model_config.assistant_lora_path is not None:
print("Unloading asistant lora")
# unfortunately, not an easier way with peft
self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
self.pipeline.fuse_lora(lora_scale=-1.0)
self.pipeline.unload_lora_weights()
if self.network is not None:
self.network.eval()
network = self.network
@@ -1162,6 +1180,14 @@ class StableDiffusion:
network.merge_out(merge_multiplier)
# self.tokenizer.to(original_device_dict['tokenizer'])
# refuse loras
if self.model_config.assistant_lora_path is not None:
print("Loading asistant lora")
# unfortunately, not an easier way with peft
self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
self.pipeline.fuse_lora(lora_scale=1.0)
self.pipeline.unload_lora_weights()
def get_latent_noise(
self,
height=None,