Support img 2 vid training for ltx-2

This commit is contained in:
Jaret Burkett
2026-01-13 19:04:56 -07:00
parent 5b5aadadb8
commit 64fe29b182
3 changed files with 59 additions and 8 deletions

View File

@@ -20,9 +20,10 @@ from optimum.quanto import freeze
from toolkit.util.quantize import quantize, get_qtype, quantize_model from toolkit.util.quantize import quantize, get_qtype, quantize_model
from toolkit.memory_management import MemoryManager from toolkit.memory_management import MemoryManager
from safetensors.torch import load_file from safetensors.torch import load_file
from PIL import Image
try: try:
from diffusers import LTX2Pipeline from diffusers import LTX2Pipeline, LTX2ImageToVideoPipeline
from diffusers.models.autoencoders import ( from diffusers.models.autoencoders import (
AutoencoderKLLTX2Audio, AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video, AutoencoderKLLTX2Video,
@@ -518,9 +519,6 @@ class LTX2Model(BaseModel):
pipeline.transformer = unwrap_model(self.model) pipeline.transformer = unwrap_model(self.model)
pipeline.text_encoder = unwrap_model(self.text_encoder[0]) pipeline.text_encoder = unwrap_model(self.text_encoder[0])
# if self.low_vram:
# pipeline.enable_model_cpu_offload(device=self.device_torch)
pipeline = pipeline.to(self.device_torch) pipeline = pipeline.to(self.device_torch)
return pipeline return pipeline
@@ -536,6 +534,20 @@ class LTX2Model(BaseModel):
): ):
if self.model.device == torch.device("cpu"): if self.model.device == torch.device("cpu"):
self.model.to(self.device_torch) self.model.to(self.device_torch)
# handle control image
if gen_config.ctrl_img is not None:
# switch to image to video pipeline
pipeline = LTX2ImageToVideoPipeline(
scheduler=pipeline.scheduler,
vae=pipeline.vae,
audio_vae=pipeline.audio_vae,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
connectors=pipeline.connectors,
transformer=pipeline.transformer,
vocoder=pipeline.vocoder,
)
is_video = gen_config.num_frames > 1 is_video = gen_config.num_frames > 1
# override the generate single image to handle video + audio generation # override the generate single image to handle video + audio generation
@@ -554,6 +566,14 @@ class LTX2Model(BaseModel):
bd = self.get_bucket_divisibility() bd = self.get_bucket_divisibility()
gen_config.height = (gen_config.height // bd) * bd gen_config.height = (gen_config.height // bd) * bd
gen_config.width = (gen_config.width // bd) * bd gen_config.width = (gen_config.width // bd) * bd
# handle control image
if gen_config.ctrl_img is not None:
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
# resize the control image
control_img = control_img.resize((gen_config.width, gen_config.height), Image.LANCZOS)
# add the control image to the extra dict
extra["image"] = control_img
# frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc. # frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc.
if gen_config.num_frames != 1: if gen_config.num_frames != 1:
@@ -681,10 +701,40 @@ class LTX2Model(BaseModel):
with torch.no_grad(): with torch.no_grad():
if self.model.device == torch.device("cpu"): if self.model.device == torch.device("cpu"):
self.model.to(self.device_torch) self.model.to(self.device_torch)
batch_size, C, latent_num_frames, latent_height, latent_width = ( batch_size, C, latent_num_frames, latent_height, latent_width = (
latent_model_input.shape latent_model_input.shape
) )
video_timestep = timestep.clone()
# i2v from first frame
if batch.dataset_config.do_i2v:
# videos come in (bs, num_frames, channels, height, width)
# images come in (bs, channels, height, width)
frames = batch.tensor
if len(frames.shape) == 4:
first_frames = frames
elif len(frames.shape) == 5:
first_frames = frames[:, 0]
else:
raise ValueError(f"Unknown frame shape {frames.shape}")
# first frame doesnt have time dim, add it back
init_latents = self.encode_images(first_frames, device=self.device_torch, dtype=self.torch_dtype)
init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1)
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
# First condition is image latents and those should be kept clean.
conditioning_mask = torch.zeros(mask_shape, device=self.device_torch, dtype=self.torch_dtype)
conditioning_mask[:, :, 0] = 1.0
# use conditioning mask to replace latents
latent_model_input = (
latent_model_input * (1 - conditioning_mask)
+ init_latents * conditioning_mask
)
# set video timestep
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
# todo get this somehow # todo get this somehow
frame_rate = 24 frame_rate = 24
@@ -764,7 +814,8 @@ class LTX2Model(BaseModel):
audio_hidden_states=audio_latents.to(self.transformer.dtype), audio_hidden_states=audio_latents.to(self.transformer.dtype),
encoder_hidden_states=connector_prompt_embeds, encoder_hidden_states=connector_prompt_embeds,
audio_encoder_hidden_states=connector_audio_prompt_embeds, audio_encoder_hidden_states=connector_audio_prompt_embeds,
timestep=timestep, timestep=video_timestep,
audio_timestep=timestep,
encoder_attention_mask=connector_attention_mask, encoder_attention_mask=connector_attention_mask,
audio_encoder_attention_mask=connector_attention_mask, audio_encoder_attention_mask=connector_attention_mask,
num_frames=latent_num_frames, num_frames=latent_num_frames,

View File

@@ -628,7 +628,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].datasets[x].fps': [24, undefined], 'config.process[0].datasets[x].fps': [24, undefined],
}, },
disableSections: ['network.conv'], disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch'], additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v'],
}, },
].sort((a, b) => { ].sort((a, b) => {
// Sort by label, case-insensitive // Sort by label, case-insensitive

View File

@@ -1 +1 @@
VERSION = "0.7.17" VERSION = "0.7.18"