From f3e211d4318fefeaab7295d06a2d5bb31d2a5c63 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Fri, 16 Aug 2024 21:09:14 -0700 Subject: [PATCH] fix bnb lora --- backend/patcher/lora.py | 5 ++--- backend/utils.py | 9 +++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index 782a7b04..51e38494 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -305,7 +305,7 @@ class LoraLoader: for key, current_patches in (tqdm(self.patches.items(), desc='Patching LoRAs') if len(self.patches) > 0 else self.patches): try: - weight = utils.get_attr(self.model, key) + parent_layer, weight = utils.get_attr_with_parent(self.model, key) assert isinstance(weight, torch.nn.Parameter) except: raise ValueError(f"Wrong LoRA Key: {key}") @@ -317,8 +317,7 @@ class LoraLoader: if operations.bnb_avaliable: if hasattr(weight, 'bnb_quantized'): - assert weight.module is not None, 'BNB bad weight without parent layer!' - bnb_layer = weight.module + bnb_layer = parent_layer if weight.bnb_quantized: weight_original_device = weight.device diff --git a/backend/utils.py b/backend/utils.py index d01863e0..0a940af9 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -76,6 +76,15 @@ def get_attr(obj, attr): return obj +def get_attr_with_parent(obj, attr): + attrs = attr.split(".") + parent = obj + for name in attrs: + parent = obj + obj = getattr(obj, name) + return parent, obj + + def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys():