mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added mask diffirential mask dialation for flex2. Handle video for the i2v adapter
This commit is contained in:
@@ -16,7 +16,7 @@ from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guid
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from optimum.quanto import freeze, QTensor
|
||||
from toolkit.util.mask import generate_random_mask
|
||||
from toolkit.util.mask import generate_random_mask, random_dialate_mask
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
|
||||
from .pipeline import Flex2Pipeline
|
||||
@@ -77,6 +77,7 @@ class Flex2(BaseModel):
|
||||
self.control_dropout = model_config.model_kwargs.get('control_dropout', 0.0)
|
||||
self.inpaint_random_chance = model_config.model_kwargs.get('inpaint_random_chance', 0.0)
|
||||
self.random_blur_mask = model_config.model_kwargs.get('random_blur_mask', False)
|
||||
self.random_dialate_mask = model_config.model_kwargs.get('random_dialate_mask', False)
|
||||
self.do_random_inpainting = model_config.model_kwargs.get('do_random_inpainting', False)
|
||||
|
||||
# static method to get the noise scheduler
|
||||
@@ -446,6 +447,14 @@ class Flex2(BaseModel):
|
||||
# we are zeroing our the latents in the inpaint area not on the pixel space.
|
||||
inpainting_latent = inpainting_latent * inpainting_tensor_mask
|
||||
|
||||
# do the random dialation after the mask is applied so it does not match perfectly.
|
||||
# this will make the model learn to prevent weird edges
|
||||
if self.random_dialate_mask:
|
||||
inpainting_tensor_mask = random_dialate_mask(
|
||||
inpainting_tensor_mask,
|
||||
max_percent=0.05
|
||||
)
|
||||
|
||||
# mask needs to be 1 for inpaint area and 0 for area to leave alone. So flip it.
|
||||
inpainting_tensor_mask = 1 - inpainting_tensor_mask
|
||||
# leave the mask as 0-1 and concat on channel of latents
|
||||
|
||||
Reference in New Issue
Block a user