mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Support img 2 vid training for ltx-2
This commit is contained in:
@@ -20,9 +20,10 @@ from optimum.quanto import freeze
|
|||||||
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||||
from toolkit.memory_management import MemoryManager
|
from toolkit.memory_management import MemoryManager
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from diffusers import LTX2Pipeline
|
from diffusers import LTX2Pipeline, LTX2ImageToVideoPipeline
|
||||||
from diffusers.models.autoencoders import (
|
from diffusers.models.autoencoders import (
|
||||||
AutoencoderKLLTX2Audio,
|
AutoencoderKLLTX2Audio,
|
||||||
AutoencoderKLLTX2Video,
|
AutoencoderKLLTX2Video,
|
||||||
@@ -518,9 +519,6 @@ class LTX2Model(BaseModel):
|
|||||||
pipeline.transformer = unwrap_model(self.model)
|
pipeline.transformer = unwrap_model(self.model)
|
||||||
pipeline.text_encoder = unwrap_model(self.text_encoder[0])
|
pipeline.text_encoder = unwrap_model(self.text_encoder[0])
|
||||||
|
|
||||||
# if self.low_vram:
|
|
||||||
# pipeline.enable_model_cpu_offload(device=self.device_torch)
|
|
||||||
|
|
||||||
pipeline = pipeline.to(self.device_torch)
|
pipeline = pipeline.to(self.device_torch)
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
@@ -536,6 +534,20 @@ class LTX2Model(BaseModel):
|
|||||||
):
|
):
|
||||||
if self.model.device == torch.device("cpu"):
|
if self.model.device == torch.device("cpu"):
|
||||||
self.model.to(self.device_torch)
|
self.model.to(self.device_torch)
|
||||||
|
|
||||||
|
# handle control image
|
||||||
|
if gen_config.ctrl_img is not None:
|
||||||
|
# switch to image to video pipeline
|
||||||
|
pipeline = LTX2ImageToVideoPipeline(
|
||||||
|
scheduler=pipeline.scheduler,
|
||||||
|
vae=pipeline.vae,
|
||||||
|
audio_vae=pipeline.audio_vae,
|
||||||
|
text_encoder=pipeline.text_encoder,
|
||||||
|
tokenizer=pipeline.tokenizer,
|
||||||
|
connectors=pipeline.connectors,
|
||||||
|
transformer=pipeline.transformer,
|
||||||
|
vocoder=pipeline.vocoder,
|
||||||
|
)
|
||||||
|
|
||||||
is_video = gen_config.num_frames > 1
|
is_video = gen_config.num_frames > 1
|
||||||
# override the generate single image to handle video + audio generation
|
# override the generate single image to handle video + audio generation
|
||||||
@@ -554,6 +566,14 @@ class LTX2Model(BaseModel):
|
|||||||
bd = self.get_bucket_divisibility()
|
bd = self.get_bucket_divisibility()
|
||||||
gen_config.height = (gen_config.height // bd) * bd
|
gen_config.height = (gen_config.height // bd) * bd
|
||||||
gen_config.width = (gen_config.width // bd) * bd
|
gen_config.width = (gen_config.width // bd) * bd
|
||||||
|
|
||||||
|
# handle control image
|
||||||
|
if gen_config.ctrl_img is not None:
|
||||||
|
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
||||||
|
# resize the control image
|
||||||
|
control_img = control_img.resize((gen_config.width, gen_config.height), Image.LANCZOS)
|
||||||
|
# add the control image to the extra dict
|
||||||
|
extra["image"] = control_img
|
||||||
|
|
||||||
# frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc.
|
# frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc.
|
||||||
if gen_config.num_frames != 1:
|
if gen_config.num_frames != 1:
|
||||||
@@ -681,10 +701,40 @@ class LTX2Model(BaseModel):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.model.device == torch.device("cpu"):
|
if self.model.device == torch.device("cpu"):
|
||||||
self.model.to(self.device_torch)
|
self.model.to(self.device_torch)
|
||||||
|
|
||||||
batch_size, C, latent_num_frames, latent_height, latent_width = (
|
batch_size, C, latent_num_frames, latent_height, latent_width = (
|
||||||
latent_model_input.shape
|
latent_model_input.shape
|
||||||
)
|
)
|
||||||
|
|
||||||
|
video_timestep = timestep.clone()
|
||||||
|
|
||||||
|
# i2v from first frame
|
||||||
|
if batch.dataset_config.do_i2v:
|
||||||
|
# videos come in (bs, num_frames, channels, height, width)
|
||||||
|
# images come in (bs, channels, height, width)
|
||||||
|
frames = batch.tensor
|
||||||
|
if len(frames.shape) == 4:
|
||||||
|
first_frames = frames
|
||||||
|
elif len(frames.shape) == 5:
|
||||||
|
first_frames = frames[:, 0]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||||
|
# first frame doesnt have time dim, add it back
|
||||||
|
init_latents = self.encode_images(first_frames, device=self.device_torch, dtype=self.torch_dtype)
|
||||||
|
init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1)
|
||||||
|
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
|
||||||
|
# First condition is image latents and those should be kept clean.
|
||||||
|
conditioning_mask = torch.zeros(mask_shape, device=self.device_torch, dtype=self.torch_dtype)
|
||||||
|
conditioning_mask[:, :, 0] = 1.0
|
||||||
|
|
||||||
|
# use conditioning mask to replace latents
|
||||||
|
latent_model_input = (
|
||||||
|
latent_model_input * (1 - conditioning_mask)
|
||||||
|
+ init_latents * conditioning_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# set video timestep
|
||||||
|
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||||
|
|
||||||
# todo get this somehow
|
# todo get this somehow
|
||||||
frame_rate = 24
|
frame_rate = 24
|
||||||
@@ -764,7 +814,8 @@ class LTX2Model(BaseModel):
|
|||||||
audio_hidden_states=audio_latents.to(self.transformer.dtype),
|
audio_hidden_states=audio_latents.to(self.transformer.dtype),
|
||||||
encoder_hidden_states=connector_prompt_embeds,
|
encoder_hidden_states=connector_prompt_embeds,
|
||||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||||
timestep=timestep,
|
timestep=video_timestep,
|
||||||
|
audio_timestep=timestep,
|
||||||
encoder_attention_mask=connector_attention_mask,
|
encoder_attention_mask=connector_attention_mask,
|
||||||
audio_encoder_attention_mask=connector_attention_mask,
|
audio_encoder_attention_mask=connector_attention_mask,
|
||||||
num_frames=latent_num_frames,
|
num_frames=latent_num_frames,
|
||||||
|
|||||||
@@ -628,7 +628,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].datasets[x].fps': [24, undefined],
|
'config.process[0].datasets[x].fps': [24, undefined],
|
||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
additionalSections: ['datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch'],
|
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v'],
|
||||||
},
|
},
|
||||||
].sort((a, b) => {
|
].sort((a, b) => {
|
||||||
// Sort by label, case-insensitive
|
// Sort by label, case-insensitive
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.7.17"
|
VERSION = "0.7.18"
|
||||||
|
|||||||
Reference in New Issue
Block a user