mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix bug that prevented training wan 2.2 with batch size greater than 1
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from functools import partial
|
||||
import torch
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from PIL import Image
|
||||
@@ -55,6 +56,29 @@ scheduler_config = {
|
||||
"use_dynamic_shifting": False,
|
||||
}
|
||||
|
||||
# TODO: this is a temporary monkeypatch to fix the time text embedding to allow for batch sizes greater than 1. Remove this when the diffusers library is fixed.
|
||||
def time_text_monkeypatch(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states_image = None,
|
||||
timestep_seq_len = None,
|
||||
):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
if timestep_seq_len is not None:
|
||||
timestep = timestep.unflatten(0, (encoder_hidden_states.shape[0], timestep_seq_len))
|
||||
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
if encoder_hidden_states_image is not None:
|
||||
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
||||
|
||||
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
||||
|
||||
class Wan22Model(Wan21):
|
||||
arch = "wan22_5b"
|
||||
@@ -80,6 +104,12 @@ class Wan22Model(Wan21):
|
||||
)
|
||||
|
||||
self._wan_cache = None
|
||||
|
||||
def load_model(self):
|
||||
super().load_model()
|
||||
|
||||
# patch the condition embedder
|
||||
self.model.condition_embedder.forward = partial(time_text_monkeypatch, self.model.condition_embedder)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
# 16x compression and 2x2 patch size
|
||||
|
||||
Reference in New Issue
Block a user