Make it easier to designate lora blocks for new models. Improve i2v adapter speed. Fix issue with i2v adapter where cached torch tensor was wrong range.

This commit is contained in:
Jaret Burkett
2025-04-13 13:49:13 -06:00
parent 6fb44db6a0
commit ca3ce0f34c
5 changed files with 67 additions and 67 deletions

View File

@@ -339,32 +339,39 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
# see if it is over threshold
if count_parameters(child_module) < parameter_threshold:
skip = True
if self.transformer_only and self.is_pixart and is_unet:
if "transformer_blocks" not in lora_name:
skip = True
if self.transformer_only and self.is_flux and is_unet:
if "transformer_blocks" not in lora_name:
skip = True
if self.transformer_only and self.is_lumina2 and is_unet:
if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name:
skip = True
if self.transformer_only and self.is_v3 and is_unet:
if "transformer_blocks" not in lora_name:
skip = True
# handle custom models
if self.transformer_only and is_unet and hasattr(root_module, 'transformer_blocks'):
if "transformer_blocks" not in lora_name:
skip = True
if self.transformer_only and is_unet:
transformer_block_names = base_model.get_transformer_block_names()
if transformer_block_names is not None:
if not any([name in lora_name for name in transformer_block_names]):
skip = True
else:
if self.is_pixart:
if "transformer_blocks" not in lora_name:
skip = True
if self.is_flux:
if "transformer_blocks" not in lora_name:
skip = True
if self.is_lumina2:
if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name:
skip = True
if self.is_v3:
if "transformer_blocks" not in lora_name:
skip = True
if self.transformer_only and is_unet and hasattr(root_module, 'blocks'):
if "blocks" not in lora_name:
skip = True
if self.transformer_only and is_unet and hasattr(root_module, 'single_blocks'):
if "single_blocks" not in lora_name and "double_blocks" not in lora_name:
skip = True
# handle custom models
if hasattr(root_module, 'transformer_blocks'):
if "transformer_blocks" not in lora_name:
skip = True
if hasattr(root_module, 'blocks'):
if "blocks" not in lora_name:
skip = True
if hasattr(root_module, 'single_blocks'):
if "single_blocks" not in lora_name and "double_blocks" not in lora_name:
skip = True
if (is_linear or is_conv2d) and not skip:

View File

@@ -5,7 +5,7 @@ import json
import random
import shutil
import typing
from typing import Union, List, Literal
from typing import Optional, Union, List, Literal
import os
from collections import OrderedDict
import copy
@@ -1478,3 +1478,7 @@ class BaseModel:
def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'):
# can be overridden in child classes to condition latents before noise prediction
return latents
def get_transformer_block_names(self) -> Optional[List[str]]:
# override in child classes to get transformer block names for lora targeting
return None

View File

@@ -411,7 +411,7 @@ class I2VAdapter(torch.nn.Module):
network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs
if hasattr(sd, 'target_lora_modules'):
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
network_kwargs['target_lin_modules'] = sd.target_lora_modules
if 'ignore_if_contains' not in network_kwargs:
network_kwargs['ignore_if_contains'] = []

View File

@@ -36,7 +36,7 @@ from toolkit.accelerator import unwrap_model
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from torchvision.transforms import Resize, ToPILImage
from tqdm import tqdm
import torch.nn.functional as F
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
# from ...callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -588,50 +588,37 @@ class Wan21(BaseModel):
if dtype is None:
dtype = self.vae_torch_dtype
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)
# move to device and dtype
image_list = [image.to(device, dtype=dtype) for image in image_list]
# We need to detect video if we have it.
# videos come in (num_frames, channels, height, width)
# images come in (channels, height, width)
# we need to add a frame dimension to images and remap the video to (channels, num_frames, height, width)
if len(image_list[0].shape) == 3:
image_list = [image.unsqueeze(1) for image in image_list]
elif len(image_list[0].shape) == 4:
image_list = [image.permute(1, 0, 2, 3) for image in image_list]
else:
raise ValueError(f"Image shape is not correct, got {list(image_list[0].shape)}")
VAE_SCALE_FACTOR = 8
# Normalize shapes
norm_images = []
for image in image_list:
if image.ndim == 3:
# (C, H, W) -> (C, 1, H, W)
norm_images.append(image.unsqueeze(1))
elif image.ndim == 4:
# (T, C, H, W) -> (C, T, H, W)
norm_images.append(image.permute(1, 0, 2, 3))
else:
raise ValueError(f"Invalid image shape: {image.shape}")
# resize images if not divisible by 8
# now we need to resize considering the shape (channels, num_frames, height, width)
for i in range(len(image_list)):
image = image_list[i]
if image.shape[2] % VAE_SCALE_FACTOR != 0 or image.shape[3] % VAE_SCALE_FACTOR != 0:
# Create resized frames by handling each frame separately
c, f, h, w = image.shape
target_h = h // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR
target_w = w // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR
# We need to process each frame separately
resized_frames = []
for frame_idx in range(f):
frame = image[:, frame_idx, :, :] # Extract single frame (channels, height, width)
resized_frame = Resize((target_h, target_w))(frame)
resized_frames.append(resized_frame.unsqueeze(1)) # Add frame dimension back
# Concatenate all frames back together along the frame dimension
image_list[i] = torch.cat(resized_frames, dim=1)
# Stack to (B, C, T, H, W)
images = torch.stack(norm_images)
B, C, T, H, W = images.shape
# Resize if needed (B * T, C, H, W)
if H % 8 != 0 or W % 8 != 0:
target_h = H // 8 * 8
target_w = W // 8 * 8
images = images.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
images = F.interpolate(images, size=(target_h, target_w), mode='bilinear', align_corners=False)
images = images.view(B, T, C, target_h, target_w).permute(0, 2, 1, 3, 4)
images = torch.stack(image_list)
# images = images.unsqueeze(2) # adds frame dimension so (bs, ch, h, w) -> (bs, ch, 1, h, w)
latents = self.vae.encode(images).latent_dist.sample()
latents_mean = (
@@ -644,9 +631,7 @@ class Wan21(BaseModel):
)
latents = (latents - latents_mean) * latents_std
latents = latents.to(device, dtype=dtype)
return latents
return latents.to(device, dtype=dtype)
def get_model_has_grad(self):
return self.model.proj_out.weight.requires_grad

View File

@@ -4,7 +4,7 @@ import json
import random
import shutil
import typing
from typing import Union, List, Literal, Iterator
from typing import Optional, Union, List, Literal, Iterator
import sys
import os
from collections import OrderedDict
@@ -3079,3 +3079,7 @@ class StableDiffusion:
def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'):
# can be overridden in child classes to condition latents before noise prediction
return latents
def get_transformer_block_names(self) -> Optional[List[str]]:
# override in child classes to get transformer block names for lora targeting
return None