diff --git a/backend/nn/unet.py b/backend/nn/unet.py index c57af1ed..7002ec76 100644 --- a/backend/nn/unet.py +++ b/backend/nn/unet.py @@ -670,6 +670,14 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin): dtype = unet_initial_dtype device = unet_initial_device + self.legacy_config = dict( + num_res_blocks=num_res_blocks, + channel_mult=channel_mult, + transformer_depth=transformer_depth, + transformer_depth_output=transformer_depth_output, + transformer_depth_middle=transformer_depth_middle, + ) + if context_dim is not None: assert use_spatial_transformer diff --git a/ldm_patched/modules/lora.py b/ldm_patched/modules/lora.py index 7f74f119..b51898f5 100644 --- a/ldm_patched/modules/lora.py +++ b/ldm_patched/modules/lora.py @@ -210,7 +210,7 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_map["lora_unet_{}".format(key_lora)] = k - diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(model.model_config.unet_config) + diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(model.diffusion_model.legacy_config) for k in diffusers_keys: if k.endswith(".weight"): unet_key = "diffusion_model.{}".format(diffusers_keys[k])