mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Handle inpainting training for control_lora adapter
This commit is contained in:
@@ -16,6 +16,10 @@ from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
|
||||
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
|
||||
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
from PIL import Image
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -1428,6 +1432,22 @@ class FluxWithCFGPipeline(FluxPipeline):
|
||||
|
||||
|
||||
class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler,
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
text_encoder_2,
|
||||
tokenizer_2,
|
||||
transformer,
|
||||
do_inpainting=False,
|
||||
num_controls=1,
|
||||
):
|
||||
self.do_inpainting = do_inpainting
|
||||
self.num_controls = num_controls
|
||||
super().__init__(scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -1581,6 +1601,17 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
# 4. Prepare latent variables
|
||||
# num_channels_latents = self.transformer.config.in_channels // 8
|
||||
num_channels_latents = 128 // 8
|
||||
|
||||
# pull mask off control image if there is one it is a pil image
|
||||
mask = None
|
||||
if control_image is not None and self.do_inpainting and control_image.mode == "RGBA":
|
||||
control_img_array = np.array(control_image)
|
||||
mask = control_img_array[:, :, 3:4]
|
||||
# scale it to 0 - 1
|
||||
mask = mask / 255.0
|
||||
# multiply rgb by mask
|
||||
control_img_array = control_img_array[:, :, :3] * mask
|
||||
control_image = Image.fromarray(control_img_array.astype(np.uint8))
|
||||
|
||||
control_image = self.prepare_image(
|
||||
image=control_image,
|
||||
@@ -1593,14 +1624,28 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
)
|
||||
|
||||
if control_image.ndim == 4:
|
||||
num_control_channels = num_channels_latents
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator)
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
if mask is not None:
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
mask = transform(mask).to(device, dtype=control_image.dtype).unsqueeze(0)
|
||||
# resize mask to match control image
|
||||
mask = F.interpolate(mask, size=(control_image.shape[2], control_image.shape[3]), mode="bilinear", align_corners=False)
|
||||
mask = mask.to(device)
|
||||
# invert mask
|
||||
mask = 1 - mask
|
||||
control_image = torch.cat([control_image, mask], dim=1)
|
||||
num_control_channels += 1
|
||||
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
num_control_channels,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
@@ -1642,9 +1687,6 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# flux has 64 input channels.
|
||||
total_controls = (self.transformer.config.in_channels // 64) - 1
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -1652,7 +1694,16 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
control_image_list = [torch.zeros_like(latents) for _ in range(total_controls)]
|
||||
control_image_list = []
|
||||
for idx in range(self.num_controls):
|
||||
if idx == 0 and self.do_inpainting:
|
||||
ctrl = torch.zeros_like(latents)
|
||||
# do ones for mask and zeros for image
|
||||
ctrl = torch.cat([ctrl, torch.ones_like(ctrl[:, :, :4])], dim=2)
|
||||
control_image_list.append(ctrl)
|
||||
else:
|
||||
control_image_list.append(torch.zeros_like(latents))
|
||||
|
||||
control_image_list[control_image_idx] = control_image
|
||||
|
||||
latent_model_input = torch.cat([latents] + control_image_list, dim=2)
|
||||
|
||||
Reference in New Issue
Block a user