From 5ea19b62922ac95374e221cc75e2305eb1b5222a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 30 Mar 2025 20:09:40 -0600 Subject: [PATCH] small bug fixes --- toolkit/models/base_model.py | 2 +- toolkit/models/subpixel_adapter.py | 9 ++++++++- toolkit/stable_diffusion_model.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 88284e54..0a668960 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -242,7 +242,7 @@ class BaseModel: # flux packs this again, if self.is_flux: - divisibility = divisibility * 4 + divisibility = divisibility * 2 return divisibility # these must be implemented in child classes diff --git a/toolkit/models/subpixel_adapter.py b/toolkit/models/subpixel_adapter.py index 5429265d..ca4ed638 100644 --- a/toolkit/models/subpixel_adapter.py +++ b/toolkit/models/subpixel_adapter.py @@ -77,6 +77,13 @@ class InOutModule(torch.nn.Module): model.config.out_channels = num_channels model.config["out_channels"] = num_channels + # if the shape matches, copy the weights + if x_embedder.weight.shape == in_out_module.x_embedder.weight.shape: + in_out_module.x_embedder.weight.data = x_embedder.weight.data.clone().float() + in_out_module.x_embedder.bias.data = x_embedder.bias.data.clone().float() + in_out_module.proj_out.weight.data = proj_out.weight.data.clone().float() + in_out_module.proj_out.bias.data = proj_out.bias.data.clone().float() + # replace the vae of the model sd = adapter.sd_ref() sd.vae = AutoencoderPixelMixer( @@ -160,7 +167,7 @@ class SubpixelAdapter(torch.nn.Module): network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs if hasattr(sd, 'target_lora_modules'): - network_kwargs['target_lin_modules'] = self.sd.target_lora_modules + network_kwargs['target_lin_modules'] = sd.target_lora_modules if 'ignore_if_contains' not in network_kwargs: network_kwargs['ignore_if_contains'] = [] diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 8039d060..6ddd1235 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -257,7 +257,7 @@ class StableDiffusion: # flux packs this again, if self.is_flux: - divisibility = divisibility * 4 + divisibility = divisibility * 2 return divisibility