mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added mask diffirential mask dialation for flex2. Handle video for the i2v adapter
This commit is contained in:
@@ -16,7 +16,7 @@ from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guid
|
|||||||
from toolkit.dequantize import patch_dequantization_on_save
|
from toolkit.dequantize import patch_dequantization_on_save
|
||||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||||
from optimum.quanto import freeze, QTensor
|
from optimum.quanto import freeze, QTensor
|
||||||
from toolkit.util.mask import generate_random_mask
|
from toolkit.util.mask import generate_random_mask, random_dialate_mask
|
||||||
from toolkit.util.quantize import quantize, get_qtype
|
from toolkit.util.quantize import quantize, get_qtype
|
||||||
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
|
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
|
||||||
from .pipeline import Flex2Pipeline
|
from .pipeline import Flex2Pipeline
|
||||||
@@ -77,6 +77,7 @@ class Flex2(BaseModel):
|
|||||||
self.control_dropout = model_config.model_kwargs.get('control_dropout', 0.0)
|
self.control_dropout = model_config.model_kwargs.get('control_dropout', 0.0)
|
||||||
self.inpaint_random_chance = model_config.model_kwargs.get('inpaint_random_chance', 0.0)
|
self.inpaint_random_chance = model_config.model_kwargs.get('inpaint_random_chance', 0.0)
|
||||||
self.random_blur_mask = model_config.model_kwargs.get('random_blur_mask', False)
|
self.random_blur_mask = model_config.model_kwargs.get('random_blur_mask', False)
|
||||||
|
self.random_dialate_mask = model_config.model_kwargs.get('random_dialate_mask', False)
|
||||||
self.do_random_inpainting = model_config.model_kwargs.get('do_random_inpainting', False)
|
self.do_random_inpainting = model_config.model_kwargs.get('do_random_inpainting', False)
|
||||||
|
|
||||||
# static method to get the noise scheduler
|
# static method to get the noise scheduler
|
||||||
@@ -446,6 +447,14 @@ class Flex2(BaseModel):
|
|||||||
# we are zeroing our the latents in the inpaint area not on the pixel space.
|
# we are zeroing our the latents in the inpaint area not on the pixel space.
|
||||||
inpainting_latent = inpainting_latent * inpainting_tensor_mask
|
inpainting_latent = inpainting_latent * inpainting_tensor_mask
|
||||||
|
|
||||||
|
# do the random dialation after the mask is applied so it does not match perfectly.
|
||||||
|
# this will make the model learn to prevent weird edges
|
||||||
|
if self.random_dialate_mask:
|
||||||
|
inpainting_tensor_mask = random_dialate_mask(
|
||||||
|
inpainting_tensor_mask,
|
||||||
|
max_percent=0.05
|
||||||
|
)
|
||||||
|
|
||||||
# mask needs to be 1 for inpaint area and 0 for area to leave alone. So flip it.
|
# mask needs to be 1 for inpaint area and 0 for area to leave alone. So flip it.
|
||||||
inpainting_tensor_mask = 1 - inpainting_tensor_mask
|
inpainting_tensor_mask = 1 - inpainting_tensor_mask
|
||||||
# leave the mask as 0-1 and concat on channel of latents
|
# leave the mask as 0-1 and concat on channel of latents
|
||||||
|
|||||||
@@ -771,7 +771,11 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
||||||
self.timer.start('preprocess_batch')
|
self.timer.start('preprocess_batch')
|
||||||
|
if isinstance(self.adapter, CustomAdapter):
|
||||||
|
batch = self.adapter.edit_batch_raw(batch)
|
||||||
batch = self.preprocess_batch(batch)
|
batch = self.preprocess_batch(batch)
|
||||||
|
if isinstance(self.adapter, CustomAdapter):
|
||||||
|
batch = self.adapter.edit_batch_processed(batch)
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
# sanity check
|
# sanity check
|
||||||
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
|
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
|
||||||
|
|||||||
@@ -302,6 +302,16 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
# else:
|
# else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def edit_batch_raw(self, batch: DataLoaderBatchDTO):
|
||||||
|
# happens on a raw batch before latents are created
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def edit_batch_processed(self, batch: DataLoaderBatchDTO):
|
||||||
|
# happens after the latents are processed
|
||||||
|
if self.adapter_type == "i2v":
|
||||||
|
return self.i2v_adapter.edit_batch_processed(batch)
|
||||||
|
return batch
|
||||||
|
|
||||||
def setup_clip(self):
|
def setup_clip(self):
|
||||||
adapter_config = self.config
|
adapter_config = self.config
|
||||||
sd = self.sd_ref()
|
sd = self.sd_ref()
|
||||||
|
|||||||
@@ -592,6 +592,36 @@ class I2VAdapter(torch.nn.Module):
|
|||||||
def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO):
|
def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO):
|
||||||
# todo handle start frame
|
# todo handle start frame
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
def edit_batch_processed(self, batch: DataLoaderBatchDTO):
|
||||||
|
with torch.no_grad():
|
||||||
|
# we will alway get a clip image frame, if one is not passed, use image
|
||||||
|
# or if video, pull from the first frame
|
||||||
|
# edit the batch to pull the first frame out of a video if we have it
|
||||||
|
# videos come in (bs, num_frames, channels, height, width)
|
||||||
|
tensor = batch.tensor
|
||||||
|
if batch.clip_image_tensor is None:
|
||||||
|
if len(tensor.shape) == 5:
|
||||||
|
# we have a video
|
||||||
|
first_frames = tensor[:, 0, :, :, :].clone()
|
||||||
|
else:
|
||||||
|
# we have a single image
|
||||||
|
first_frames = tensor.clone()
|
||||||
|
|
||||||
|
# it is -1 to 1, change it to 0 to 1
|
||||||
|
first_frames = (first_frames + 1) / 2
|
||||||
|
|
||||||
|
# clip image tensors are preprocessed.
|
||||||
|
tensors_0_1 = first_frames.to(dtype=torch.float16)
|
||||||
|
clip_out = self.adapter_ref().clip_image_processor(
|
||||||
|
images=tensors_0_1,
|
||||||
|
return_tensors="pt",
|
||||||
|
do_resize=True,
|
||||||
|
do_rescale=False,
|
||||||
|
).pixel_values
|
||||||
|
|
||||||
|
batch.clip_image_tensor = clip_out.to(self.device_torch)
|
||||||
|
return batch
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_active(self):
|
def is_active(self):
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import os
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import time
|
import time
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
def generate_random_mask(
|
def generate_random_mask(
|
||||||
@@ -173,7 +174,7 @@ def get_gaussian_kernel(kernel_size=5, sigma=1.0):
|
|||||||
return gaussian
|
return gaussian
|
||||||
|
|
||||||
|
|
||||||
def save_masks_as_images(masks, output_dir="output"):
|
def save_masks_as_images(masks, suffix="", output_dir="output"):
|
||||||
"""
|
"""
|
||||||
Save generated masks as RGB JPG images using PIL.
|
Save generated masks as RGB JPG images using PIL.
|
||||||
"""
|
"""
|
||||||
@@ -192,7 +193,65 @@ def save_masks_as_images(masks, output_dir="output"):
|
|||||||
|
|
||||||
# Convert to PIL Image and save
|
# Convert to PIL Image and save
|
||||||
img = Image.fromarray(rgb_mask)
|
img = Image.fromarray(rgb_mask)
|
||||||
img.save(os.path.join(output_dir, f"mask_{i:03d}.jpg"), quality=95)
|
img.save(os.path.join(output_dir, f"mask_{i:03d}{suffix}.jpg"), quality=95)
|
||||||
|
|
||||||
|
|
||||||
|
def random_dialate_mask(mask, max_percent=0.05):
|
||||||
|
"""
|
||||||
|
Randomly dialates a binary mask with a kernel of random size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask (torch.Tensor): Input mask of shape [batch_size, channels, height, width]
|
||||||
|
max_percent (float): Maximum kernel size as a percentage of the mask size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Dialated mask with the same shape as input
|
||||||
|
"""
|
||||||
|
|
||||||
|
size = mask.shape[-1]
|
||||||
|
max_size = int(size * max_percent)
|
||||||
|
|
||||||
|
# Handle case where max_size is too small
|
||||||
|
if max_size < 3:
|
||||||
|
max_size = 3
|
||||||
|
|
||||||
|
batch_chunks = torch.chunk(mask, mask.shape[0], dim=0)
|
||||||
|
out_chunks = []
|
||||||
|
|
||||||
|
for i in range(len(batch_chunks)):
|
||||||
|
chunk = batch_chunks[i]
|
||||||
|
|
||||||
|
# Ensure kernel size is odd for proper padding
|
||||||
|
kernel_size = np.random.randint(1, max_size)
|
||||||
|
|
||||||
|
# If kernel_size is less than 2, keep the original mask
|
||||||
|
if kernel_size < 2:
|
||||||
|
out_chunks.append(chunk)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Make sure kernel size is odd
|
||||||
|
if kernel_size % 2 == 0:
|
||||||
|
kernel_size += 1
|
||||||
|
|
||||||
|
# Create normalized dilation kernel
|
||||||
|
kernel = torch.ones((1, 1, kernel_size, kernel_size), device=mask.device) / (kernel_size * kernel_size)
|
||||||
|
|
||||||
|
# Pad the mask for convolution
|
||||||
|
padding = kernel_size // 2
|
||||||
|
padded_mask = F.pad(chunk, (padding, padding, padding, padding), mode='constant', value=0)
|
||||||
|
|
||||||
|
# Apply convolution
|
||||||
|
dilated = F.conv2d(padded_mask, kernel)
|
||||||
|
|
||||||
|
# Random threshold for varied dilation effect
|
||||||
|
threshold = np.random.uniform(0.2, 0.8)
|
||||||
|
|
||||||
|
# Apply threshold
|
||||||
|
dilated = (dilated > threshold).float()
|
||||||
|
|
||||||
|
out_chunks.append(dilated)
|
||||||
|
|
||||||
|
return torch.cat(out_chunks, dim=0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -216,11 +275,14 @@ if __name__ == "__main__":
|
|||||||
max_coverage=0.8,
|
max_coverage=0.8,
|
||||||
num_blobs_range=(1, 3)
|
num_blobs_range=(1, 3)
|
||||||
)
|
)
|
||||||
|
dialation = random_dialate_mask(masks)
|
||||||
|
print(f"Generated {batch_size} masks with shape: {masks.shape}")
|
||||||
end = time.time()
|
end = time.time()
|
||||||
# print time in milliseconds
|
# print time in milliseconds
|
||||||
print(f"Time taken: {(end - start)*1000:.2f} ms")
|
print(f"Time taken: {(end - start)*1000:.2f} ms")
|
||||||
|
|
||||||
print(f"Saving masks to 'output' directory...")
|
print(f"Saving masks to 'output' directory...")
|
||||||
save_masks_as_images(masks)
|
save_masks_as_images(masks)
|
||||||
|
save_masks_as_images(dialation, suffix="_dilated" )
|
||||||
|
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|||||||
Reference in New Issue
Block a user