diff --git a/extensions_built_in/diffusion_models/ltx2/ltx2.py b/extensions_built_in/diffusion_models/ltx2/ltx2.py index ddd9568d..37f50567 100644 --- a/extensions_built_in/diffusion_models/ltx2/ltx2.py +++ b/extensions_built_in/diffusion_models/ltx2/ltx2.py @@ -20,9 +20,10 @@ from optimum.quanto import freeze from toolkit.util.quantize import quantize, get_qtype, quantize_model from toolkit.memory_management import MemoryManager from safetensors.torch import load_file +from PIL import Image try: - from diffusers import LTX2Pipeline + from diffusers import LTX2Pipeline, LTX2ImageToVideoPipeline from diffusers.models.autoencoders import ( AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, @@ -518,9 +519,6 @@ class LTX2Model(BaseModel): pipeline.transformer = unwrap_model(self.model) pipeline.text_encoder = unwrap_model(self.text_encoder[0]) - # if self.low_vram: - # pipeline.enable_model_cpu_offload(device=self.device_torch) - pipeline = pipeline.to(self.device_torch) return pipeline @@ -536,6 +534,20 @@ class LTX2Model(BaseModel): ): if self.model.device == torch.device("cpu"): self.model.to(self.device_torch) + + # handle control image + if gen_config.ctrl_img is not None: + # switch to image to video pipeline + pipeline = LTX2ImageToVideoPipeline( + scheduler=pipeline.scheduler, + vae=pipeline.vae, + audio_vae=pipeline.audio_vae, + text_encoder=pipeline.text_encoder, + tokenizer=pipeline.tokenizer, + connectors=pipeline.connectors, + transformer=pipeline.transformer, + vocoder=pipeline.vocoder, + ) is_video = gen_config.num_frames > 1 # override the generate single image to handle video + audio generation @@ -554,6 +566,14 @@ class LTX2Model(BaseModel): bd = self.get_bucket_divisibility() gen_config.height = (gen_config.height // bd) * bd gen_config.width = (gen_config.width // bd) * bd + + # handle control image + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img).convert("RGB") + # resize the control image + control_img = control_img.resize((gen_config.width, gen_config.height), Image.LANCZOS) + # add the control image to the extra dict + extra["image"] = control_img # frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc. if gen_config.num_frames != 1: @@ -681,10 +701,40 @@ class LTX2Model(BaseModel): with torch.no_grad(): if self.model.device == torch.device("cpu"): self.model.to(self.device_torch) - + batch_size, C, latent_num_frames, latent_height, latent_width = ( latent_model_input.shape ) + + video_timestep = timestep.clone() + + # i2v from first frame + if batch.dataset_config.do_i2v: + # videos come in (bs, num_frames, channels, height, width) + # images come in (bs, channels, height, width) + frames = batch.tensor + if len(frames.shape) == 4: + first_frames = frames + elif len(frames.shape) == 5: + first_frames = frames[:, 0] + else: + raise ValueError(f"Unknown frame shape {frames.shape}") + # first frame doesnt have time dim, add it back + init_latents = self.encode_images(first_frames, device=self.device_torch, dtype=self.torch_dtype) + init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1) + mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width) + # First condition is image latents and those should be kept clean. + conditioning_mask = torch.zeros(mask_shape, device=self.device_torch, dtype=self.torch_dtype) + conditioning_mask[:, :, 0] = 1.0 + + # use conditioning mask to replace latents + latent_model_input = ( + latent_model_input * (1 - conditioning_mask) + + init_latents * conditioning_mask + ) + + # set video timestep + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) # todo get this somehow frame_rate = 24 @@ -764,7 +814,8 @@ class LTX2Model(BaseModel): audio_hidden_states=audio_latents.to(self.transformer.dtype), encoder_hidden_states=connector_prompt_embeds, audio_encoder_hidden_states=connector_audio_prompt_embeds, - timestep=timestep, + timestep=video_timestep, + audio_timestep=timestep, encoder_attention_mask=connector_attention_mask, audio_encoder_attention_mask=connector_attention_mask, num_frames=latent_num_frames, diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 992d677b..d10875fe 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -628,7 +628,7 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].datasets[x].fps': [24, undefined], }, disableSections: ['network.conv'], - additionalSections: ['datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v'], }, ].sort((a, b) => { // Sort by label, case-insensitive diff --git a/version.py b/version.py index 3db333f0..63d7599f 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.7.17" +VERSION = "0.7.18"