Added bluring to mask for flex2

This commit is contained in:
Jaret Burkett
2025-04-02 07:55:51 -06:00
parent 77763a3e5c
commit ac1ee559c5

View File

@@ -2,6 +2,7 @@ import os
from typing import TYPE_CHECKING, List
import torch
import torchvision
import yaml
from toolkit import train_tools
from toolkit.config_modules import GenerateImageConfig, ModelConfig
@@ -37,6 +38,15 @@ scheduler_config = {
}
def random_blur(img, min_kernel_size=3, max_kernel_size=23, p=0.5):
if random.random() < p:
kernel_size = random.randint(min_kernel_size, max_kernel_size)
# make sure it is odd
if kernel_size % 2 == 0:
kernel_size += 1
img = torchvision.transforms.functional.gaussian_blur(img, kernel_size=kernel_size)
return img
class Flex2(BaseModel):
arch = "flex2"
@@ -66,6 +76,8 @@ class Flex2(BaseModel):
self.inpaint_dropout = model_config.model_kwargs.get('inpaint_dropout', 0.0)
self.control_dropout = model_config.model_kwargs.get('control_dropout', 0.0)
self.inpaint_random_chance = model_config.model_kwargs.get('inpaint_random_chance', 0.0)
self.random_blur_mask = model_config.model_kwargs.get('random_blur_mask', False)
self.do_random_inpainting = model_config.model_kwargs.get('do_random_inpainting', False)
# static method to get the noise scheduler
@staticmethod
@@ -370,13 +382,17 @@ class Flex2(BaseModel):
do_dropout = random.random() < self.inpaint_dropout if self.inpaint_dropout > 0.0 else False
# do random mask if we dont have one
inpaint_tensor = batch.inpaint_tensor
if inpaint_tensor is None and batch.mask_tensor is not None:
# we have a mask tensor, use it
inpaint_tensor = batch.mask_tensor
if self.inpaint_random_chance > 0.0:
do_random = random.random() < self.inpaint_random_chance
if do_random:
# force a random tensor
inpaint_tensor = None
if inpaint_tensor is None and not do_dropout:
if inpaint_tensor is None and not do_dropout and self.do_random_inpainting:
# generate a random one since we dont have one
# this will make random blobs, invert the blobs for now as we normanlly inpaint the alpha
inpaint_tensor = 1 - generate_random_mask(
@@ -388,21 +404,37 @@ class Flex2(BaseModel):
if inpaint_tensor is not None and not do_dropout:
if inpaint_tensor.shape[1] == 4:
# get just the mask
# get just the mask
inpainting_tensor_mask = inpaint_tensor[:, 3:4, :, :].to(latents.device, dtype=latents.dtype)
elif inpaint_tensor.shape[1] == 3:
# rgb mask. Just get one channel
inpainting_tensor_mask = inpaint_tensor[:, 0:1, :, :].to(latents.device, dtype=latents.dtype)
# mask is 0-1 with 1 being inpaint area, we need to invert it for now, it is re inverted later
inpaint_tensor = 1 - inpaint_tensor
else:
inpainting_tensor_mask = inpaint_tensor
# # use our batch latents so we cna avoid ancoding again
# # use our batch latents so we cna avoid encoding again
inpainting_latent = batch.latents
# resize the mask to match the new encoded size
inpainting_tensor_mask = F.interpolate(inpainting_tensor_mask, size=(inpainting_latent.shape[2], inpainting_latent.shape[3]), mode='bilinear')
inpainting_tensor_mask = inpainting_tensor_mask.to(latents.device, latents.dtype)
if self.random_blur_mask:
# blur the mask
# Give it a channel dim of 1
inpainting_tensor_mask = inpainting_tensor_mask.unsqueeze(1)
# we are at latent size, so keep kernel smaller
inpainting_tensor_mask = random_blur(
inpainting_tensor_mask,
min_kernel_size=3,
max_kernel_size=8,
p=0.5
)
# remove the channel dim
inpainting_tensor_mask = inpainting_tensor_mask.squeeze(1)
do_mask_invert = False
if self.invert_inpaint_mask_chance > 0.0:
do_mask_invert = random.random() < self.invert_inpaint_mask_chance