mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Fixed breaking change with diffusers. Allow flowmatch on normal stable diffusion models.
This commit is contained in:
@@ -5,7 +5,6 @@ import sys
|
||||
|
||||
from PIL import Image
|
||||
from diffusers import Transformer2DModel
|
||||
from diffusers.models.attention_processor import apply_rope
|
||||
from torch import nn
|
||||
from torch.nn import Parameter
|
||||
from torch.nn.modules.module import T
|
||||
@@ -341,7 +340,10 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
|
||||
# from ..embeddings import apply_rotary_emb
|
||||
# query = apply_rotary_emb(query, image_rotary_emb)
|
||||
# key = apply_rotary_emb(key, image_rotary_emb)
|
||||
query, key = apply_rope(query, key, image_rotary_emb)
|
||||
from diffusers.models.embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
Reference in New Issue
Block a user