mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Use a shuffled embedding as unconditional for i2v adapter
This commit is contained in:
@@ -10,6 +10,8 @@ from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProce
|
|||||||
from diffusers.models.attention_processor import Attention
|
from diffusers.models.attention_processor import Attention
|
||||||
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding
|
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding
|
||||||
|
|
||||||
|
from toolkit.util.shuffle import shuffle_tensor_along_axis
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.models.base_model import BaseModel
|
from toolkit.models.base_model import BaseModel
|
||||||
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
|
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
|
||||||
@@ -318,7 +320,14 @@ def new_wan_forward(
|
|||||||
self._do_unconditional = not self._do_unconditional
|
self._do_unconditional = not self._do_unconditional
|
||||||
if self._do_unconditional:
|
if self._do_unconditional:
|
||||||
# slightly reduce strength of conditional for the unconditional
|
# slightly reduce strength of conditional for the unconditional
|
||||||
encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5
|
# encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5
|
||||||
|
# shuffle the embedding tokens so we still have all the information, but it is scrambled
|
||||||
|
# this will prevent things like color from being cfg overweights, but still sharpen content.
|
||||||
|
|
||||||
|
encoder_hidden_states_image = shuffle_tensor_along_axis(
|
||||||
|
adapter.adapter_ref().conditional_embeds,
|
||||||
|
axis=1
|
||||||
|
)
|
||||||
# encoder_hidden_states_image = adapter.adapter_ref().unconditional_embeds
|
# encoder_hidden_states_image = adapter.adapter_ref().unconditional_embeds
|
||||||
else:
|
else:
|
||||||
# use the conditional
|
# use the conditional
|
||||||
|
|||||||
51
toolkit/util/shuffle.py
Normal file
51
toolkit/util/shuffle.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def shuffle_tensor_along_axis(tensor, axis=0, seed=None):
|
||||||
|
"""
|
||||||
|
Shuffle a tensor along a specified axis without affecting the global random state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The input tensor to shuffle
|
||||||
|
axis (int, optional): The axis along which to shuffle. Defaults to 0.
|
||||||
|
seed (int, optional): Random seed for reproducibility. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The shuffled tensor
|
||||||
|
"""
|
||||||
|
# Clone the tensor to avoid in-place modifications
|
||||||
|
shuffled_tensor = tensor.clone()
|
||||||
|
|
||||||
|
# Store original random states
|
||||||
|
torch_state = torch.get_rng_state()
|
||||||
|
np_state = np.random.get_state()
|
||||||
|
py_state = random.getstate()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set seed if provided
|
||||||
|
if seed is not None:
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
# Get the size of the dimension to shuffle
|
||||||
|
dim_size = tensor.shape[axis]
|
||||||
|
|
||||||
|
# Generate random indices for shuffling
|
||||||
|
indices = torch.randperm(dim_size)
|
||||||
|
|
||||||
|
# Create a slice object to shuffle along the specified axis
|
||||||
|
slices = [slice(None)] * tensor.dim()
|
||||||
|
slices[axis] = indices
|
||||||
|
|
||||||
|
# Apply the shuffle
|
||||||
|
shuffled_tensor = tensor[slices]
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original random states
|
||||||
|
torch.set_rng_state(torch_state)
|
||||||
|
np.random.set_state(np_state)
|
||||||
|
random.setstate(py_state)
|
||||||
|
|
||||||
|
return shuffled_tensor
|
||||||
Reference in New Issue
Block a user