mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Make it easier to designate lora blocks for new models. Improve i2v adapter speed. Fix issue with i2v adapter where cached torch tensor was wrong range.
This commit is contained in:
@@ -339,32 +339,39 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
# see if it is over threshold
|
||||
if count_parameters(child_module) < parameter_threshold:
|
||||
skip = True
|
||||
|
||||
if self.transformer_only and self.is_pixart and is_unet:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
if self.transformer_only and self.is_flux and is_unet:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
if self.transformer_only and self.is_lumina2 and is_unet:
|
||||
if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name:
|
||||
skip = True
|
||||
if self.transformer_only and self.is_v3 and is_unet:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
# handle custom models
|
||||
if self.transformer_only and is_unet and hasattr(root_module, 'transformer_blocks'):
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
if self.transformer_only and is_unet:
|
||||
transformer_block_names = base_model.get_transformer_block_names()
|
||||
|
||||
if transformer_block_names is not None:
|
||||
if not any([name in lora_name for name in transformer_block_names]):
|
||||
skip = True
|
||||
else:
|
||||
if self.is_pixart:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
if self.is_flux:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
if self.is_lumina2:
|
||||
if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name:
|
||||
skip = True
|
||||
if self.is_v3:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
if self.transformer_only and is_unet and hasattr(root_module, 'blocks'):
|
||||
if "blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
if self.transformer_only and is_unet and hasattr(root_module, 'single_blocks'):
|
||||
if "single_blocks" not in lora_name and "double_blocks" not in lora_name:
|
||||
skip = True
|
||||
# handle custom models
|
||||
if hasattr(root_module, 'transformer_blocks'):
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
if hasattr(root_module, 'blocks'):
|
||||
if "blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
if hasattr(root_module, 'single_blocks'):
|
||||
if "single_blocks" not in lora_name and "double_blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
if (is_linear or is_conv2d) and not skip:
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
import random
|
||||
import shutil
|
||||
import typing
|
||||
from typing import Union, List, Literal
|
||||
from typing import Optional, Union, List, Literal
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
import copy
|
||||
@@ -1478,3 +1478,7 @@ class BaseModel:
|
||||
def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'):
|
||||
# can be overridden in child classes to condition latents before noise prediction
|
||||
return latents
|
||||
|
||||
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||
# override in child classes to get transformer block names for lora targeting
|
||||
return None
|
||||
|
||||
@@ -411,7 +411,7 @@ class I2VAdapter(torch.nn.Module):
|
||||
|
||||
network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs
|
||||
if hasattr(sd, 'target_lora_modules'):
|
||||
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
|
||||
network_kwargs['target_lin_modules'] = sd.target_lora_modules
|
||||
|
||||
if 'ignore_if_contains' not in network_kwargs:
|
||||
network_kwargs['ignore_if_contains'] = []
|
||||
|
||||
@@ -36,7 +36,7 @@ from toolkit.accelerator import unwrap_model
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from torchvision.transforms import Resize, ToPILImage
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch.nn.functional as F
|
||||
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
||||
from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
|
||||
# from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
@@ -588,50 +588,37 @@ class Wan21(BaseModel):
|
||||
if dtype is None:
|
||||
dtype = self.vae_torch_dtype
|
||||
|
||||
# Move to vae to device if on cpu
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(device)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
# move to device and dtype
|
||||
|
||||
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
||||
|
||||
# We need to detect video if we have it.
|
||||
# videos come in (num_frames, channels, height, width)
|
||||
# images come in (channels, height, width)
|
||||
# we need to add a frame dimension to images and remap the video to (channels, num_frames, height, width)
|
||||
|
||||
if len(image_list[0].shape) == 3:
|
||||
image_list = [image.unsqueeze(1) for image in image_list]
|
||||
elif len(image_list[0].shape) == 4:
|
||||
image_list = [image.permute(1, 0, 2, 3) for image in image_list]
|
||||
else:
|
||||
raise ValueError(f"Image shape is not correct, got {list(image_list[0].shape)}")
|
||||
|
||||
VAE_SCALE_FACTOR = 8
|
||||
# Normalize shapes
|
||||
norm_images = []
|
||||
for image in image_list:
|
||||
if image.ndim == 3:
|
||||
# (C, H, W) -> (C, 1, H, W)
|
||||
norm_images.append(image.unsqueeze(1))
|
||||
elif image.ndim == 4:
|
||||
# (T, C, H, W) -> (C, T, H, W)
|
||||
norm_images.append(image.permute(1, 0, 2, 3))
|
||||
else:
|
||||
raise ValueError(f"Invalid image shape: {image.shape}")
|
||||
|
||||
# resize images if not divisible by 8
|
||||
# now we need to resize considering the shape (channels, num_frames, height, width)
|
||||
for i in range(len(image_list)):
|
||||
image = image_list[i]
|
||||
if image.shape[2] % VAE_SCALE_FACTOR != 0 or image.shape[3] % VAE_SCALE_FACTOR != 0:
|
||||
# Create resized frames by handling each frame separately
|
||||
c, f, h, w = image.shape
|
||||
target_h = h // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR
|
||||
target_w = w // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR
|
||||
|
||||
# We need to process each frame separately
|
||||
resized_frames = []
|
||||
for frame_idx in range(f):
|
||||
frame = image[:, frame_idx, :, :] # Extract single frame (channels, height, width)
|
||||
resized_frame = Resize((target_h, target_w))(frame)
|
||||
resized_frames.append(resized_frame.unsqueeze(1)) # Add frame dimension back
|
||||
|
||||
# Concatenate all frames back together along the frame dimension
|
||||
image_list[i] = torch.cat(resized_frames, dim=1)
|
||||
# Stack to (B, C, T, H, W)
|
||||
images = torch.stack(norm_images)
|
||||
B, C, T, H, W = images.shape
|
||||
|
||||
# Resize if needed (B * T, C, H, W)
|
||||
if H % 8 != 0 or W % 8 != 0:
|
||||
target_h = H // 8 * 8
|
||||
target_w = W // 8 * 8
|
||||
images = images.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
|
||||
images = F.interpolate(images, size=(target_h, target_w), mode='bilinear', align_corners=False)
|
||||
images = images.view(B, T, C, target_h, target_w).permute(0, 2, 1, 3, 4)
|
||||
|
||||
images = torch.stack(image_list)
|
||||
# images = images.unsqueeze(2) # adds frame dimension so (bs, ch, h, w) -> (bs, ch, 1, h, w)
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
|
||||
latents_mean = (
|
||||
@@ -644,9 +631,7 @@ class Wan21(BaseModel):
|
||||
)
|
||||
latents = (latents - latents_mean) * latents_std
|
||||
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
return latents.to(device, dtype=dtype)
|
||||
|
||||
def get_model_has_grad(self):
|
||||
return self.model.proj_out.weight.requires_grad
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import random
|
||||
import shutil
|
||||
import typing
|
||||
from typing import Union, List, Literal, Iterator
|
||||
from typing import Optional, Union, List, Literal, Iterator
|
||||
import sys
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
@@ -3079,3 +3079,7 @@ class StableDiffusion:
|
||||
def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'):
|
||||
# can be overridden in child classes to condition latents before noise prediction
|
||||
return latents
|
||||
|
||||
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||
# override in child classes to get transformer block names for lora targeting
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user