diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index cc08cdf6..682c87f7 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -618,277 +618,6 @@ class SDTrainer(BaseSDTrainProcess): return loss - def get_guided_loss_targeted_polarity( - self, - noisy_latents: torch.Tensor, - conditional_embeds: PromptEmbeds, - match_adapter_assist: bool, - network_weight_list: list, - timesteps: torch.Tensor, - pred_kwargs: dict, - batch: 'DataLoaderBatchDTO', - noise: torch.Tensor, - **kwargs - ): - with torch.no_grad(): - # Perform targeted guidance (working title) - dtype = get_torch_dtype(self.train_config.dtype) - - conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach() - unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach() - - mean_latents = (conditional_latents + unconditional_latents) / 2.0 - - unconditional_diff = (unconditional_latents - mean_latents) - conditional_diff = (conditional_latents - mean_latents) - - # we need to determine the amount of signal and noise that would be present at the current timestep - # conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps) - # unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps) - # unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps) - # conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps) - # unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps) - - # target_noise = noise + unconditional_signal - - conditional_noisy_latents = self.sd.add_noise( - mean_latents, - noise, - timesteps - ).detach() - - unconditional_noisy_latents = self.sd.add_noise( - mean_latents, - noise, - timesteps - ).detach() - - # Disable the LoRA network so we can predict parent network knowledge without it - self.network.is_active = False - self.sd.unet.eval() - - # Predict noise to get a baseline of what the parent network wants to do with the latents + noise. - # This acts as our control to preserve the unaltered parts of the image. - baseline_prediction = self.sd.predict_noise( - latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), - conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), - timestep=timesteps, - guidance_scale=1.0, - **pred_kwargs # adapter residuals in here - ).detach() - - # double up everything to run it through all at once - cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) - cat_latents = torch.cat([conditional_noisy_latents, conditional_noisy_latents], dim=0) - cat_timesteps = torch.cat([timesteps, timesteps], dim=0) - - # since we are dividing the polarity from the middle out, we need to double our network - # weights on training since the convergent point will be at half network strength - - negative_network_weights = [weight * -2.0 for weight in network_weight_list] - positive_network_weights = [weight * 2.0 for weight in network_weight_list] - cat_network_weight_list = positive_network_weights + negative_network_weights - - # turn the LoRA network back on. - self.sd.unet.train() - self.network.is_active = True - - self.network.multiplier = cat_network_weight_list - - # do our prediction with LoRA active on the scaled guidance latents - prediction = self.sd.predict_noise( - latents=cat_latents.to(self.device_torch, dtype=dtype).detach(), - conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(), - timestep=cat_timesteps, - guidance_scale=1.0, - **pred_kwargs # adapter residuals in here - ) - - pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) - - pred_pos = pred_pos - baseline_prediction - pred_neg = pred_neg - baseline_prediction - - pred_loss = torch.nn.functional.mse_loss( - pred_pos.float(), - unconditional_diff.float(), - reduction="none" - ) - pred_loss = pred_loss.mean([1, 2, 3]) - - pred_neg_loss = torch.nn.functional.mse_loss( - pred_neg.float(), - conditional_diff.float(), - reduction="none" - ) - pred_neg_loss = pred_neg_loss.mean([1, 2, 3]) - - loss = (pred_loss + pred_neg_loss) / 2.0 - - # loss = self.apply_snr(loss, timesteps) - loss = loss.mean() - self.accelerator.backward(loss) - - # detach it so parent class can run backward on no grads without throwing error - loss = loss.detach() - loss.requires_grad_(True) - - return loss - - def get_guided_loss_masked_polarity( - self, - noisy_latents: torch.Tensor, - conditional_embeds: PromptEmbeds, - match_adapter_assist: bool, - network_weight_list: list, - timesteps: torch.Tensor, - pred_kwargs: dict, - batch: 'DataLoaderBatchDTO', - noise: torch.Tensor, - **kwargs - ): - with torch.no_grad(): - # Perform targeted guidance (working title) - dtype = get_torch_dtype(self.train_config.dtype) - - conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach() - unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach() - inverse_latents = unconditional_latents - (conditional_latents - unconditional_latents) - - mean_latents = (conditional_latents + unconditional_latents) / 2.0 - - # unconditional_diff = (unconditional_latents - mean_latents) - # conditional_diff = (conditional_latents - mean_latents) - - # we need to determine the amount of signal and noise that would be present at the current timestep - # conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps) - # unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps) - # unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps) - # conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps) - # unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps) - - # make a differential mask - differential_mask = torch.abs(conditional_latents - unconditional_latents) - max_differential = \ - differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] - differential_scaler = 1.0 / max_differential - differential_mask = differential_mask * differential_scaler - spread_point = 0.1 - # adjust mask to amplify the differential at 0.1 - differential_mask = ((differential_mask - spread_point) * 10.0) + spread_point - # clip it - differential_mask = torch.clamp(differential_mask, 0.0, 1.0) - - # target_noise = noise + unconditional_signal - - conditional_noisy_latents = self.sd.add_noise( - conditional_latents, - noise, - timesteps - ).detach() - - unconditional_noisy_latents = self.sd.add_noise( - unconditional_latents, - noise, - timesteps - ).detach() - - inverse_noisy_latents = self.sd.add_noise( - inverse_latents, - noise, - timesteps - ).detach() - - # Disable the LoRA network so we can predict parent network knowledge without it - self.network.is_active = False - self.sd.unet.eval() - - # Predict noise to get a baseline of what the parent network wants to do with the latents + noise. - # This acts as our control to preserve the unaltered parts of the image. - # baseline_prediction = self.sd.predict_noise( - # latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), - # conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), - # timestep=timesteps, - # guidance_scale=1.0, - # **pred_kwargs # adapter residuals in here - # ).detach() - - # double up everything to run it through all at once - cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) - cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) - cat_timesteps = torch.cat([timesteps, timesteps], dim=0) - - # since we are dividing the polarity from the middle out, we need to double our network - # weights on training since the convergent point will be at half network strength - - negative_network_weights = [weight * -1.0 for weight in network_weight_list] - positive_network_weights = [weight * 1.0 for weight in network_weight_list] - cat_network_weight_list = positive_network_weights + negative_network_weights - - # turn the LoRA network back on. - self.sd.unet.train() - self.network.is_active = True - - self.network.multiplier = cat_network_weight_list - - # do our prediction with LoRA active on the scaled guidance latents - prediction = self.sd.predict_noise( - latents=cat_latents.to(self.device_torch, dtype=dtype).detach(), - conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(), - timestep=cat_timesteps, - guidance_scale=1.0, - **pred_kwargs # adapter residuals in here - ) - - pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) - - # create a loss to balance the mean to 0 between the two predictions - differential_mean_pred_loss = torch.abs(pred_pos - pred_neg).mean([1, 2, 3]) ** 2.0 - - # pred_pos = pred_pos - baseline_prediction - # pred_neg = pred_neg - baseline_prediction - - pred_loss = torch.nn.functional.mse_loss( - pred_pos.float(), - noise.float(), - reduction="none" - ) - # apply mask - pred_loss = pred_loss * (1.0 + differential_mask) - pred_loss = pred_loss.mean([1, 2, 3]) - - pred_neg_loss = torch.nn.functional.mse_loss( - pred_neg.float(), - noise.float(), - reduction="none" - ) - # apply inverse mask - pred_neg_loss = pred_neg_loss * (1.0 - differential_mask) - pred_neg_loss = pred_neg_loss.mean([1, 2, 3]) - - # make a loss to balance to losses of the pos and neg so they are equal - # differential_mean_loss_loss = torch.abs(pred_loss - pred_neg_loss) - # - # differential_mean_loss = differential_mean_pred_loss + differential_mean_loss_loss - # - # # add a multiplier to balancing losses to make them the top priority - # differential_mean_loss = differential_mean_loss - - # remove the grads from the negative as it is only a balancing loss - # pred_neg_loss = pred_neg_loss.detach() - - # loss = pred_loss + pred_neg_loss + differential_mean_loss - loss = pred_loss + pred_neg_loss - - # loss = self.apply_snr(loss, timesteps) - loss = loss.mean() - self.accelerator.backward(loss) - - # detach it so parent class can run backward on no grads without throwing error - loss = loss.detach() - loss.requires_grad_(True) - - return loss def get_prior_prediction( self, @@ -985,6 +714,7 @@ class SDTrainer(BaseSDTrainProcess): timestep=timesteps, guidance_scale=self.train_config.cfg_scale, rescale_cfg=self.train_config.cfg_rescale, + batch=batch, **pred_kwargs # adapter residuals in here ) if was_unet_training: @@ -1021,6 +751,7 @@ class SDTrainer(BaseSDTrainProcess): timesteps: Union[int, torch.Tensor] = 1, conditional_embeds: Union[PromptEmbeds, None] = None, unconditional_embeds: Union[PromptEmbeds, None] = None, + batch: Optional['DataLoaderBatchDTO'] = None, **kwargs, ): dtype = get_torch_dtype(self.train_config.dtype) @@ -1034,6 +765,7 @@ class SDTrainer(BaseSDTrainProcess): detach_unconditional=False, rescale_cfg=self.train_config.cfg_rescale, bypass_guidance_embedding=self.train_config.bypass_guidance_embedding, + batch=batch, **kwargs ) @@ -1690,6 +1422,7 @@ class SDTrainer(BaseSDTrainProcess): timesteps=timesteps, conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), unconditional_embeds=unconditional_embeds, + batch=batch, **pred_kwargs ) self.after_unet_predict() @@ -1723,6 +1456,7 @@ class SDTrainer(BaseSDTrainProcess): timesteps=timesteps, conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype), unconditional_embeds=unconditional_embeds, + batch=batch, **pred_kwargs ) dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 3c7f57ed..51b60170 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1313,6 +1313,10 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.network_config is not None: adapter_name = f"{adapter_name}_{suffix}" latest_save_path = self.get_latest_save_path(adapter_name) + + if latest_save_path is not None and not self.adapter_config.train: + # the save path is for something else since we are not training + latest_save_path = self.adapter_config.name_or_path dtype = get_torch_dtype(self.train_config.dtype) if is_t2i: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index f76def92..6346d554 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -540,6 +540,11 @@ class ModelConfig: self.arch: ModelArch = kwargs.get("arch", None) + # can be used to load the extras like text encoder or vae from here + # only setup for some models but will prevent having to download the te for + # 20 different model variants + self.extras_name_or_path = kwargs.get("extras_name_or_path", self.name_or_path) + # kwargs to pass to the model self.model_kwargs = kwargs.get("model_kwargs", {}) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 6f8be862..d2e4572d 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -1161,8 +1161,9 @@ class CustomAdapter(torch.nn.Module): else: with torch.no_grad(): self.vision_encoder.eval() + self.vision_encoder.to(self.device) clip_output = self.vision_encoder( - clip_image, + clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)), output_hidden_states=True, ) if self.config.clip_layer == 'penultimate_hidden_states': diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 0a668960..7c8a294b 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -703,6 +703,7 @@ class BaseModel: return_conditional_pred=False, guidance_embedding_scale=1.0, bypass_guidance_embedding=False, + batch: Union[None, 'DataLoaderBatchDTO'] = None, **kwargs, ): conditional_pred = None @@ -821,6 +822,8 @@ class BaseModel: kwargs['guidance_embedding_scale'] = guidance_embedding_scale if 'bypass_guidance_embedding' in signatures: kwargs['bypass_guidance_embedding'] = bypass_guidance_embedding + if 'batch' in signatures: + kwargs['batch'] = batch noise_pred = self.get_noise_prediction( latent_model_input=latent_model_input, diff --git a/toolkit/models/wan21/__init__.py b/toolkit/models/wan21/__init__.py index 9e2aa3ca..8c2706a1 100644 --- a/toolkit/models/wan21/__init__.py +++ b/toolkit/models/wan21/__init__.py @@ -1 +1,2 @@ -from .wan21 import Wan21 \ No newline at end of file +from .wan21 import Wan21 +from .wan21_i2v import Wan21I2V \ No newline at end of file diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index 48992cfb..fb7edd66 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -330,23 +330,22 @@ class Wan21(BaseModel): def load_model(self): dtype = self.torch_dtype - # todo , will this work with other wan models? - base_model_path = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" model_path = self.model_config.name_or_path self.print_and_status_update("Loading Wan2.1 model") - # base_model_path = "black-forest-labs/FLUX.1-schnell" - base_model_path = self.model_config.name_or_path_original subfolder = 'transformer' transformer_path = model_path if os.path.exists(transformer_path): subfolder = None transformer_path = os.path.join(transformer_path, 'transformer') - # check if the path is a full checkpoint. - te_folder_path = os.path.join(model_path, 'text_encoder') - # if we have the te, this folder is a full checkpoint, use it as the base - if os.path.exists(te_folder_path): - base_model_path = model_path + + te_path = self.model_config.extras_name_or_path + if os.path.exists(os.path.join(model_path, 'text_encoder')): + te_path = model_path + + vae_path = self.model_config.extras_name_or_path + if os.path.exists(os.path.join(model_path, 'vae')): + vae_path = model_path self.print_and_status_update("Loading transformer") transformer = WanTransformer3DModel.from_pretrained( @@ -420,9 +419,9 @@ class Wan21(BaseModel): self.print_and_status_update("Loading UMT5EncoderModel") tokenizer = AutoTokenizer.from_pretrained( - base_model_path, subfolder="tokenizer", torch_dtype=dtype) + te_path, subfolder="tokenizer", torch_dtype=dtype) text_encoder = UMT5EncoderModel.from_pretrained( - base_model_path, subfolder="text_encoder", torch_dtype=dtype).to(dtype=dtype) + te_path, subfolder="text_encoder", torch_dtype=dtype).to(dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype) flush() @@ -442,7 +441,7 @@ class Wan21(BaseModel): self.print_and_status_update("Loading VAE") # todo, example does float 32? check if quality suffers vae = AutoencoderKLWan.from_pretrained( - base_model_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype) + vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype) flush() self.print_and_status_update("Making pipe") diff --git a/toolkit/models/wan21/wan21_i2v.py b/toolkit/models/wan21/wan21_i2v.py new file mode 100644 index 00000000..ace2593e --- /dev/null +++ b/toolkit/models/wan21/wan21_i2v.py @@ -0,0 +1,540 @@ +# WIP, coming soon ish +from functools import partial +import torch +import yaml +from toolkit.accelerator import unwrap_model +from toolkit.basic import flush +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.prompt_utils import PromptEmbeds +from toolkit.paths import REPOS_ROOT +from transformers import AutoTokenizer, UMT5EncoderModel +from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, WanTransformer3DModel +import os +import sys + +import weakref +import torch +import yaml +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.prompt_utils import PromptEmbeds + +import os +import copy +from toolkit.config_modules import ModelConfig, GenerateImageConfig +import torch +from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler +from transformers import CLIPVisionModel, CLIPImageProcessor +import torch.nn.functional as F + +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from typing import Any, Callable, Dict, List, Optional, Union +from diffusers.video_processor import VideoProcessor +from diffusers.image_processor import PipelineImageInput +from PIL import Image + +from .wan21 import \ + scheduler_configUniPC, \ + scheduler_config, \ + Wan21 + + +class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline): + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + image_encoder: CLIPVisionModel, + image_processor: CLIPImageProcessor, + transformer: WanTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + device: torch.device = torch.device("cuda"), + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + image_processor=image_processor, + ) + self._exec_device = device + + @property + def _execution_device(self): + return self._exec_device + + @torch.no_grad() + def __call__( + self, + image: PipelineImageInput, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # unload vae and transformer + device = self.transformer.device + + self.text_encoder.to(device) + + self.vae.to('cpu') + self.image_encoder.to('cpu') + flush() + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # unload text encoder + print("Unloading text encoder") + self.text_encoder.to("cpu") + self.transformer.to(device) + flush() + + # Encode image embedding + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + self.image_encoder.to(device) + self.vae.to(device) + image_embeds = self.encode_image(image) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + latents, condition = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.bfloat16, + device, + generator, + latents, + ) + self.image_encoder.to('cpu') + self.vae.to('cpu') + flush() + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + self.vae.to(device) + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) + + def encode_image(self, image: PipelineImageInput): + image = self.image_processor(images=image, return_tensors="pt") + image = {k: v.to(self.image_encoder.device, dtype=self.image_encoder.dtype) for k, v in image.items()} + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + +class Wan21I2V(Wan21): + arch = 'wan21_i2v' + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, model_config, dtype, + custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['WanTransformer3DModel'] + self.image_encoder: CLIPVisionModel = None + self.image_processor: CLIPImageProcessor = None + + def load_model(self): + # call the super class to load most of the model + super().load_model() + if self.model_config.low_vram: + # unload text encoder + self.text_encoder.to("cpu") + # all the base stuff is loaded. We now need to load the vision encoder stuff + dtype = self.torch_dtype + try: + self.image_processor = CLIPImageProcessor.from_pretrained( + self.model_config.extras_name_or_path , + subfolder="image_processor" + ) + self.image_encoder = CLIPVisionModel.from_pretrained( + self.model_config.extras_name_or_path, + subfolder="image_encoder", + torch_dtype=dtype, + ) + except Exception as e: + # load from name_or_path + self.image_processor = CLIPImageProcessor.from_pretrained( + self.model_config.name_or_path_original, + subfolder="image_processor" + ) + self.image_encoder = CLIPVisionModel.from_pretrained( + self.model_config.name_or_path_original, + subfolder="image_encoder", + torch_dtype=dtype, + ) + self.image_encoder.to(self.device_torch, dtype=dtype) + self.image_encoder.eval() + self.image_encoder.requires_grad_(False) + + if self.model_config.low_vram: + # unload image encoder + self.image_encoder.to("cpu") + + # rebuild the pipeline + self.pipeline = self.get_generation_pipeline() + flush() + + def generate_images( + self, + image_configs, + sampler=None, + pipeline=None, + ): + # will oom on 24gb vram if we dont unload vision encoder first + if self.model_config.low_vram: + # unload image encoder + self.image_encoder.to("cpu") + self.vae.to("cpu") + self.transformer.to("cpu") + flush() + super().generate_images( + image_configs, + sampler=sampler, + pipeline=pipeline, + ) + + def set_device_state_preset(self, *args, **kwargs): + # set the device state to cpu for the image encoder + if self.model_config.low_vram: + return + super().set_device_state_preset(*args, **kwargs) + + + def get_generation_pipeline(self): + scheduler = UniPCMultistepScheduler(**scheduler_configUniPC) + if self.model_config.low_vram: + pipeline = AggressiveWanI2VUnloadPipeline( + vae=self.vae, + transformer=self.model, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + image_encoder=self.image_encoder, + image_processor=self.image_processor, + device=self.device_torch + ) + else: + pipeline = WanImageToVideoPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + image_encoder=self.image_encoder, + image_processor=self.image_processor, + ) + + # pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: WanImageToVideoPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # reactivate progress bar since this is slooooow + pipeline.set_progress_bar_config(disable=False) + # pipeline = pipeline.to(self.device_torch) + + + if gen_config.ctrl_img is None: + raise ValueError("I2V samples must have a control image") + + control_img = Image.open(gen_config.ctrl_img).convert("RGB") + + height = gen_config.height + width = gen_config.width + + # make sure they are divisible by 16 + height = height // 16 * 16 + width = width // 16 * 16 + + # resize the control image + control_img = control_img.resize((width, height), Image.LANCZOS) + + output = pipeline( + image=control_img, + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + height=height, + width=width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + num_frames=gen_config.num_frames, + generator=generator, + return_dict=False, + output_type="pil", + **extra + )[0] + + # shape = [1, frames, channels, height, width] + batch_item = output[0] # list of pil images + if gen_config.num_frames > 1: + return batch_item # return the frames. + else: + # get just the first image + img = batch_item[0] + return img + + + def preprocess_clip_image(self, image_n1p1): + # tensor shape: (bs, ch, height, width) with values in range [-1, 1] + # Convert from [-1, 1] to [0, 1] range + tensor = (image_n1p1 + 1) / 2 + + # Resize to 224x224 (using bilinear interpolation, which is resample=3 in PIL) + if tensor.shape[2] != 224 or tensor.shape[3] != 224: + tensor = F.interpolate(tensor, size=(224, 224), mode='bilinear', align_corners=False) + + # Normalize with mean and std + mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(tensor.device) + std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(tensor.device) + tensor = (tensor - mean) / std + + return tensor + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: DataLoaderBatchDTO, + **kwargs + ): + # videos come in (bs, num_frames, channels, height, width) + # images come in (bs, channels, height, width) + with torch.no_grad(): + frames = batch.tensor + if len(frames.shape) == 4: + first_frames = frames + elif len(frames.shape) == 5: + first_frames = frames[:, 0] + else: + raise ValueError(f"Unknown frame shape {frames.shape}") + + # first_frames shape is (bs, channels, height, width), -1 to 1 + preprocessed_frames = self.preprocess_clip_image(first_frames) + preprocessed_frames = preprocessed_frames.to(self.device_torch, dtype=self.torch_dtype) + # preprocessed_frame shape is (bs, 3, 224, 224) + self.image_encoder.to(self.device_torch) + image_embeds_full = self.image_encoder(preprocessed_frames, output_hidden_states=True) + image_embeds = image_embeds_full.hidden_states[-2] + image_embeds = image_embeds.to(self.device_torch, dtype=self.torch_dtype) + + # condition latent + # first_frames shape is (bs, channels, height, width) + # wan needs latends in (bs, channels, num_frames, height, width) + first_frames = first_frames.unsqueeze(2) + # video condition is first frame is the frame, the rest are zeros + num_frames = frames.shape[1] + + zero_frame = torch.zeros_like(first_frames) + video_condition = torch.cat([ + first_frames, + *[zero_frame for _ in range(num_frames - 1)] + ], dim=2) + + # our vae encoder expects (bs, num_frames, channels, height, width) + # permute to (bs, channels, num_frames, height, width) + video_condition = video_condition.permute(0, 2, 1, 3, 4) + + latent_condition = self.encode_images( + video_condition, + device=self.device_torch, + dtype=self.torch_dtype, + ) + latent_condition = latent_condition.to(self.device_torch, dtype=self.torch_dtype) + + batch_size = frames.shape[0] + latent_height = latent_condition.shape[3] + latent_width = latent_condition.shape[4] + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.pipeline.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.pipeline.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(self.device_torch, dtype=self.torch_dtype) + + # return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + first_frame_condition = torch.concat([mask_lat_size, latent_condition], dim=1) + conditioned_latent = torch.cat([latent_model_input, first_frame_condition], dim=1) + + noise_pred = self.model( + hidden_states=conditioned_latent, + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds, + encoder_hidden_states_image=image_embeds, + return_dict=False, + **kwargs + )[0] + return noise_pred diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index de5f81dd..a08b043e 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1831,6 +1831,7 @@ class StableDiffusion: return_conditional_pred=False, guidance_embedding_scale=1.0, bypass_guidance_embedding=False, + batch: Union[None, 'DataLoaderBatchDTO'] = None, **kwargs, ): conditional_pred = None diff --git a/toolkit/util/get_model.py b/toolkit/util/get_model.py index 280fc632..545175dd 100644 --- a/toolkit/util/get_model.py +++ b/toolkit/util/get_model.py @@ -7,11 +7,12 @@ from toolkit.paths import TOOLKIT_ROOT import importlib import pkgutil -from toolkit.models.wan21 import Wan21 +from toolkit.models.wan21 import Wan21, Wan21I2V from toolkit.models.cogview4 import CogView4 BUILT_IN_MODELS = [ Wan21, + Wan21I2V, CogView4, ]