diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 852dcebc..d71a5f01 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -523,11 +523,7 @@ class StableDiffusion: if self.model_config.lora_path is not None: print("Fusing in LoRA") - # if doing low vram, do this on the gpu, painfully slow otherwise - if self.low_vram: - print(" - this process is painfully slow with 'low_vram' enabled. Disable it if possible.") - # need the pipe to do this unfortunately for now - # we have to fuse in the weights before quantizing + # need the pipe for peft pipe: FluxPipeline = FluxPipeline( scheduler=None, text_encoder=None, @@ -537,10 +533,60 @@ class StableDiffusion: vae=None, transformer=transformer, ) - pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") - pipe.fuse_lora() - # unfortunately, not an easier way with peft - pipe.unload_lora_weights() + if self.low_vram: + # we cannot fuse the loras all at once without ooming in lowvram mode, so we have to do it in parts + # we can do it on the cpu but it takes about 5-10 mins vs seconds on the gpu + # we are going to separate it into the two transformer blocks one at a time + + lora_state_dict = load_file(self.model_config.lora_path) + single_transformer_lora = {} + single_block_key = "transformer.single_transformer_blocks." + double_transformer_lora = {} + double_block_key = "transformer.transformer_blocks." + for key, value in lora_state_dict.items(): + if single_block_key in key: + new_key = key.replace(single_block_key, "") + single_transformer_lora[new_key] = value + elif double_block_key in key: + new_key = key.replace(double_block_key, "") + double_transformer_lora[new_key] = value + else: + raise ValueError(f"Unknown lora key: {key}. Cannot load this lora in low vram mode") + + # double blocks + transformer.transformer_blocks = transformer.transformer_blocks.to( + torch.device(self.quantize_device), dtype=dtype + ) + pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double") + pipe.fuse_lora() + transformer.transformer_blocks = transformer.transformer_blocks.to( + 'cpu', dtype=dtype + ) + + # single blocks + transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( + torch.device(self.quantize_device), dtype=dtype + ) + pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single") + pipe.fuse_lora() + transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( + 'cpu', dtype=dtype + ) + + # cleanup + del single_transformer_lora + del double_transformer_lora + del lora_state_dict + flush() + + else: + # need the pipe to do this unfortunately for now + # we have to fuse in the weights before quantizing + pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") + pipe.fuse_lora() + # unfortunately, not an easier way with peft + pipe.unload_lora_weights() + flush() if self.model_config.quantize: quantization_type = qfloat8