diff --git a/toolkit/models/i2v_adapter.py b/toolkit/models/i2v_adapter.py index dcf4c2f2..50b24cd4 100644 --- a/toolkit/models/i2v_adapter.py +++ b/toolkit/models/i2v_adapter.py @@ -10,6 +10,8 @@ from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProce from diffusers.models.attention_processor import Attention from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding +from toolkit.util.shuffle import shuffle_tensor_along_axis + if TYPE_CHECKING: from toolkit.models.base_model import BaseModel from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig @@ -318,7 +320,14 @@ def new_wan_forward( self._do_unconditional = not self._do_unconditional if self._do_unconditional: # slightly reduce strength of conditional for the unconditional - encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5 + # encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5 + # shuffle the embedding tokens so we still have all the information, but it is scrambled + # this will prevent things like color from being cfg overweights, but still sharpen content. + + encoder_hidden_states_image = shuffle_tensor_along_axis( + adapter.adapter_ref().conditional_embeds, + axis=1 + ) # encoder_hidden_states_image = adapter.adapter_ref().unconditional_embeds else: # use the conditional diff --git a/toolkit/util/shuffle.py b/toolkit/util/shuffle.py new file mode 100644 index 00000000..7d735f9d --- /dev/null +++ b/toolkit/util/shuffle.py @@ -0,0 +1,51 @@ +import torch +import random +import numpy as np + +def shuffle_tensor_along_axis(tensor, axis=0, seed=None): + """ + Shuffle a tensor along a specified axis without affecting the global random state. + + Args: + tensor (torch.Tensor): The input tensor to shuffle + axis (int, optional): The axis along which to shuffle. Defaults to 0. + seed (int, optional): Random seed for reproducibility. Defaults to None. + + Returns: + torch.Tensor: The shuffled tensor + """ + # Clone the tensor to avoid in-place modifications + shuffled_tensor = tensor.clone() + + # Store original random states + torch_state = torch.get_rng_state() + np_state = np.random.get_state() + py_state = random.getstate() + + try: + # Set seed if provided + if seed is not None: + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + # Get the size of the dimension to shuffle + dim_size = tensor.shape[axis] + + # Generate random indices for shuffling + indices = torch.randperm(dim_size) + + # Create a slice object to shuffle along the specified axis + slices = [slice(None)] * tensor.dim() + slices[axis] = indices + + # Apply the shuffle + shuffled_tensor = tensor[slices] + + finally: + # Restore original random states + torch.set_rng_state(torch_state) + np.random.set_state(np_state) + random.setstate(py_state) + + return shuffled_tensor \ No newline at end of file