mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-27 09:41:31 +00:00
change patcher method
This commit is contained in:
@@ -190,10 +190,10 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def patch_model(self, device_to=None, patch_weights=True):
|
def patch_model(self, device_to=None, patch_weights=True):
|
||||||
for k in self.object_patches:
|
for k in self.object_patches:
|
||||||
old = getattr(self.model, k)
|
old = ldm_patched.modules.utils.get_attr(self.model, k)
|
||||||
if k not in self.object_patches_backup:
|
if k not in self.object_patches_backup:
|
||||||
self.object_patches_backup[k] = old
|
self.object_patches_backup[k] = old
|
||||||
setattr(self.model, k, self.object_patches[k])
|
ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches[k])
|
||||||
|
|
||||||
if patch_weights:
|
if patch_weights:
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
@@ -378,6 +378,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
keys = list(self.object_patches_backup.keys())
|
keys = list(self.object_patches_backup.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
setattr(self.model, k, self.object_patches_backup[k])
|
ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
|
||||||
|
|
||||||
self.object_patches_backup = {}
|
self.object_patches_backup = {}
|
||||||
|
|||||||
@@ -286,6 +286,12 @@ def set_attr(obj, attr, value):
|
|||||||
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
|
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
|
||||||
del prev
|
del prev
|
||||||
|
|
||||||
|
def set_attr_raw(obj, attr, value):
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs[:-1]:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
setattr(obj, attrs[-1], value)
|
||||||
|
|
||||||
def copy_to_param(obj, attr, value):
|
def copy_to_param(obj, attr, value):
|
||||||
# inplace update tensor instead of replacing it
|
# inplace update tensor instead of replacing it
|
||||||
attrs = attr.split(".")
|
attrs = attr.split(".")
|
||||||
|
|||||||
Reference in New Issue
Block a user