Fuse flux schnell assistant adapter in pieces when doing lowvram to drastically speed ip up from minutes to seconds.

This commit is contained in:
Jaret Burkett
2024-08-18 09:09:11 -06:00
parent 81899310f8
commit f944eeaa4d

View File

@@ -523,11 +523,7 @@ class StableDiffusion:
if self.model_config.lora_path is not None:
print("Fusing in LoRA")
# if doing low vram, do this on the gpu, painfully slow otherwise
if self.low_vram:
print(" - this process is painfully slow with 'low_vram' enabled. Disable it if possible.")
# need the pipe to do this unfortunately for now
# we have to fuse in the weights before quantizing
# need the pipe for peft
pipe: FluxPipeline = FluxPipeline(
scheduler=None,
text_encoder=None,
@@ -537,10 +533,60 @@ class StableDiffusion:
vae=None,
transformer=transformer,
)
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
pipe.fuse_lora()
# unfortunately, not an easier way with peft
pipe.unload_lora_weights()
if self.low_vram:
# we cannot fuse the loras all at once without ooming in lowvram mode, so we have to do it in parts
# we can do it on the cpu but it takes about 5-10 mins vs seconds on the gpu
# we are going to separate it into the two transformer blocks one at a time
lora_state_dict = load_file(self.model_config.lora_path)
single_transformer_lora = {}
single_block_key = "transformer.single_transformer_blocks."
double_transformer_lora = {}
double_block_key = "transformer.transformer_blocks."
for key, value in lora_state_dict.items():
if single_block_key in key:
new_key = key.replace(single_block_key, "")
single_transformer_lora[new_key] = value
elif double_block_key in key:
new_key = key.replace(double_block_key, "")
double_transformer_lora[new_key] = value
else:
raise ValueError(f"Unknown lora key: {key}. Cannot load this lora in low vram mode")
# double blocks
transformer.transformer_blocks = transformer.transformer_blocks.to(
torch.device(self.quantize_device), dtype=dtype
)
pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double")
pipe.fuse_lora()
transformer.transformer_blocks = transformer.transformer_blocks.to(
'cpu', dtype=dtype
)
# single blocks
transformer.single_transformer_blocks = transformer.single_transformer_blocks.to(
torch.device(self.quantize_device), dtype=dtype
)
pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single")
pipe.fuse_lora()
transformer.single_transformer_blocks = transformer.single_transformer_blocks.to(
'cpu', dtype=dtype
)
# cleanup
del single_transformer_lora
del double_transformer_lora
del lora_state_dict
flush()
else:
# need the pipe to do this unfortunately for now
# we have to fuse in the weights before quantizing
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
pipe.fuse_lora()
# unfortunately, not an easier way with peft
pipe.unload_lora_weights()
flush()
if self.model_config.quantize:
quantization_type = qfloat8