small bug fixes

This commit is contained in:
Jaret Burkett
2025-03-30 20:09:40 -06:00
parent 58861005a5
commit 5ea19b6292
3 changed files with 10 additions and 3 deletions

View File

@@ -242,7 +242,7 @@ class BaseModel:
# flux packs this again,
if self.is_flux:
divisibility = divisibility * 4
divisibility = divisibility * 2
return divisibility
# these must be implemented in child classes

View File

@@ -77,6 +77,13 @@ class InOutModule(torch.nn.Module):
model.config.out_channels = num_channels
model.config["out_channels"] = num_channels
# if the shape matches, copy the weights
if x_embedder.weight.shape == in_out_module.x_embedder.weight.shape:
in_out_module.x_embedder.weight.data = x_embedder.weight.data.clone().float()
in_out_module.x_embedder.bias.data = x_embedder.bias.data.clone().float()
in_out_module.proj_out.weight.data = proj_out.weight.data.clone().float()
in_out_module.proj_out.bias.data = proj_out.bias.data.clone().float()
# replace the vae of the model
sd = adapter.sd_ref()
sd.vae = AutoencoderPixelMixer(
@@ -160,7 +167,7 @@ class SubpixelAdapter(torch.nn.Module):
network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs
if hasattr(sd, 'target_lora_modules'):
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
network_kwargs['target_lin_modules'] = sd.target_lora_modules
if 'ignore_if_contains' not in network_kwargs:
network_kwargs['ignore_if_contains'] = []

View File

@@ -257,7 +257,7 @@ class StableDiffusion:
# flux packs this again,
if self.is_flux:
divisibility = divisibility * 4
divisibility = divisibility * 2
return divisibility