mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-01 13:59:47 +00:00
fix bnb lora
This commit is contained in:
@@ -305,7 +305,7 @@ class LoraLoader:
|
||||
|
||||
for key, current_patches in (tqdm(self.patches.items(), desc='Patching LoRAs') if len(self.patches) > 0 else self.patches):
|
||||
try:
|
||||
weight = utils.get_attr(self.model, key)
|
||||
parent_layer, weight = utils.get_attr_with_parent(self.model, key)
|
||||
assert isinstance(weight, torch.nn.Parameter)
|
||||
except:
|
||||
raise ValueError(f"Wrong LoRA Key: {key}")
|
||||
@@ -317,8 +317,7 @@ class LoraLoader:
|
||||
|
||||
if operations.bnb_avaliable:
|
||||
if hasattr(weight, 'bnb_quantized'):
|
||||
assert weight.module is not None, 'BNB bad weight without parent layer!'
|
||||
bnb_layer = weight.module
|
||||
bnb_layer = parent_layer
|
||||
if weight.bnb_quantized:
|
||||
weight_original_device = weight.device
|
||||
|
||||
|
||||
@@ -76,6 +76,15 @@ def get_attr(obj, attr):
|
||||
return obj
|
||||
|
||||
|
||||
def get_attr_with_parent(obj, attr):
|
||||
attrs = attr.split(".")
|
||||
parent = obj
|
||||
for name in attrs:
|
||||
parent = obj
|
||||
obj = getattr(obj, name)
|
||||
return parent, obj
|
||||
|
||||
|
||||
def calculate_parameters(sd, prefix=""):
|
||||
params = 0
|
||||
for k in sd.keys():
|
||||
|
||||
Reference in New Issue
Block a user