Handle inpainting training for control_lora adapter

This commit is contained in:
Jaret Burkett
2025-03-24 13:17:47 -06:00
parent f10937e6da
commit 45be82d5d6
9 changed files with 257 additions and 23 deletions

View File

@@ -16,6 +16,10 @@ from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
from diffusers.image_processor import PipelineImageInput
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
if is_torch_xla_available():
@@ -1428,6 +1432,22 @@ class FluxWithCFGPipeline(FluxPipeline):
class FluxAdvancedControlPipeline(FluxControlPipeline):
def __init__(
self,
scheduler,
vae,
text_encoder,
tokenizer,
text_encoder_2,
tokenizer_2,
transformer,
do_inpainting=False,
num_controls=1,
):
self.do_inpainting = do_inpainting
self.num_controls = num_controls
super().__init__(scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer)
@torch.no_grad()
def __call__(
self,
@@ -1581,6 +1601,17 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
# 4. Prepare latent variables
# num_channels_latents = self.transformer.config.in_channels // 8
num_channels_latents = 128 // 8
# pull mask off control image if there is one it is a pil image
mask = None
if control_image is not None and self.do_inpainting and control_image.mode == "RGBA":
control_img_array = np.array(control_image)
mask = control_img_array[:, :, 3:4]
# scale it to 0 - 1
mask = mask / 255.0
# multiply rgb by mask
control_img_array = control_img_array[:, :, :3] * mask
control_image = Image.fromarray(control_img_array.astype(np.uint8))
control_image = self.prepare_image(
image=control_image,
@@ -1593,14 +1624,28 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
)
if control_image.ndim == 4:
num_control_channels = num_channels_latents
control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator)
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
if mask is not None:
transform = transforms.Compose([
transforms.ToTensor(),
])
mask = transform(mask).to(device, dtype=control_image.dtype).unsqueeze(0)
# resize mask to match control image
mask = F.interpolate(mask, size=(control_image.shape[2], control_image.shape[3]), mode="bilinear", align_corners=False)
mask = mask.to(device)
# invert mask
mask = 1 - mask
control_image = torch.cat([control_image, mask], dim=1)
num_control_channels += 1
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
num_control_channels,
height_control_image,
width_control_image,
)
@@ -1642,9 +1687,6 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# flux has 64 input channels.
total_controls = (self.transformer.config.in_channels // 64) - 1
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -1652,7 +1694,16 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
if self.interrupt:
continue
control_image_list = [torch.zeros_like(latents) for _ in range(total_controls)]
control_image_list = []
for idx in range(self.num_controls):
if idx == 0 and self.do_inpainting:
ctrl = torch.zeros_like(latents)
# do ones for mask and zeros for image
ctrl = torch.cat([ctrl, torch.ones_like(ctrl[:, :, :4])], dim=2)
control_image_list.append(ctrl)
else:
control_image_list.append(torch.zeros_like(latents))
control_image_list[control_image_idx] = control_image
latent_model_input = torch.cat([latents] + control_image_list, dim=2)