mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Handle inpainting training for control_lora adapter
This commit is contained in:
@@ -46,15 +46,24 @@ class ImgEmbedder(torch.nn.Module):
|
||||
cls,
|
||||
model: FluxTransformer2DModel,
|
||||
adapter: 'ControlLoraAdapter',
|
||||
num_control_images=1
|
||||
num_control_images=1,
|
||||
has_inpainting_input=False
|
||||
):
|
||||
if model.__class__.__name__ == 'FluxTransformer2DModel':
|
||||
if model.__class__.__name__ == 'FluxTransformer2DModel':
|
||||
num_adapter_in_channels = model.x_embedder.in_features * num_control_images
|
||||
|
||||
if has_inpainting_input:
|
||||
# inpainting has the mask before packing latents. it is normally 16 ch + 1ch mask
|
||||
# packed it is 64ch + 4ch mask
|
||||
# so we need to add 4 to the input channels
|
||||
num_adapter_in_channels += 4
|
||||
|
||||
x_embedder: torch.nn.Linear = model.x_embedder
|
||||
img_embedder = cls(
|
||||
adapter,
|
||||
orig_layer=x_embedder,
|
||||
in_channels=x_embedder.in_features * num_control_images,
|
||||
out_channels=x_embedder.out_features,
|
||||
in_channels=num_adapter_in_channels,
|
||||
out_channels=x_embedder.out_features,
|
||||
)
|
||||
|
||||
# hijack the forward method
|
||||
@@ -181,7 +190,8 @@ class ControlLoraAdapter(torch.nn.Module):
|
||||
self.x_embedder = ImgEmbedder.from_model(
|
||||
sd.unet,
|
||||
self,
|
||||
num_control_images=config.num_control_images
|
||||
num_control_images=config.num_control_images,
|
||||
has_inpainting_input=config.has_inpainting_input
|
||||
)
|
||||
self.x_embedder.to(self.device_torch)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user