Fixed breaking change with diffusers. Allow flowmatch on normal stable diffusion models.

This commit is contained in:
Jaret Burkett
2024-08-22 14:36:22 -06:00
parent e07a98a50c
commit 338c77d677
4 changed files with 37 additions and 28 deletions

View File

@@ -13,7 +13,7 @@ from transformers import CLIPImageProcessor
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from toolkit.paths import SD_SCRIPTS_ROOT
import torchvision.transforms.functional
from toolkit.image_utils import show_img
from toolkit.image_utils import show_img, show_tensors
sys.path.append(SD_SCRIPTS_ROOT)
@@ -34,7 +34,7 @@ parser.add_argument('--epochs', type=int, default=1)
args = parser.parse_args()
dataset_folder = args.dataset_folder
resolution = 512
resolution = 1024
bucket_tolerance = 64
batch_size = 1
@@ -55,8 +55,8 @@ class FakeSD:
dataset_config = DatasetConfig(
dataset_path=dataset_folder,
clip_image_path=dataset_folder,
square_crop=True,
# clip_image_path=dataset_folder,
# square_crop=True,
resolution=resolution,
# caption_ext='json',
default_caption='default',
@@ -88,32 +88,37 @@ for epoch in range(args.epochs):
# img_batch = color_block_imgs(img_batch, neg1_1=True)
chunks = torch.chunk(img_batch, batch_size, dim=0)
# put them so they are size by side
big_img = torch.cat(chunks, dim=3)
big_img = big_img.squeeze(0)
# chunks = torch.chunk(img_batch, batch_size, dim=0)
# # put them so they are size by side
# big_img = torch.cat(chunks, dim=3)
# big_img = big_img.squeeze(0)
#
# control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
# big_control_img = torch.cat(control_chunks, dim=3)
# big_control_img = big_control_img.squeeze(0) * 2 - 1
#
#
# # resize control image
# big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
#
# big_img = torch.cat([big_img, big_control_img], dim=2)
#
# min_val = big_img.min()
# max_val = big_img.max()
#
# big_img = (big_img / 2 + 0.5).clamp(0, 1)
control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
big_control_img = torch.cat(control_chunks, dim=3)
big_control_img = big_control_img.squeeze(0) * 2 - 1
big_img = img_batch
# big_img = big_img.clamp(-1, 1)
# resize control image
big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
big_img = torch.cat([big_img, big_control_img], dim=2)
min_val = big_img.min()
max_val = big_img.max()
big_img = (big_img / 2 + 0.5).clamp(0, 1)
show_tensors(big_img)
# convert to image
img = transforms.ToPILImage()(big_img)
# img = transforms.ToPILImage()(big_img)
#
# show_img(img)
show_img(img)
time.sleep(1.0)
time.sleep(0.2)
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)

View File

@@ -5,7 +5,6 @@ import sys
from PIL import Image
from diffusers import Transformer2DModel
from diffusers.models.attention_processor import apply_rope
from torch import nn
from torch.nn import Parameter
from torch.nn.modules.module import T
@@ -341,7 +340,10 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

View File

@@ -8,6 +8,7 @@ import torch
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_noise_sigma = 1.0
with torch.no_grad():
# create weights for timesteps

View File

@@ -34,6 +34,7 @@ from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
from toolkit.reference_adapter import ReferenceAdapter
from toolkit.sampler import get_sampler
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
from toolkit.sd_device_states_presets import empty_preset
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
@@ -178,7 +179,7 @@ class StableDiffusion:
self.config_file = None
self.is_flow_matching = False
if self.is_flux or self.is_v3 or self.is_auraflow:
if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler):
self.is_flow_matching = True
self.quantize_device = quantize_device if quantize_device is not None else self.device