This commit is contained in:
layerdiffusion
2024-08-01 23:29:27 -07:00
parent f6981339b0
commit 61ca3bc34f
2 changed files with 9 additions and 1 deletions

View File

@@ -670,6 +670,14 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin):
dtype = unet_initial_dtype
device = unet_initial_device
self.legacy_config = dict(
num_res_blocks=num_res_blocks,
channel_mult=channel_mult,
transformer_depth=transformer_depth,
transformer_depth_output=transformer_depth_output,
transformer_depth_middle=transformer_depth_middle,
)
if context_dim is not None:
assert use_spatial_transformer

View File

@@ -210,7 +210,7 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(model.model_config.unet_config)
diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(model.diffusion_model.legacy_config)
for k in diffusers_keys:
if k.endswith(".weight"):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])