only load lora one time

This commit is contained in:
layerdiffusion
2024-08-16 02:02:22 -07:00
parent 243952f364
commit 12369669cf
5 changed files with 162 additions and 140 deletions

View File

@@ -1,9 +1,11 @@
import torch
import time
import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui
import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui
from backend import memory_management
from tqdm import tqdm
from backend import memory_management, utils, operations
class ForgeLoraCollection:
@@ -77,7 +79,7 @@ def merge_lora_to_model_weight(patches, weight, key):
weight *= strength_model
if isinstance(v, list):
v = (calculate_weight(v[1:], v[0].clone(), key),)
v = (merge_lora_to_model_weight(v[1:], v[0].clone(), key),)
patch_type = ''
@@ -238,3 +240,140 @@ def merge_lora_to_model_weight(patches, weight, key):
weight = old_weight
return weight
class LoraLoader:
def __init__(self, model):
self.model = model
self.patches = {}
self.backup = {}
self.dirty = False
def clear_patches(self):
self.patches.clear()
self.dirty = True
return
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
model_sd = self.model.state_dict()
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]
if key in model_sd:
p.add(k)
current_patches = self.patches.get(key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
self.patches[key] = current_patches
self.dirty = True
return list(p)
def refresh(self, target_device=None, offload_device=torch.cpu):
if not self.dirty:
return
self.dirty = False
execution_start_time = time.perf_counter()
# Restore
for k, w in self.backup.items():
if target_device is not None:
w = w.to(device=target_device)
if not isinstance(w, torch.nn.Parameter):
# In very few cases
w = torch.nn.Parameter(w, requires_grad=False)
utils.set_attr_raw(self.model, k, w)
self.backup = {}
# Patch
for key, current_patches in (tqdm(self.patches.items(), desc='Patching LoRAs') if len(self.patches) > 0 else self.patches):
try:
weight = utils.get_attr(self.model, key)
assert isinstance(weight, torch.nn.Parameter)
except:
raise ValueError(f"Wrong LoRA Key: {key}")
if key not in self.backup:
self.backup[key] = weight.to(device=offload_device)
bnb_layer = None
if operations.bnb_avaliable:
if hasattr(weight, 'bnb_quantized'):
assert weight.module is not None, 'BNB bad weight without parent layer!'
bnb_layer = weight.module
if weight.bnb_quantized:
weight_original_device = weight.device
if target_device is not None:
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
weight = weight.to(target_device)
else:
weight = weight.cuda()
from backend.operations_bnb import functional_dequantize_4bit
weight = functional_dequantize_4bit(weight)
if target_device is None:
weight = weight.to(device=weight_original_device)
else:
weight = weight.data
if target_device is not None:
weight = weight.to(device=target_device)
gguf_cls, gguf_type, gguf_real_shape = None, None, None
if hasattr(weight, 'is_gguf'):
from backend.operations_gguf import dequantize_tensor
gguf_cls = weight.gguf_cls
gguf_type = weight.gguf_type
gguf_real_shape = weight.gguf_real_shape
weight = dequantize_tensor(weight)
weight_original_dtype = weight.dtype
weight = weight.to(dtype=torch.float32)
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
if bnb_layer is not None:
bnb_layer.reload_weight(weight)
continue
if gguf_cls is not None:
from backend.operations_gguf import ParameterGGUF
weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape)
utils.set_attr_raw(self.model, key, ParameterGGUF.make(
data=weight,
gguf_type=gguf_type,
gguf_cls=gguf_cls,
gguf_real_shape=gguf_real_shape
))
continue
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
# Time
moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.1:
print(f'LoRA patching has taken {moving_time:.2f} seconds')
return