diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index 782a7b04..51e38494 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -305,7 +305,7 @@ class LoraLoader: for key, current_patches in (tqdm(self.patches.items(), desc='Patching LoRAs') if len(self.patches) > 0 else self.patches): try: - weight = utils.get_attr(self.model, key) + parent_layer, weight = utils.get_attr_with_parent(self.model, key) assert isinstance(weight, torch.nn.Parameter) except: raise ValueError(f"Wrong LoRA Key: {key}") @@ -317,8 +317,7 @@ class LoraLoader: if operations.bnb_avaliable: if hasattr(weight, 'bnb_quantized'): - assert weight.module is not None, 'BNB bad weight without parent layer!' - bnb_layer = weight.module + bnb_layer = parent_layer if weight.bnb_quantized: weight_original_device = weight.device diff --git a/backend/utils.py b/backend/utils.py index d01863e0..0a940af9 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -76,6 +76,15 @@ def get_attr(obj, attr): return obj +def get_attr_with_parent(obj, attr): + attrs = attr.split(".") + parent = obj + for name in attrs: + parent = obj + obj = getattr(obj, name) + return parent, obj + + def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys():