Use a shuffled embedding as unconditional for i2v adapter

This commit is contained in:
Jaret Burkett
2025-04-11 10:44:43 -06:00
parent 059155174a
commit 4a43589666
2 changed files with 61 additions and 1 deletions

View File

@@ -10,6 +10,8 @@ from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProce
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding
from toolkit.util.shuffle import shuffle_tensor_along_axis
if TYPE_CHECKING:
from toolkit.models.base_model import BaseModel
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
@@ -318,7 +320,14 @@ def new_wan_forward(
self._do_unconditional = not self._do_unconditional
if self._do_unconditional:
# slightly reduce strength of conditional for the unconditional
encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5
# encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5
# shuffle the embedding tokens so we still have all the information, but it is scrambled
# this will prevent things like color from being cfg overweights, but still sharpen content.
encoder_hidden_states_image = shuffle_tensor_along_axis(
adapter.adapter_ref().conditional_embeds,
axis=1
)
# encoder_hidden_states_image = adapter.adapter_ref().unconditional_embeds
else:
# use the conditional

51
toolkit/util/shuffle.py Normal file
View File

@@ -0,0 +1,51 @@
import torch
import random
import numpy as np
def shuffle_tensor_along_axis(tensor, axis=0, seed=None):
"""
Shuffle a tensor along a specified axis without affecting the global random state.
Args:
tensor (torch.Tensor): The input tensor to shuffle
axis (int, optional): The axis along which to shuffle. Defaults to 0.
seed (int, optional): Random seed for reproducibility. Defaults to None.
Returns:
torch.Tensor: The shuffled tensor
"""
# Clone the tensor to avoid in-place modifications
shuffled_tensor = tensor.clone()
# Store original random states
torch_state = torch.get_rng_state()
np_state = np.random.get_state()
py_state = random.getstate()
try:
# Set seed if provided
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# Get the size of the dimension to shuffle
dim_size = tensor.shape[axis]
# Generate random indices for shuffling
indices = torch.randperm(dim_size)
# Create a slice object to shuffle along the specified axis
slices = [slice(None)] * tensor.dim()
slices[axis] = indices
# Apply the shuffle
shuffled_tensor = tensor[slices]
finally:
# Restore original random states
torch.set_rng_state(torch_state)
np.random.set_state(np_state)
random.setstate(py_state)
return shuffled_tensor