revise lora patching

This commit is contained in:
layerdiffusion
2024-08-22 11:59:43 -07:00
parent 28ad046447
commit 2ab19f7f1c
2 changed files with 28 additions and 9 deletions

View File

@@ -376,7 +376,7 @@ def sampling_prepare(unet, x):
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
additional_model_patchers += unet.controlnet_linked_list.get_models()
if dynamic_args.get('online_lora', False):
if unet.lora_loader.online_mode:
lora_memory = utils.nested_compute_size(unet.lora_loader.patches)
additional_inference_memory += lora_memory
@@ -384,11 +384,9 @@ def sampling_prepare(unet, x):
models=[unet] + additional_model_patchers,
memory_required=unet_inference_memory + additional_inference_memory)
if dynamic_args.get('online_lora', False):
if unet.lora_loader.online_mode:
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device)
unet.lora_loader.patches = {}
real_model = unet.model
percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p)
@@ -400,6 +398,8 @@ def sampling_prepare(unet, x):
def sampling_cleanup(unet):
if unet.lora_loader.online_mode:
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.offload_device)
for cnet in unet.list_controlnets():
cnet.cleanup()
cleanup_cache()