Handle inpainting training for control_lora adapter

This commit is contained in:
Jaret Burkett
2025-03-24 13:17:47 -06:00
parent f10937e6da
commit 45be82d5d6
9 changed files with 257 additions and 23 deletions

View File

@@ -46,15 +46,24 @@ class ImgEmbedder(torch.nn.Module):
cls,
model: FluxTransformer2DModel,
adapter: 'ControlLoraAdapter',
num_control_images=1
num_control_images=1,
has_inpainting_input=False
):
if model.__class__.__name__ == 'FluxTransformer2DModel':
if model.__class__.__name__ == 'FluxTransformer2DModel':
num_adapter_in_channels = model.x_embedder.in_features * num_control_images
if has_inpainting_input:
# inpainting has the mask before packing latents. it is normally 16 ch + 1ch mask
# packed it is 64ch + 4ch mask
# so we need to add 4 to the input channels
num_adapter_in_channels += 4
x_embedder: torch.nn.Linear = model.x_embedder
img_embedder = cls(
adapter,
orig_layer=x_embedder,
in_channels=x_embedder.in_features * num_control_images,
out_channels=x_embedder.out_features,
in_channels=num_adapter_in_channels,
out_channels=x_embedder.out_features,
)
# hijack the forward method
@@ -181,7 +190,8 @@ class ControlLoraAdapter(torch.nn.Module):
self.x_embedder = ImgEmbedder.from_model(
sd.unet,
self,
num_control_images=config.num_control_images
num_control_images=config.num_control_images,
has_inpainting_input=config.has_inpainting_input
)
self.x_embedder.to(self.device_torch)