diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index 77ec89ed..9c8348f2 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -68,8 +68,13 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32): # Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446 - weight_original_dtype = weight.dtype - weight = weight.to(dtype=computation_dtype) + weight_dtype_backup = None + + if computation_dtype == weight.dtype: + weight = weight.clone() + else: + weight_dtype_backup = weight.dtype + weight = weight.to(dtype=computation_dtype) for p in patches: strength = p[0] @@ -253,7 +258,9 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t if old_weight is not None: weight = old_weight - weight = weight.to(dtype=weight_original_dtype) + if weight_dtype_backup is not None: + weight = weight.to(dtype=weight_dtype_backup) + return weight @@ -267,6 +274,7 @@ class LoraLoader: self.backup = {} self.online_backup = [] self.dirty = False + self.online_mode = False def clear_patches(self): self.patches.clear() @@ -296,6 +304,13 @@ class LoraLoader: self.patches[key] = current_patches self.dirty = True + + self.online_mode = dynamic_args.get('online_lora', False) + + if hasattr(self.model, 'storage_dtype'): + if self.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self.online_mode = False + return list(p) @torch.inference_mode() @@ -323,7 +338,11 @@ class LoraLoader: self.backup = {} - online_mode = dynamic_args.get('online_lora', False) + if len(self.patches) > 0: + if self.online_mode: + print('Patching LoRA in on-the-fly.') + else: + print('Patching LoRA by precomputing model weights.') # Patch @@ -337,7 +356,7 @@ class LoraLoader: if key not in self.backup: self.backup[key] = weight.to(device=offload_device) - if online_mode: + if self.online_mode: if not hasattr(parent_layer, 'forge_online_loras'): parent_layer.forge_online_loras = {} diff --git a/backend/sampling/sampling_function.py b/backend/sampling/sampling_function.py index 1c32a09b..dd8b3088 100644 --- a/backend/sampling/sampling_function.py +++ b/backend/sampling/sampling_function.py @@ -376,7 +376,7 @@ def sampling_prepare(unet, x): additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype()) additional_model_patchers += unet.controlnet_linked_list.get_models() - if dynamic_args.get('online_lora', False): + if unet.lora_loader.online_mode: lora_memory = utils.nested_compute_size(unet.lora_loader.patches) additional_inference_memory += lora_memory @@ -384,11 +384,9 @@ def sampling_prepare(unet, x): models=[unet] + additional_model_patchers, memory_required=unet_inference_memory + additional_inference_memory) - if dynamic_args.get('online_lora', False): + if unet.lora_loader.online_mode: utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device) - unet.lora_loader.patches = {} - real_model = unet.model percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p) @@ -400,6 +398,8 @@ def sampling_prepare(unet, x): def sampling_cleanup(unet): + if unet.lora_loader.online_mode: + utils.nested_move_to_device(unet.lora_loader.patches, device=unet.offload_device) for cnet in unet.list_controlnets(): cnet.cleanup() cleanup_cache()