mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-12 22:19:48 +00:00
Added bluring to mask for flex2
This commit is contained in:
@@ -2,6 +2,7 @@ import os
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import yaml
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
@@ -37,6 +38,15 @@ scheduler_config = {
|
||||
}
|
||||
|
||||
|
||||
def random_blur(img, min_kernel_size=3, max_kernel_size=23, p=0.5):
|
||||
if random.random() < p:
|
||||
kernel_size = random.randint(min_kernel_size, max_kernel_size)
|
||||
# make sure it is odd
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
img = torchvision.transforms.functional.gaussian_blur(img, kernel_size=kernel_size)
|
||||
return img
|
||||
|
||||
class Flex2(BaseModel):
|
||||
arch = "flex2"
|
||||
|
||||
@@ -66,6 +76,8 @@ class Flex2(BaseModel):
|
||||
self.inpaint_dropout = model_config.model_kwargs.get('inpaint_dropout', 0.0)
|
||||
self.control_dropout = model_config.model_kwargs.get('control_dropout', 0.0)
|
||||
self.inpaint_random_chance = model_config.model_kwargs.get('inpaint_random_chance', 0.0)
|
||||
self.random_blur_mask = model_config.model_kwargs.get('random_blur_mask', False)
|
||||
self.do_random_inpainting = model_config.model_kwargs.get('do_random_inpainting', False)
|
||||
|
||||
# static method to get the noise scheduler
|
||||
@staticmethod
|
||||
@@ -370,13 +382,17 @@ class Flex2(BaseModel):
|
||||
do_dropout = random.random() < self.inpaint_dropout if self.inpaint_dropout > 0.0 else False
|
||||
# do random mask if we dont have one
|
||||
inpaint_tensor = batch.inpaint_tensor
|
||||
if inpaint_tensor is None and batch.mask_tensor is not None:
|
||||
# we have a mask tensor, use it
|
||||
inpaint_tensor = batch.mask_tensor
|
||||
|
||||
if self.inpaint_random_chance > 0.0:
|
||||
do_random = random.random() < self.inpaint_random_chance
|
||||
if do_random:
|
||||
# force a random tensor
|
||||
inpaint_tensor = None
|
||||
|
||||
if inpaint_tensor is None and not do_dropout:
|
||||
if inpaint_tensor is None and not do_dropout and self.do_random_inpainting:
|
||||
# generate a random one since we dont have one
|
||||
# this will make random blobs, invert the blobs for now as we normanlly inpaint the alpha
|
||||
inpaint_tensor = 1 - generate_random_mask(
|
||||
@@ -388,21 +404,37 @@ class Flex2(BaseModel):
|
||||
if inpaint_tensor is not None and not do_dropout:
|
||||
|
||||
if inpaint_tensor.shape[1] == 4:
|
||||
# get just the mask
|
||||
# get just the mask
|
||||
inpainting_tensor_mask = inpaint_tensor[:, 3:4, :, :].to(latents.device, dtype=latents.dtype)
|
||||
elif inpaint_tensor.shape[1] == 3:
|
||||
# rgb mask. Just get one channel
|
||||
inpainting_tensor_mask = inpaint_tensor[:, 0:1, :, :].to(latents.device, dtype=latents.dtype)
|
||||
# mask is 0-1 with 1 being inpaint area, we need to invert it for now, it is re inverted later
|
||||
inpaint_tensor = 1 - inpaint_tensor
|
||||
else:
|
||||
inpainting_tensor_mask = inpaint_tensor
|
||||
|
||||
# # use our batch latents so we cna avoid ancoding again
|
||||
# # use our batch latents so we cna avoid encoding again
|
||||
inpainting_latent = batch.latents
|
||||
|
||||
# resize the mask to match the new encoded size
|
||||
inpainting_tensor_mask = F.interpolate(inpainting_tensor_mask, size=(inpainting_latent.shape[2], inpainting_latent.shape[3]), mode='bilinear')
|
||||
inpainting_tensor_mask = inpainting_tensor_mask.to(latents.device, latents.dtype)
|
||||
|
||||
if self.random_blur_mask:
|
||||
# blur the mask
|
||||
# Give it a channel dim of 1
|
||||
inpainting_tensor_mask = inpainting_tensor_mask.unsqueeze(1)
|
||||
# we are at latent size, so keep kernel smaller
|
||||
inpainting_tensor_mask = random_blur(
|
||||
inpainting_tensor_mask,
|
||||
min_kernel_size=3,
|
||||
max_kernel_size=8,
|
||||
p=0.5
|
||||
)
|
||||
# remove the channel dim
|
||||
inpainting_tensor_mask = inpainting_tensor_mask.squeeze(1)
|
||||
|
||||
do_mask_invert = False
|
||||
if self.invert_inpaint_mask_chance > 0.0:
|
||||
do_mask_invert = random.random() < self.invert_inpaint_mask_chance
|
||||
|
||||
Reference in New Issue
Block a user