Use a shuffled embedding as unconditional for i2v adapter

This commit is contained in:
Jaret Burkett
2025-04-11 10:44:43 -06:00
parent 059155174a
commit 4a43589666
2 changed files with 61 additions and 1 deletions

View File

@@ -10,6 +10,8 @@ from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProce
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding
from toolkit.util.shuffle import shuffle_tensor_along_axis
if TYPE_CHECKING:
from toolkit.models.base_model import BaseModel
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
@@ -318,7 +320,14 @@ def new_wan_forward(
self._do_unconditional = not self._do_unconditional
if self._do_unconditional:
# slightly reduce strength of conditional for the unconditional
encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5
# encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5
# shuffle the embedding tokens so we still have all the information, but it is scrambled
# this will prevent things like color from being cfg overweights, but still sharpen content.
encoder_hidden_states_image = shuffle_tensor_along_axis(
adapter.adapter_ref().conditional_embeds,
axis=1
)
# encoder_hidden_states_image = adapter.adapter_ref().unconditional_embeds
else:
# use the conditional