From ca3ce0f34c51be43d3b585e89748e6193025b08b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 13 Apr 2025 13:49:13 -0600 Subject: [PATCH] Make it easier to designate lora blocks for new models. Improve i2v adapter speed. Fix issue with i2v adapter where cached torch tensor was wrong range. --- toolkit/lora_special.py | 55 ++++++++++++++------------ toolkit/models/base_model.py | 6 ++- toolkit/models/i2v_adapter.py | 2 +- toolkit/models/wan21/wan21.py | 65 ++++++++++++------------------- toolkit/stable_diffusion_model.py | 6 ++- 5 files changed, 67 insertions(+), 67 deletions(-) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index c6691dec..e213face 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -339,32 +339,39 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): # see if it is over threshold if count_parameters(child_module) < parameter_threshold: skip = True - - if self.transformer_only and self.is_pixart and is_unet: - if "transformer_blocks" not in lora_name: - skip = True - if self.transformer_only and self.is_flux and is_unet: - if "transformer_blocks" not in lora_name: - skip = True - if self.transformer_only and self.is_lumina2 and is_unet: - if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name: - skip = True - if self.transformer_only and self.is_v3 and is_unet: - if "transformer_blocks" not in lora_name: - skip = True - # handle custom models - if self.transformer_only and is_unet and hasattr(root_module, 'transformer_blocks'): - if "transformer_blocks" not in lora_name: - skip = True + if self.transformer_only and is_unet: + transformer_block_names = base_model.get_transformer_block_names() + + if transformer_block_names is not None: + if not any([name in lora_name for name in transformer_block_names]): + skip = True + else: + if self.is_pixart: + if "transformer_blocks" not in lora_name: + skip = True + if self.is_flux: + if "transformer_blocks" not in lora_name: + skip = True + if self.is_lumina2: + if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name: + skip = True + if self.is_v3: + if "transformer_blocks" not in lora_name: + skip = True - if self.transformer_only and is_unet and hasattr(root_module, 'blocks'): - if "blocks" not in lora_name: - skip = True - - if self.transformer_only and is_unet and hasattr(root_module, 'single_blocks'): - if "single_blocks" not in lora_name and "double_blocks" not in lora_name: - skip = True + # handle custom models + if hasattr(root_module, 'transformer_blocks'): + if "transformer_blocks" not in lora_name: + skip = True + + if hasattr(root_module, 'blocks'): + if "blocks" not in lora_name: + skip = True + + if hasattr(root_module, 'single_blocks'): + if "single_blocks" not in lora_name and "double_blocks" not in lora_name: + skip = True if (is_linear or is_conv2d) and not skip: diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 7c8a294b..ddc8ce8b 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -5,7 +5,7 @@ import json import random import shutil import typing -from typing import Union, List, Literal +from typing import Optional, Union, List, Literal import os from collections import OrderedDict import copy @@ -1478,3 +1478,7 @@ class BaseModel: def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'): # can be overridden in child classes to condition latents before noise prediction return latents + + def get_transformer_block_names(self) -> Optional[List[str]]: + # override in child classes to get transformer block names for lora targeting + return None diff --git a/toolkit/models/i2v_adapter.py b/toolkit/models/i2v_adapter.py index d51efc69..f9615eac 100644 --- a/toolkit/models/i2v_adapter.py +++ b/toolkit/models/i2v_adapter.py @@ -411,7 +411,7 @@ class I2VAdapter(torch.nn.Module): network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs if hasattr(sd, 'target_lora_modules'): - network_kwargs['target_lin_modules'] = self.sd.target_lora_modules + network_kwargs['target_lin_modules'] = sd.target_lora_modules if 'ignore_if_contains' not in network_kwargs: network_kwargs['ignore_if_contains'] = [] diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index fb7edd66..fce77dca 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -36,7 +36,7 @@ from toolkit.accelerator import unwrap_model from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler from torchvision.transforms import Resize, ToPILImage from tqdm import tqdm - +import torch.nn.functional as F from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE # from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -588,50 +588,37 @@ class Wan21(BaseModel): if dtype is None: dtype = self.vae_torch_dtype - # Move to vae to device if on cpu if self.vae.device == 'cpu': self.vae.to(device) self.vae.eval() self.vae.requires_grad_(False) - # move to device and dtype + image_list = [image.to(device, dtype=dtype) for image in image_list] - - # We need to detect video if we have it. - # videos come in (num_frames, channels, height, width) - # images come in (channels, height, width) - # we need to add a frame dimension to images and remap the video to (channels, num_frames, height, width) - - if len(image_list[0].shape) == 3: - image_list = [image.unsqueeze(1) for image in image_list] - elif len(image_list[0].shape) == 4: - image_list = [image.permute(1, 0, 2, 3) for image in image_list] - else: - raise ValueError(f"Image shape is not correct, got {list(image_list[0].shape)}") - VAE_SCALE_FACTOR = 8 + # Normalize shapes + norm_images = [] + for image in image_list: + if image.ndim == 3: + # (C, H, W) -> (C, 1, H, W) + norm_images.append(image.unsqueeze(1)) + elif image.ndim == 4: + # (T, C, H, W) -> (C, T, H, W) + norm_images.append(image.permute(1, 0, 2, 3)) + else: + raise ValueError(f"Invalid image shape: {image.shape}") - # resize images if not divisible by 8 - # now we need to resize considering the shape (channels, num_frames, height, width) - for i in range(len(image_list)): - image = image_list[i] - if image.shape[2] % VAE_SCALE_FACTOR != 0 or image.shape[3] % VAE_SCALE_FACTOR != 0: - # Create resized frames by handling each frame separately - c, f, h, w = image.shape - target_h = h // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR - target_w = w // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR - - # We need to process each frame separately - resized_frames = [] - for frame_idx in range(f): - frame = image[:, frame_idx, :, :] # Extract single frame (channels, height, width) - resized_frame = Resize((target_h, target_w))(frame) - resized_frames.append(resized_frame.unsqueeze(1)) # Add frame dimension back - - # Concatenate all frames back together along the frame dimension - image_list[i] = torch.cat(resized_frames, dim=1) + # Stack to (B, C, T, H, W) + images = torch.stack(norm_images) + B, C, T, H, W = images.shape + + # Resize if needed (B * T, C, H, W) + if H % 8 != 0 or W % 8 != 0: + target_h = H // 8 * 8 + target_w = W // 8 * 8 + images = images.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W) + images = F.interpolate(images, size=(target_h, target_w), mode='bilinear', align_corners=False) + images = images.view(B, T, C, target_h, target_w).permute(0, 2, 1, 3, 4) - images = torch.stack(image_list) - # images = images.unsqueeze(2) # adds frame dimension so (bs, ch, h, w) -> (bs, ch, 1, h, w) latents = self.vae.encode(images).latent_dist.sample() latents_mean = ( @@ -644,9 +631,7 @@ class Wan21(BaseModel): ) latents = (latents - latents_mean) * latents_std - latents = latents.to(device, dtype=dtype) - - return latents + return latents.to(device, dtype=dtype) def get_model_has_grad(self): return self.model.proj_out.weight.requires_grad diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 8a05a908..5f66db51 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -4,7 +4,7 @@ import json import random import shutil import typing -from typing import Union, List, Literal, Iterator +from typing import Optional, Union, List, Literal, Iterator import sys import os from collections import OrderedDict @@ -3079,3 +3079,7 @@ class StableDiffusion: def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'): # can be overridden in child classes to condition latents before noise prediction return latents + + def get_transformer_block_names(self) -> Optional[List[str]]: + # override in child classes to get transformer block names for lora targeting + return None