mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fuse flux schnell assistant adapter in pieces when doing lowvram to drastically speed ip up from minutes to seconds.
This commit is contained in:
@@ -523,11 +523,7 @@ class StableDiffusion:
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
print("Fusing in LoRA")
|
||||
# if doing low vram, do this on the gpu, painfully slow otherwise
|
||||
if self.low_vram:
|
||||
print(" - this process is painfully slow with 'low_vram' enabled. Disable it if possible.")
|
||||
# need the pipe to do this unfortunately for now
|
||||
# we have to fuse in the weights before quantizing
|
||||
# need the pipe for peft
|
||||
pipe: FluxPipeline = FluxPipeline(
|
||||
scheduler=None,
|
||||
text_encoder=None,
|
||||
@@ -537,10 +533,60 @@ class StableDiffusion:
|
||||
vae=None,
|
||||
transformer=transformer,
|
||||
)
|
||||
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
||||
pipe.fuse_lora()
|
||||
# unfortunately, not an easier way with peft
|
||||
pipe.unload_lora_weights()
|
||||
if self.low_vram:
|
||||
# we cannot fuse the loras all at once without ooming in lowvram mode, so we have to do it in parts
|
||||
# we can do it on the cpu but it takes about 5-10 mins vs seconds on the gpu
|
||||
# we are going to separate it into the two transformer blocks one at a time
|
||||
|
||||
lora_state_dict = load_file(self.model_config.lora_path)
|
||||
single_transformer_lora = {}
|
||||
single_block_key = "transformer.single_transformer_blocks."
|
||||
double_transformer_lora = {}
|
||||
double_block_key = "transformer.transformer_blocks."
|
||||
for key, value in lora_state_dict.items():
|
||||
if single_block_key in key:
|
||||
new_key = key.replace(single_block_key, "")
|
||||
single_transformer_lora[new_key] = value
|
||||
elif double_block_key in key:
|
||||
new_key = key.replace(double_block_key, "")
|
||||
double_transformer_lora[new_key] = value
|
||||
else:
|
||||
raise ValueError(f"Unknown lora key: {key}. Cannot load this lora in low vram mode")
|
||||
|
||||
# double blocks
|
||||
transformer.transformer_blocks = transformer.transformer_blocks.to(
|
||||
torch.device(self.quantize_device), dtype=dtype
|
||||
)
|
||||
pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double")
|
||||
pipe.fuse_lora()
|
||||
transformer.transformer_blocks = transformer.transformer_blocks.to(
|
||||
'cpu', dtype=dtype
|
||||
)
|
||||
|
||||
# single blocks
|
||||
transformer.single_transformer_blocks = transformer.single_transformer_blocks.to(
|
||||
torch.device(self.quantize_device), dtype=dtype
|
||||
)
|
||||
pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single")
|
||||
pipe.fuse_lora()
|
||||
transformer.single_transformer_blocks = transformer.single_transformer_blocks.to(
|
||||
'cpu', dtype=dtype
|
||||
)
|
||||
|
||||
# cleanup
|
||||
del single_transformer_lora
|
||||
del double_transformer_lora
|
||||
del lora_state_dict
|
||||
flush()
|
||||
|
||||
else:
|
||||
# need the pipe to do this unfortunately for now
|
||||
# we have to fuse in the weights before quantizing
|
||||
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
||||
pipe.fuse_lora()
|
||||
# unfortunately, not an easier way with peft
|
||||
pipe.unload_lora_weights()
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
quantization_type = qfloat8
|
||||
|
||||
Reference in New Issue
Block a user