From 338c77d67733a2d6d9c4fdd55623ae04f5ed5ead Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 22 Aug 2024 14:36:22 -0600 Subject: [PATCH] Fixed breaking change with diffusers. Allow flowmatch on normal stable diffusion models. --- testing/test_bucket_dataloader.py | 55 +++++++++++--------- toolkit/ip_adapter.py | 6 ++- toolkit/samplers/custom_flowmatch_sampler.py | 1 + toolkit/stable_diffusion_model.py | 3 +- 4 files changed, 37 insertions(+), 28 deletions(-) diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index 6be8bddd..31d97f2d 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -13,7 +13,7 @@ from transformers import CLIPImageProcessor sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from toolkit.paths import SD_SCRIPTS_ROOT import torchvision.transforms.functional -from toolkit.image_utils import show_img +from toolkit.image_utils import show_img, show_tensors sys.path.append(SD_SCRIPTS_ROOT) @@ -34,7 +34,7 @@ parser.add_argument('--epochs', type=int, default=1) args = parser.parse_args() dataset_folder = args.dataset_folder -resolution = 512 +resolution = 1024 bucket_tolerance = 64 batch_size = 1 @@ -55,8 +55,8 @@ class FakeSD: dataset_config = DatasetConfig( dataset_path=dataset_folder, - clip_image_path=dataset_folder, - square_crop=True, + # clip_image_path=dataset_folder, + # square_crop=True, resolution=resolution, # caption_ext='json', default_caption='default', @@ -88,32 +88,37 @@ for epoch in range(args.epochs): # img_batch = color_block_imgs(img_batch, neg1_1=True) - chunks = torch.chunk(img_batch, batch_size, dim=0) - # put them so they are size by side - big_img = torch.cat(chunks, dim=3) - big_img = big_img.squeeze(0) + # chunks = torch.chunk(img_batch, batch_size, dim=0) + # # put them so they are size by side + # big_img = torch.cat(chunks, dim=3) + # big_img = big_img.squeeze(0) + # + # control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0) + # big_control_img = torch.cat(control_chunks, dim=3) + # big_control_img = big_control_img.squeeze(0) * 2 - 1 + # + # + # # resize control image + # big_control_img = torchvision.transforms.Resize((width, height))(big_control_img) + # + # big_img = torch.cat([big_img, big_control_img], dim=2) + # + # min_val = big_img.min() + # max_val = big_img.max() + # + # big_img = (big_img / 2 + 0.5).clamp(0, 1) - control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0) - big_control_img = torch.cat(control_chunks, dim=3) - big_control_img = big_control_img.squeeze(0) * 2 - 1 + big_img = img_batch + # big_img = big_img.clamp(-1, 1) - - # resize control image - big_control_img = torchvision.transforms.Resize((width, height))(big_control_img) - - big_img = torch.cat([big_img, big_control_img], dim=2) - - min_val = big_img.min() - max_val = big_img.max() - - big_img = (big_img / 2 + 0.5).clamp(0, 1) + show_tensors(big_img) # convert to image - img = transforms.ToPILImage()(big_img) + # img = transforms.ToPILImage()(big_img) + # + # show_img(img) - show_img(img) - - time.sleep(1.0) + time.sleep(0.2) # if not last epoch if epoch < args.epochs - 1: trigger_dataloader_setup_epoch(dataloader) diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 4264071a..8d735b78 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -5,7 +5,6 @@ import sys from PIL import Image from diffusers import Transformer2DModel -from diffusers.models.attention_processor import apply_rope from torch import nn from torch.nn import Parameter from torch.nn.modules.module import T @@ -341,7 +340,10 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module): # from ..embeddings import apply_rotary_emb # query = apply_rotary_emb(query, image_rotary_emb) # key = apply_rotary_emb(key, image_rotary_emb) - query, key = apply_rope(query, key, image_rotary_emb) + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 8efc5d06..1cb2eac6 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -8,6 +8,7 @@ import torch class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.init_noise_sigma = 1.0 with torch.no_grad(): # create weights for timesteps diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 0528b4a6..9e27a1ed 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -34,6 +34,7 @@ from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds from toolkit.reference_adapter import ReferenceAdapter from toolkit.sampler import get_sampler +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers from toolkit.sd_device_states_presets import empty_preset from toolkit.train_tools import get_torch_dtype, apply_noise_offset @@ -178,7 +179,7 @@ class StableDiffusion: self.config_file = None self.is_flow_matching = False - if self.is_flux or self.is_v3 or self.is_auraflow: + if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler): self.is_flow_matching = True self.quantize_device = quantize_device if quantize_device is not None else self.device