diff --git a/ldm_patched/modules/model_patcher.py b/ldm_patched/modules/model_patcher.py index 4ae462cd..d1093dc6 100644 --- a/ldm_patched/modules/model_patcher.py +++ b/ldm_patched/modules/model_patcher.py @@ -190,10 +190,10 @@ class ModelPatcher: def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: - old = getattr(self.model, k) + old = ldm_patched.modules.utils.get_attr(self.model, k) if k not in self.object_patches_backup: self.object_patches_backup[k] = old - setattr(self.model, k, self.object_patches[k]) + ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches[k]) if patch_weights: model_sd = self.model_state_dict() @@ -378,6 +378,6 @@ class ModelPatcher: keys = list(self.object_patches_backup.keys()) for k in keys: - setattr(self.model, k, self.object_patches_backup[k]) + ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches_backup[k]) self.object_patches_backup = {} diff --git a/ldm_patched/modules/utils.py b/ldm_patched/modules/utils.py index 5d00dcc0..06a3c826 100644 --- a/ldm_patched/modules/utils.py +++ b/ldm_patched/modules/utils.py @@ -286,6 +286,12 @@ def set_attr(obj, attr, value): setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) del prev +def set_attr_raw(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + setattr(obj, attrs[-1], value) + def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".")