mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-21 12:53:56 +00:00
Fixed breaking change with diffusers. Allow flowmatch on normal stable diffusion models.
This commit is contained in:
@@ -13,7 +13,7 @@ from transformers import CLIPImageProcessor
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||
import torchvision.transforms.functional
|
||||
from toolkit.image_utils import show_img
|
||||
from toolkit.image_utils import show_img, show_tensors
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
@@ -34,7 +34,7 @@ parser.add_argument('--epochs', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_folder = args.dataset_folder
|
||||
resolution = 512
|
||||
resolution = 1024
|
||||
bucket_tolerance = 64
|
||||
batch_size = 1
|
||||
|
||||
@@ -55,8 +55,8 @@ class FakeSD:
|
||||
|
||||
dataset_config = DatasetConfig(
|
||||
dataset_path=dataset_folder,
|
||||
clip_image_path=dataset_folder,
|
||||
square_crop=True,
|
||||
# clip_image_path=dataset_folder,
|
||||
# square_crop=True,
|
||||
resolution=resolution,
|
||||
# caption_ext='json',
|
||||
default_caption='default',
|
||||
@@ -88,32 +88,37 @@ for epoch in range(args.epochs):
|
||||
|
||||
# img_batch = color_block_imgs(img_batch, neg1_1=True)
|
||||
|
||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# put them so they are size by side
|
||||
big_img = torch.cat(chunks, dim=3)
|
||||
big_img = big_img.squeeze(0)
|
||||
# chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# # put them so they are size by side
|
||||
# big_img = torch.cat(chunks, dim=3)
|
||||
# big_img = big_img.squeeze(0)
|
||||
#
|
||||
# control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
|
||||
# big_control_img = torch.cat(control_chunks, dim=3)
|
||||
# big_control_img = big_control_img.squeeze(0) * 2 - 1
|
||||
#
|
||||
#
|
||||
# # resize control image
|
||||
# big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
|
||||
#
|
||||
# big_img = torch.cat([big_img, big_control_img], dim=2)
|
||||
#
|
||||
# min_val = big_img.min()
|
||||
# max_val = big_img.max()
|
||||
#
|
||||
# big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
|
||||
big_control_img = torch.cat(control_chunks, dim=3)
|
||||
big_control_img = big_control_img.squeeze(0) * 2 - 1
|
||||
big_img = img_batch
|
||||
# big_img = big_img.clamp(-1, 1)
|
||||
|
||||
|
||||
# resize control image
|
||||
big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
|
||||
|
||||
big_img = torch.cat([big_img, big_control_img], dim=2)
|
||||
|
||||
min_val = big_img.min()
|
||||
max_val = big_img.max()
|
||||
|
||||
big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||
show_tensors(big_img)
|
||||
|
||||
# convert to image
|
||||
img = transforms.ToPILImage()(big_img)
|
||||
# img = transforms.ToPILImage()(big_img)
|
||||
#
|
||||
# show_img(img)
|
||||
|
||||
show_img(img)
|
||||
|
||||
time.sleep(1.0)
|
||||
time.sleep(0.2)
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
@@ -5,7 +5,6 @@ import sys
|
||||
|
||||
from PIL import Image
|
||||
from diffusers import Transformer2DModel
|
||||
from diffusers.models.attention_processor import apply_rope
|
||||
from torch import nn
|
||||
from torch.nn import Parameter
|
||||
from torch.nn.modules.module import T
|
||||
@@ -341,7 +340,10 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
|
||||
# from ..embeddings import apply_rotary_emb
|
||||
# query = apply_rotary_emb(query, image_rotary_emb)
|
||||
# key = apply_rotary_emb(key, image_rotary_emb)
|
||||
query, key = apply_rope(query, key, image_rotary_emb)
|
||||
from diffusers.models.embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
with torch.no_grad():
|
||||
# create weights for timesteps
|
||||
|
||||
@@ -34,6 +34,7 @@ from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
|
||||
from toolkit.sd_device_states_presets import empty_preset
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
@@ -178,7 +179,7 @@ class StableDiffusion:
|
||||
self.config_file = None
|
||||
|
||||
self.is_flow_matching = False
|
||||
if self.is_flux or self.is_v3 or self.is_auraflow:
|
||||
if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler):
|
||||
self.is_flow_matching = True
|
||||
|
||||
self.quantize_device = quantize_device if quantize_device is not None else self.device
|
||||
|
||||
Reference in New Issue
Block a user