Fixed breaking change with diffusers. Allow flowmatch on normal stable diffusion models.

This commit is contained in:
Jaret Burkett
2024-08-22 14:36:22 -06:00
parent e07a98a50c
commit 338c77d677
4 changed files with 37 additions and 28 deletions

View File

@@ -5,7 +5,6 @@ import sys
from PIL import Image
from diffusers import Transformer2DModel
from diffusers.models.attention_processor import apply_rope
from torch import nn
from torch.nn import Parameter
from torch.nn.modules.module import T
@@ -341,7 +340,10 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)