diff --git a/extensions_built_in/flex2/flex2.py b/extensions_built_in/flex2/flex2.py index 0f966c4b..3340f64a 100644 --- a/extensions_built_in/flex2/flex2.py +++ b/extensions_built_in/flex2/flex2.py @@ -2,6 +2,7 @@ import os from typing import TYPE_CHECKING, List import torch +import torchvision import yaml from toolkit import train_tools from toolkit.config_modules import GenerateImageConfig, ModelConfig @@ -37,6 +38,15 @@ scheduler_config = { } +def random_blur(img, min_kernel_size=3, max_kernel_size=23, p=0.5): + if random.random() < p: + kernel_size = random.randint(min_kernel_size, max_kernel_size) + # make sure it is odd + if kernel_size % 2 == 0: + kernel_size += 1 + img = torchvision.transforms.functional.gaussian_blur(img, kernel_size=kernel_size) + return img + class Flex2(BaseModel): arch = "flex2" @@ -66,6 +76,8 @@ class Flex2(BaseModel): self.inpaint_dropout = model_config.model_kwargs.get('inpaint_dropout', 0.0) self.control_dropout = model_config.model_kwargs.get('control_dropout', 0.0) self.inpaint_random_chance = model_config.model_kwargs.get('inpaint_random_chance', 0.0) + self.random_blur_mask = model_config.model_kwargs.get('random_blur_mask', False) + self.do_random_inpainting = model_config.model_kwargs.get('do_random_inpainting', False) # static method to get the noise scheduler @staticmethod @@ -370,13 +382,17 @@ class Flex2(BaseModel): do_dropout = random.random() < self.inpaint_dropout if self.inpaint_dropout > 0.0 else False # do random mask if we dont have one inpaint_tensor = batch.inpaint_tensor + if inpaint_tensor is None and batch.mask_tensor is not None: + # we have a mask tensor, use it + inpaint_tensor = batch.mask_tensor + if self.inpaint_random_chance > 0.0: do_random = random.random() < self.inpaint_random_chance if do_random: # force a random tensor inpaint_tensor = None - if inpaint_tensor is None and not do_dropout: + if inpaint_tensor is None and not do_dropout and self.do_random_inpainting: # generate a random one since we dont have one # this will make random blobs, invert the blobs for now as we normanlly inpaint the alpha inpaint_tensor = 1 - generate_random_mask( @@ -388,21 +404,37 @@ class Flex2(BaseModel): if inpaint_tensor is not None and not do_dropout: if inpaint_tensor.shape[1] == 4: - # get just the mask + # get just the mask inpainting_tensor_mask = inpaint_tensor[:, 3:4, :, :].to(latents.device, dtype=latents.dtype) elif inpaint_tensor.shape[1] == 3: # rgb mask. Just get one channel inpainting_tensor_mask = inpaint_tensor[:, 0:1, :, :].to(latents.device, dtype=latents.dtype) + # mask is 0-1 with 1 being inpaint area, we need to invert it for now, it is re inverted later + inpaint_tensor = 1 - inpaint_tensor else: inpainting_tensor_mask = inpaint_tensor - # # use our batch latents so we cna avoid ancoding again + # # use our batch latents so we cna avoid encoding again inpainting_latent = batch.latents # resize the mask to match the new encoded size inpainting_tensor_mask = F.interpolate(inpainting_tensor_mask, size=(inpainting_latent.shape[2], inpainting_latent.shape[3]), mode='bilinear') inpainting_tensor_mask = inpainting_tensor_mask.to(latents.device, latents.dtype) + if self.random_blur_mask: + # blur the mask + # Give it a channel dim of 1 + inpainting_tensor_mask = inpainting_tensor_mask.unsqueeze(1) + # we are at latent size, so keep kernel smaller + inpainting_tensor_mask = random_blur( + inpainting_tensor_mask, + min_kernel_size=3, + max_kernel_size=8, + p=0.5 + ) + # remove the channel dim + inpainting_tensor_mask = inpainting_tensor_mask.squeeze(1) + do_mask_invert = False if self.invert_inpaint_mask_chance > 0.0: do_mask_invert = random.random() < self.invert_inpaint_mask_chance