mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Fuse flux schnell assistant adapter in pieces when doing lowvram to drastically speed ip up from minutes to seconds.
This commit is contained in:
@@ -523,11 +523,7 @@ class StableDiffusion:
|
|||||||
|
|
||||||
if self.model_config.lora_path is not None:
|
if self.model_config.lora_path is not None:
|
||||||
print("Fusing in LoRA")
|
print("Fusing in LoRA")
|
||||||
# if doing low vram, do this on the gpu, painfully slow otherwise
|
# need the pipe for peft
|
||||||
if self.low_vram:
|
|
||||||
print(" - this process is painfully slow with 'low_vram' enabled. Disable it if possible.")
|
|
||||||
# need the pipe to do this unfortunately for now
|
|
||||||
# we have to fuse in the weights before quantizing
|
|
||||||
pipe: FluxPipeline = FluxPipeline(
|
pipe: FluxPipeline = FluxPipeline(
|
||||||
scheduler=None,
|
scheduler=None,
|
||||||
text_encoder=None,
|
text_encoder=None,
|
||||||
@@ -537,10 +533,60 @@ class StableDiffusion:
|
|||||||
vae=None,
|
vae=None,
|
||||||
transformer=transformer,
|
transformer=transformer,
|
||||||
)
|
)
|
||||||
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
if self.low_vram:
|
||||||
pipe.fuse_lora()
|
# we cannot fuse the loras all at once without ooming in lowvram mode, so we have to do it in parts
|
||||||
# unfortunately, not an easier way with peft
|
# we can do it on the cpu but it takes about 5-10 mins vs seconds on the gpu
|
||||||
pipe.unload_lora_weights()
|
# we are going to separate it into the two transformer blocks one at a time
|
||||||
|
|
||||||
|
lora_state_dict = load_file(self.model_config.lora_path)
|
||||||
|
single_transformer_lora = {}
|
||||||
|
single_block_key = "transformer.single_transformer_blocks."
|
||||||
|
double_transformer_lora = {}
|
||||||
|
double_block_key = "transformer.transformer_blocks."
|
||||||
|
for key, value in lora_state_dict.items():
|
||||||
|
if single_block_key in key:
|
||||||
|
new_key = key.replace(single_block_key, "")
|
||||||
|
single_transformer_lora[new_key] = value
|
||||||
|
elif double_block_key in key:
|
||||||
|
new_key = key.replace(double_block_key, "")
|
||||||
|
double_transformer_lora[new_key] = value
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown lora key: {key}. Cannot load this lora in low vram mode")
|
||||||
|
|
||||||
|
# double blocks
|
||||||
|
transformer.transformer_blocks = transformer.transformer_blocks.to(
|
||||||
|
torch.device(self.quantize_device), dtype=dtype
|
||||||
|
)
|
||||||
|
pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double")
|
||||||
|
pipe.fuse_lora()
|
||||||
|
transformer.transformer_blocks = transformer.transformer_blocks.to(
|
||||||
|
'cpu', dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# single blocks
|
||||||
|
transformer.single_transformer_blocks = transformer.single_transformer_blocks.to(
|
||||||
|
torch.device(self.quantize_device), dtype=dtype
|
||||||
|
)
|
||||||
|
pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single")
|
||||||
|
pipe.fuse_lora()
|
||||||
|
transformer.single_transformer_blocks = transformer.single_transformer_blocks.to(
|
||||||
|
'cpu', dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
del single_transformer_lora
|
||||||
|
del double_transformer_lora
|
||||||
|
del lora_state_dict
|
||||||
|
flush()
|
||||||
|
|
||||||
|
else:
|
||||||
|
# need the pipe to do this unfortunately for now
|
||||||
|
# we have to fuse in the weights before quantizing
|
||||||
|
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
||||||
|
pipe.fuse_lora()
|
||||||
|
# unfortunately, not an easier way with peft
|
||||||
|
pipe.unload_lora_weights()
|
||||||
|
flush()
|
||||||
|
|
||||||
if self.model_config.quantize:
|
if self.model_config.quantize:
|
||||||
quantization_type = qfloat8
|
quantization_type = qfloat8
|
||||||
|
|||||||
Reference in New Issue
Block a user