Support LoRAs for Q8/Q5/Q4 GGUF Models

what a crazy night of math
This commit is contained in:
layerdiffusion
2024-08-15 05:34:33 -07:00
parent fd0d25ba8a
commit 1bd6cf0e0c
5 changed files with 149 additions and 37 deletions

View File

@@ -4,10 +4,12 @@
# are from Forge, implemented from scratch (after forge-v1.0.1), and may have
# certain level of differences.
import time
import torch
import copy
import inspect
from tqdm import tqdm
from backend import memory_management, utils, operations
from backend.patcher.lora import merge_lora_to_model_weight
@@ -237,6 +239,8 @@ class ModelPatcher:
return sd
def forge_patch_model(self, target_device=None):
execution_start_time = time.perf_counter()
for k, item in self.object_patches.items():
old = utils.get_attr(self.model, k)
@@ -245,13 +249,16 @@ class ModelPatcher:
utils.set_attr_raw(self.model, k, item)
for key, current_patches in self.patches.items():
for key, current_patches in (tqdm(self.patches.items(), desc='Patching LoRAs to Diffusion Model') if len(self.patches) > 0 else self.patches):
try:
weight = utils.get_attr(self.model, key)
assert isinstance(weight, torch.nn.Parameter)
except:
raise ValueError(f"Wrong LoRA Key: {key}")
weight_original_device = weight.device
lora_computation_device = weight.device
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device)
@@ -262,8 +269,6 @@ class ModelPatcher:
assert weight.module is not None, 'BNB bad weight without parent layer!'
bnb_layer = weight.module
if weight.bnb_quantized:
weight_original_device = weight.device
if target_device is not None:
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
weight = weight.to(target_device)
@@ -272,35 +277,56 @@ class ModelPatcher:
from backend.operations_bnb import functional_dequantize_4bit
weight = functional_dequantize_4bit(weight)
if target_device is None:
weight = weight.to(device=weight_original_device)
else:
weight = weight.data
if target_device is None:
weight = weight.to(device=lora_computation_device, non_blocking=memory_management.device_supports_non_blocking(lora_computation_device))
else:
weight = weight.to(device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device))
gguf_cls, gguf_type, gguf_real_shape = None, None, None
if hasattr(weight, 'is_gguf'):
raise NotImplementedError('LoRAs for GGUF model are under construction!')
from backend.operations_gguf import dequantize_tensor
gguf_cls = weight.gguf_cls
gguf_type = weight.gguf_type
gguf_real_shape = weight.gguf_real_shape
weight = dequantize_tensor(weight)
weight_original_dtype = weight.dtype
to_args = dict(dtype=torch.float32)
weight = weight.to(dtype=torch.float32, non_blocking=memory_management.device_supports_non_blocking(weight.device))
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
if target_device is not None:
to_args['device'] = target_device
to_args['non_blocking'] = memory_management.device_supports_non_blocking(target_device)
weight = weight.to(**to_args)
out_weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
if target_device is None:
weight = weight.to(device=weight_original_device, non_blocking=memory_management.device_supports_non_blocking(weight_original_device))
if bnb_layer is not None:
bnb_layer.reload_weight(out_weight)
bnb_layer.reload_weight(weight)
continue
utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False))
if gguf_cls is not None:
from backend.utils import ParameterGGUF
weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape)
utils.set_attr_raw(self.model, key, ParameterGGUF.make(
data=weight,
gguf_type=gguf_type,
gguf_cls=gguf_cls,
gguf_real_shape=gguf_real_shape
))
continue
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
if target_device is not None:
self.model.to(target_device)
self.current_device = target_device
moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.1:
print(f'LoRA patching has taken {moving_time:.2f} seconds')
return self.model
def forge_unpatch_model(self, target_device=None):