mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Use a shuffled embedding as unconditional for i2v adapter
This commit is contained in:
@@ -10,6 +10,8 @@ from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProce
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding
|
||||
|
||||
from toolkit.util.shuffle import shuffle_tensor_along_axis
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
|
||||
@@ -318,7 +320,14 @@ def new_wan_forward(
|
||||
self._do_unconditional = not self._do_unconditional
|
||||
if self._do_unconditional:
|
||||
# slightly reduce strength of conditional for the unconditional
|
||||
encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5
|
||||
# encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5
|
||||
# shuffle the embedding tokens so we still have all the information, but it is scrambled
|
||||
# this will prevent things like color from being cfg overweights, but still sharpen content.
|
||||
|
||||
encoder_hidden_states_image = shuffle_tensor_along_axis(
|
||||
adapter.adapter_ref().conditional_embeds,
|
||||
axis=1
|
||||
)
|
||||
# encoder_hidden_states_image = adapter.adapter_ref().unconditional_embeds
|
||||
else:
|
||||
# use the conditional
|
||||
|
||||
Reference in New Issue
Block a user