mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-22 05:13:57 +00:00
291 lines
9.9 KiB
Python
291 lines
9.9 KiB
Python
from functools import partial
|
|
import torch
|
|
from toolkit.prompt_utils import PromptEmbeds
|
|
from PIL import Image
|
|
from diffusers import UniPCMultistepScheduler
|
|
import torch
|
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
|
from toolkit.samplers.custom_flowmatch_sampler import (
|
|
CustomFlowMatchEulerDiscreteScheduler,
|
|
)
|
|
from .wan22_pipeline import Wan22Pipeline
|
|
|
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
|
from torchvision.transforms import functional as TF
|
|
|
|
from toolkit.models.wan21.wan21 import Wan21, AggressiveWanUnloadPipeline
|
|
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning_v22
|
|
|
|
|
|
# for generation only?
|
|
scheduler_configUniPC = {
|
|
"_class_name": "UniPCMultistepScheduler",
|
|
"_diffusers_version": "0.35.0.dev0",
|
|
"beta_end": 0.02,
|
|
"beta_schedule": "linear",
|
|
"beta_start": 0.0001,
|
|
"disable_corrector": [],
|
|
"dynamic_thresholding_ratio": 0.995,
|
|
"final_sigmas_type": "zero",
|
|
"flow_shift": 5.0,
|
|
"lower_order_final": True,
|
|
"num_train_timesteps": 1000,
|
|
"predict_x0": True,
|
|
"prediction_type": "flow_prediction",
|
|
"rescale_betas_zero_snr": False,
|
|
"sample_max_value": 1.0,
|
|
"solver_order": 2,
|
|
"solver_p": None,
|
|
"solver_type": "bh2",
|
|
"steps_offset": 0,
|
|
"thresholding": False,
|
|
"time_shift_type": "exponential",
|
|
"timestep_spacing": "linspace",
|
|
"trained_betas": None,
|
|
"use_beta_sigmas": False,
|
|
"use_dynamic_shifting": False,
|
|
"use_exponential_sigmas": False,
|
|
"use_flow_sigmas": True,
|
|
"use_karras_sigmas": False,
|
|
}
|
|
|
|
# for training. I think it is right
|
|
scheduler_config = {
|
|
"num_train_timesteps": 1000,
|
|
"shift": 5.0,
|
|
"use_dynamic_shifting": False,
|
|
}
|
|
|
|
# TODO: this is a temporary monkeypatch to fix the time text embedding to allow for batch sizes greater than 1. Remove this when the diffusers library is fixed.
|
|
def time_text_monkeypatch(
|
|
self,
|
|
timestep: torch.Tensor,
|
|
encoder_hidden_states,
|
|
encoder_hidden_states_image = None,
|
|
timestep_seq_len = None,
|
|
):
|
|
timestep = self.timesteps_proj(timestep)
|
|
if timestep_seq_len is not None:
|
|
timestep = timestep.unflatten(0, (encoder_hidden_states.shape[0], timestep_seq_len))
|
|
|
|
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
|
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
|
timestep = timestep.to(time_embedder_dtype)
|
|
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
|
timestep_proj = self.time_proj(self.act_fn(temb))
|
|
|
|
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
|
if encoder_hidden_states_image is not None:
|
|
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
|
|
|
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
|
|
|
class Wan225bModel(Wan21):
|
|
arch = "wan22_5b"
|
|
_wan_generation_scheduler_config = scheduler_configUniPC
|
|
_wan_expand_timesteps = True
|
|
|
|
def __init__(
|
|
self,
|
|
device,
|
|
model_config: ModelConfig,
|
|
dtype="bf16",
|
|
custom_pipeline=None,
|
|
noise_scheduler=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
device=device,
|
|
model_config=model_config,
|
|
dtype=dtype,
|
|
custom_pipeline=custom_pipeline,
|
|
noise_scheduler=noise_scheduler,
|
|
**kwargs,
|
|
)
|
|
|
|
self._wan_cache = None
|
|
|
|
def load_model(self):
|
|
super().load_model()
|
|
|
|
# patch the condition embedder
|
|
self.model.condition_embedder.forward = partial(time_text_monkeypatch, self.model.condition_embedder)
|
|
|
|
def get_bucket_divisibility(self):
|
|
# 16x compression and 2x2 patch size
|
|
return 32
|
|
|
|
def get_generation_pipeline(self):
|
|
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
|
pipeline = Wan22Pipeline(
|
|
vae=self.vae,
|
|
transformer=self.model,
|
|
transformer_2=self.model,
|
|
text_encoder=self.text_encoder,
|
|
tokenizer=self.tokenizer,
|
|
scheduler=scheduler,
|
|
expand_timesteps=self._wan_expand_timesteps,
|
|
device=self.device_torch,
|
|
aggressive_offload=self.model_config.low_vram,
|
|
)
|
|
|
|
pipeline = pipeline.to(self.device_torch)
|
|
|
|
return pipeline
|
|
|
|
# static method to get the scheduler
|
|
@staticmethod
|
|
def get_train_scheduler():
|
|
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
|
return scheduler
|
|
|
|
def get_base_model_version(self):
|
|
return "wan_2.2_5b"
|
|
|
|
def generate_single_image(
|
|
self,
|
|
pipeline: AggressiveWanUnloadPipeline,
|
|
gen_config: GenerateImageConfig,
|
|
conditional_embeds: PromptEmbeds,
|
|
unconditional_embeds: PromptEmbeds,
|
|
generator: torch.Generator,
|
|
extra: dict,
|
|
):
|
|
# reactivate progress bar since this is slooooow
|
|
pipeline.set_progress_bar_config(disable=False)
|
|
|
|
num_frames = (
|
|
(gen_config.num_frames - 1) // 4
|
|
) * 4 + 1 # make sure it is divisible by 4 + 1
|
|
gen_config.num_frames = num_frames
|
|
|
|
height = gen_config.height
|
|
width = gen_config.width
|
|
noise_mask = None
|
|
if gen_config.ctrl_img is not None:
|
|
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
|
|
|
d = self.get_bucket_divisibility()
|
|
|
|
# make sure they are divisible by d
|
|
height = height // d * d
|
|
width = width // d * d
|
|
|
|
# resize the control image
|
|
control_img = control_img.resize((width, height), Image.LANCZOS)
|
|
|
|
# 5. Prepare latent variables
|
|
num_channels_latents = self.transformer.config.in_channels
|
|
latents = pipeline.prepare_latents(
|
|
1,
|
|
num_channels_latents,
|
|
height,
|
|
width,
|
|
gen_config.num_frames,
|
|
torch.float32,
|
|
self.device_torch,
|
|
generator,
|
|
None,
|
|
).to(self.torch_dtype)
|
|
|
|
first_frame_n1p1 = (
|
|
TF.to_tensor(control_img)
|
|
.unsqueeze(0)
|
|
.to(self.device_torch, dtype=self.torch_dtype)
|
|
* 2.0
|
|
- 1.0
|
|
) # normalize to [-1, 1]
|
|
|
|
gen_config.latents, noise_mask = add_first_frame_conditioning_v22(
|
|
latent_model_input=latents, first_frame=first_frame_n1p1, vae=self.vae
|
|
)
|
|
|
|
output = pipeline(
|
|
prompt_embeds=conditional_embeds.text_embeds.to(
|
|
self.device_torch, dtype=self.torch_dtype
|
|
),
|
|
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
|
|
self.device_torch, dtype=self.torch_dtype
|
|
),
|
|
height=height,
|
|
width=width,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
guidance_scale=gen_config.guidance_scale,
|
|
latents=gen_config.latents,
|
|
num_frames=gen_config.num_frames,
|
|
generator=generator,
|
|
return_dict=False,
|
|
output_type="pil",
|
|
noise_mask=noise_mask,
|
|
**extra,
|
|
)[0]
|
|
|
|
# shape = [1, frames, channels, height, width]
|
|
batch_item = output[0] # list of pil images
|
|
if gen_config.num_frames > 1:
|
|
return batch_item # return the frames.
|
|
else:
|
|
# get just the first image
|
|
img = batch_item[0]
|
|
return img
|
|
|
|
def get_noise_prediction(
|
|
self,
|
|
latent_model_input: torch.Tensor,
|
|
timestep: torch.Tensor, # 0 to 1000 scale
|
|
text_embeddings: PromptEmbeds,
|
|
batch: DataLoaderBatchDTO,
|
|
**kwargs,
|
|
):
|
|
# videos come in (bs, num_frames, channels, height, width)
|
|
# images come in (bs, channels, height, width)
|
|
|
|
# for wan, only do i2v for video for now. Images do normal t2i
|
|
conditioned_latent = latent_model_input
|
|
noise_mask = None
|
|
|
|
if batch.dataset_config.do_i2v:
|
|
with torch.no_grad():
|
|
frames = batch.tensor
|
|
if len(frames.shape) == 4:
|
|
first_frames = frames
|
|
elif len(frames.shape) == 5:
|
|
first_frames = frames[:, 0]
|
|
# Add conditioning using the standalone function
|
|
conditioned_latent, noise_mask = add_first_frame_conditioning_v22(
|
|
latent_model_input=latent_model_input.to(
|
|
self.device_torch, self.torch_dtype
|
|
),
|
|
first_frame=first_frames.to(self.device_torch, self.torch_dtype),
|
|
vae=self.vae,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown frame shape {frames.shape}")
|
|
|
|
# make the noise mask
|
|
if noise_mask is None:
|
|
noise_mask = torch.ones(
|
|
conditioned_latent.shape,
|
|
dtype=conditioned_latent.dtype,
|
|
device=conditioned_latent.device,
|
|
)
|
|
# todo write this better
|
|
t_chunks = torch.chunk(timestep, timestep.shape[0])
|
|
out_t_chunks = []
|
|
for t in t_chunks:
|
|
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
|
temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten()
|
|
# batch_size, seq_len
|
|
temp_ts = temp_ts.unsqueeze(0)
|
|
out_t_chunks.append(temp_ts)
|
|
timestep = torch.cat(out_t_chunks, dim=0)
|
|
|
|
noise_pred = self.model(
|
|
hidden_states=conditioned_latent,
|
|
timestep=timestep,
|
|
encoder_hidden_states=text_embeddings.text_embeds,
|
|
return_dict=False,
|
|
**kwargs,
|
|
)[0]
|
|
return noise_pred
|