Implement some rethinking about LoRA system

1. Add an option to allow users to use UNet in fp8/gguf but lora in fp16.
2. All FP16 loras do not need patch. Others will only patch again when lora weight change.
3. FP8 unet + fp16 lora are available (somewhat only available) in Forge now. This also solves some “LoRA too subtle” problems.
4. Significantly speed up all gguf models (in Async mode) by using independent thread (CUDA stream) to compute and dequant at the same time, even when low-bit weights are already on GPU.
5. View “online lora” as a module similar to ControlLoRA so that it is moved to GPU together with model when sampling, achieving significant speedup and perfect low VRAM management simultaneously.
This commit is contained in:
layerdiffusion
2024-08-19 04:31:00 -07:00
parent e5f213c21e
commit d38e560e42
11 changed files with 200 additions and 159 deletions

View File

@@ -5,6 +5,7 @@ import gradio as gr
from gradio.context import Context
from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils, paths
from backend import memory_management, stream
from backend.args import dynamic_args
total_vram = int(memory_management.total_vram)
@@ -21,11 +22,14 @@ ui_forge_pin_shared_memory: gr.Radio = None
ui_forge_inference_memory: gr.Slider = None
forge_unet_storage_dtype_options = {
'Automatic': None,
'bnb-nf4': 'nf4',
'float8-e4m3fn': torch.float8_e4m3fn,
'bnb-fp4': 'fp4',
'float8-e5m2': torch.float8_e5m2,
'Automatic': (None, False),
'Automatic (fp16 LoRA)': (None, True),
'bnb-nf4': ('nf4', False),
'float8-e4m3fn': (torch.float8_e4m3fn, False),
'float8-e4m3fn (fp16 LoRA)': (torch.float8_e4m3fn, True),
'bnb-fp4': ('fp4', False),
'float8-e5m2': (torch.float8_e5m2, False),
'float8-e5m2 (fp16 LoRA)': (torch.float8_e5m2, True),
}
module_list = {}
@@ -180,10 +184,14 @@ def refresh_model_loading_parameters():
checkpoint_info = select_checkpoint()
unet_storage_dtype, lora_fp16 = forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, (None, False))
dynamic_args['online_lora'] = lora_fp16
model_data.forge_loading_parameters = dict(
checkpoint_info=checkpoint_info,
additional_modules=shared.opts.forge_additional_modules,
unet_storage_dtype=forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, None)
unet_storage_dtype=unet_storage_dtype
)
print(f'Model selected: {model_data.forge_loading_parameters}')