mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Support img 2 vid training for ltx-2
This commit is contained in:
@@ -20,9 +20,10 @@ from optimum.quanto import freeze
|
||||
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||
from toolkit.memory_management import MemoryManager
|
||||
from safetensors.torch import load_file
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
from diffusers import LTX2Pipeline
|
||||
from diffusers import LTX2Pipeline, LTX2ImageToVideoPipeline
|
||||
from diffusers.models.autoencoders import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
@@ -518,9 +519,6 @@ class LTX2Model(BaseModel):
|
||||
pipeline.transformer = unwrap_model(self.model)
|
||||
pipeline.text_encoder = unwrap_model(self.text_encoder[0])
|
||||
|
||||
# if self.low_vram:
|
||||
# pipeline.enable_model_cpu_offload(device=self.device_torch)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
@@ -536,6 +534,20 @@ class LTX2Model(BaseModel):
|
||||
):
|
||||
if self.model.device == torch.device("cpu"):
|
||||
self.model.to(self.device_torch)
|
||||
|
||||
# handle control image
|
||||
if gen_config.ctrl_img is not None:
|
||||
# switch to image to video pipeline
|
||||
pipeline = LTX2ImageToVideoPipeline(
|
||||
scheduler=pipeline.scheduler,
|
||||
vae=pipeline.vae,
|
||||
audio_vae=pipeline.audio_vae,
|
||||
text_encoder=pipeline.text_encoder,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
connectors=pipeline.connectors,
|
||||
transformer=pipeline.transformer,
|
||||
vocoder=pipeline.vocoder,
|
||||
)
|
||||
|
||||
is_video = gen_config.num_frames > 1
|
||||
# override the generate single image to handle video + audio generation
|
||||
@@ -554,6 +566,14 @@ class LTX2Model(BaseModel):
|
||||
bd = self.get_bucket_divisibility()
|
||||
gen_config.height = (gen_config.height // bd) * bd
|
||||
gen_config.width = (gen_config.width // bd) * bd
|
||||
|
||||
# handle control image
|
||||
if gen_config.ctrl_img is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
||||
# resize the control image
|
||||
control_img = control_img.resize((gen_config.width, gen_config.height), Image.LANCZOS)
|
||||
# add the control image to the extra dict
|
||||
extra["image"] = control_img
|
||||
|
||||
# frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc.
|
||||
if gen_config.num_frames != 1:
|
||||
@@ -681,10 +701,40 @@ class LTX2Model(BaseModel):
|
||||
with torch.no_grad():
|
||||
if self.model.device == torch.device("cpu"):
|
||||
self.model.to(self.device_torch)
|
||||
|
||||
|
||||
batch_size, C, latent_num_frames, latent_height, latent_width = (
|
||||
latent_model_input.shape
|
||||
)
|
||||
|
||||
video_timestep = timestep.clone()
|
||||
|
||||
# i2v from first frame
|
||||
if batch.dataset_config.do_i2v:
|
||||
# videos come in (bs, num_frames, channels, height, width)
|
||||
# images come in (bs, channels, height, width)
|
||||
frames = batch.tensor
|
||||
if len(frames.shape) == 4:
|
||||
first_frames = frames
|
||||
elif len(frames.shape) == 5:
|
||||
first_frames = frames[:, 0]
|
||||
else:
|
||||
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||
# first frame doesnt have time dim, add it back
|
||||
init_latents = self.encode_images(first_frames, device=self.device_torch, dtype=self.torch_dtype)
|
||||
init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1)
|
||||
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
|
||||
# First condition is image latents and those should be kept clean.
|
||||
conditioning_mask = torch.zeros(mask_shape, device=self.device_torch, dtype=self.torch_dtype)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
|
||||
# use conditioning mask to replace latents
|
||||
latent_model_input = (
|
||||
latent_model_input * (1 - conditioning_mask)
|
||||
+ init_latents * conditioning_mask
|
||||
)
|
||||
|
||||
# set video timestep
|
||||
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||
|
||||
# todo get this somehow
|
||||
frame_rate = 24
|
||||
@@ -764,7 +814,8 @@ class LTX2Model(BaseModel):
|
||||
audio_hidden_states=audio_latents.to(self.transformer.dtype),
|
||||
encoder_hidden_states=connector_prompt_embeds,
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
|
||||
Reference in New Issue
Block a user