mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Support img 2 vid training for ltx-2
This commit is contained in:
@@ -20,9 +20,10 @@ from optimum.quanto import freeze
|
||||
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||
from toolkit.memory_management import MemoryManager
|
||||
from safetensors.torch import load_file
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
from diffusers import LTX2Pipeline
|
||||
from diffusers import LTX2Pipeline, LTX2ImageToVideoPipeline
|
||||
from diffusers.models.autoencoders import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
@@ -518,9 +519,6 @@ class LTX2Model(BaseModel):
|
||||
pipeline.transformer = unwrap_model(self.model)
|
||||
pipeline.text_encoder = unwrap_model(self.text_encoder[0])
|
||||
|
||||
# if self.low_vram:
|
||||
# pipeline.enable_model_cpu_offload(device=self.device_torch)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
@@ -537,6 +535,20 @@ class LTX2Model(BaseModel):
|
||||
if self.model.device == torch.device("cpu"):
|
||||
self.model.to(self.device_torch)
|
||||
|
||||
# handle control image
|
||||
if gen_config.ctrl_img is not None:
|
||||
# switch to image to video pipeline
|
||||
pipeline = LTX2ImageToVideoPipeline(
|
||||
scheduler=pipeline.scheduler,
|
||||
vae=pipeline.vae,
|
||||
audio_vae=pipeline.audio_vae,
|
||||
text_encoder=pipeline.text_encoder,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
connectors=pipeline.connectors,
|
||||
transformer=pipeline.transformer,
|
||||
vocoder=pipeline.vocoder,
|
||||
)
|
||||
|
||||
is_video = gen_config.num_frames > 1
|
||||
# override the generate single image to handle video + audio generation
|
||||
if is_video:
|
||||
@@ -555,6 +567,14 @@ class LTX2Model(BaseModel):
|
||||
gen_config.height = (gen_config.height // bd) * bd
|
||||
gen_config.width = (gen_config.width // bd) * bd
|
||||
|
||||
# handle control image
|
||||
if gen_config.ctrl_img is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
||||
# resize the control image
|
||||
control_img = control_img.resize((gen_config.width, gen_config.height), Image.LANCZOS)
|
||||
# add the control image to the extra dict
|
||||
extra["image"] = control_img
|
||||
|
||||
# frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc.
|
||||
if gen_config.num_frames != 1:
|
||||
if (gen_config.num_frames - 1) % 8 != 0:
|
||||
@@ -686,6 +706,36 @@ class LTX2Model(BaseModel):
|
||||
latent_model_input.shape
|
||||
)
|
||||
|
||||
video_timestep = timestep.clone()
|
||||
|
||||
# i2v from first frame
|
||||
if batch.dataset_config.do_i2v:
|
||||
# videos come in (bs, num_frames, channels, height, width)
|
||||
# images come in (bs, channels, height, width)
|
||||
frames = batch.tensor
|
||||
if len(frames.shape) == 4:
|
||||
first_frames = frames
|
||||
elif len(frames.shape) == 5:
|
||||
first_frames = frames[:, 0]
|
||||
else:
|
||||
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||
# first frame doesnt have time dim, add it back
|
||||
init_latents = self.encode_images(first_frames, device=self.device_torch, dtype=self.torch_dtype)
|
||||
init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1)
|
||||
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
|
||||
# First condition is image latents and those should be kept clean.
|
||||
conditioning_mask = torch.zeros(mask_shape, device=self.device_torch, dtype=self.torch_dtype)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
|
||||
# use conditioning mask to replace latents
|
||||
latent_model_input = (
|
||||
latent_model_input * (1 - conditioning_mask)
|
||||
+ init_latents * conditioning_mask
|
||||
)
|
||||
|
||||
# set video timestep
|
||||
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||
|
||||
# todo get this somehow
|
||||
frame_rate = 24
|
||||
# check frame dimension
|
||||
@@ -764,7 +814,8 @@ class LTX2Model(BaseModel):
|
||||
audio_hidden_states=audio_latents.to(self.transformer.dtype),
|
||||
encoder_hidden_states=connector_prompt_embeds,
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
|
||||
@@ -628,7 +628,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].datasets[x].fps': [24, undefined],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v'],
|
||||
},
|
||||
].sort((a, b) => {
|
||||
// Sort by label, case-insensitive
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.7.17"
|
||||
VERSION = "0.7.18"
|
||||
|
||||
Reference in New Issue
Block a user