Various experiments and minor bug fixes for edge cases

This commit is contained in:
Jaret Burkett
2025-04-25 13:44:38 -06:00
parent 8ff85ba14f
commit 88b3fbae37
8 changed files with 170 additions and 122 deletions

View File

@@ -69,6 +69,7 @@ import transformers
import diffusers import diffusers
import hashlib import hashlib
from toolkit.util.blended_blur_noise import get_blended_blur_noise
from toolkit.util.get_model import get_model_class from toolkit.util.get_model import get_model_class
def flush(): def flush():
@@ -903,7 +904,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
return noise return noise
def get_noise(self, latents, batch_size, dtype=torch.float32, batch: 'DataLoaderBatchDTO' = None): def get_noise(
self,
latents,
batch_size,
dtype=torch.float32,
batch: 'DataLoaderBatchDTO' = None,
timestep=None,
):
if self.train_config.optimal_noise_pairing_samples > 1: if self.train_config.optimal_noise_pairing_samples > 1:
noise = self.get_optimal_noise(latents, dtype=dtype) noise = self.get_optimal_noise(latents, dtype=dtype)
elif self.train_config.force_consistent_noise: elif self.train_config.force_consistent_noise:
@@ -934,11 +942,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# add to noise # add to noise
noise += noise_shift noise += noise_shift
# standardize the noise if self.train_config.blended_blur_noise:
# shouldnt be needed? noise = get_blended_blur_noise(
# std = noise.std(dim=(2, 3), keepdim=True) latents, noise, timestep
# normalizer = 1 / (std + 1e-6) )
# noise = noise * normalizer
return noise return noise
@@ -1193,7 +1200,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
timesteps = torch.stack(timesteps, dim=0) timesteps = torch.stack(timesteps, dim=0)
# get noise # get noise
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch) noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps)
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents # add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
# this will negate any noise offsets # this will negate any noise offsets
@@ -1924,10 +1931,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
start_step_num = self.step_num start_step_num = self.step_num
did_first_flush = False did_first_flush = False
flush_next = False
for step in range(start_step_num, self.train_config.steps): for step in range(start_step_num, self.train_config.steps):
if self.train_config.do_paramiter_swapping: if self.train_config.do_paramiter_swapping:
self.optimizer.optimizer.swap_paramiters() self.optimizer.optimizer.swap_paramiters()
self.timer.start('train_loop') self.timer.start('train_loop')
if flush_next:
flush()
flush_next = False
if self.train_config.do_random_cfg: if self.train_config.do_random_cfg:
self.train_config.do_cfg = True self.train_config.do_cfg = True
self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale) self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
@@ -2089,6 +2100,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
print_acc(f"\nSaving at step {self.step_num}") print_acc(f"\nSaving at step {self.step_num}")
self.save(self.step_num) self.save(self.step_num)
self.ensure_params_requires_grad() self.ensure_params_requires_grad()
# clear any grads
optimizer.zero_grad()
flush()
flush_next = True
if self.progress_bar is not None: if self.progress_bar is not None:
self.progress_bar.unpause() self.progress_bar.unpause()

View File

@@ -460,6 +460,7 @@ class TrainConfig:
# forces same noise for the same image at a given size. # forces same noise for the same image at a given size.
self.force_consistent_noise = kwargs.get('force_consistent_noise', False) self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
self.blended_blur_noise = kwargs.get('blended_blur_noise', False)
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21'] ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']

View File

@@ -1160,12 +1160,12 @@ class BaseModel:
if self.model_config.ignore_if_contains is not None: if self.model_config.ignore_if_contains is not None:
# remove params that contain the ignore_if_contains from named params # remove params that contain the ignore_if_contains from named params
for key in list(named_params.keys()): for key in list(named_params.keys()):
if any([s in key for s in self.model_config.ignore_if_contains]): if any([s in f"transformer.{key}" for s in self.model_config.ignore_if_contains]):
del named_params[key] del named_params[key]
if self.model_config.only_if_contains is not None: if self.model_config.only_if_contains is not None:
# remove params that do not contain the only_if_contains from named params # remove params that do not contain the only_if_contains from named params
for key in list(named_params.keys()): for key in list(named_params.keys()):
if not any([s in key for s in self.model_config.only_if_contains]): if not any([s in f"transformer.{key}" for s in self.model_config.only_if_contains]):
del named_params[key] del named_params[key]
if refiner: if refiner:

View File

@@ -1,7 +1,6 @@
import math import math
import weakref import weakref
from toolkit.config_modules import AdapterConfig
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import TYPE_CHECKING, List, Dict, Any from typing import TYPE_CHECKING, List, Dict, Any
@@ -35,7 +34,6 @@ class MLP(nn.Module):
x = x + residual x = x + residual
return x return x
class LoRAGenerator(torch.nn.Module): class LoRAGenerator(torch.nn.Module):
def __init__( def __init__(
self, self,
@@ -60,8 +58,7 @@ class LoRAGenerator(torch.nn.Module):
self.lin_in = nn.Linear(input_size, hidden_size) self.lin_in = nn.Linear(input_size, hidden_size)
self.mlp_blocks = nn.Sequential(*[ self.mlp_blocks = nn.Sequential(*[
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers)
range(num_mlp_layers)
]) ])
self.head = nn.Linear(hidden_size, head_size, bias=False) self.head = nn.Linear(hidden_size, head_size, bias=False)
self.norm = nn.LayerNorm(head_size) self.norm = nn.LayerNorm(head_size)
@@ -128,22 +125,15 @@ class InstantLoRAMidModule(torch.nn.Module):
self.lora_module_ref = weakref.ref(lora_module) self.lora_module_ref = weakref.ref(lora_module)
self.instant_lora_module_ref = weakref.ref(instant_lora_module) self.instant_lora_module_ref = weakref.ref(instant_lora_module)
self.do_up = instant_lora_module.config.ilora_up
self.do_down = instant_lora_module.config.ilora_down
self.do_mid = instant_lora_module.config.ilora_mid
self.down_dim = self.down_shape[1] if self.do_down else 0
self.mid_dim = self.up_shape[1] if self.do_mid else 0
self.out_dim = self.up_shape[0] if self.do_up else 0
self.embed = None self.embed = None
def down_forward(self, x, *args, **kwargs): def down_forward(self, x, *args, **kwargs):
if not self.do_down:
return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
# get the embed # get the embed
self.embed = self.instant_lora_module_ref().img_embeds[self.index] self.embed = self.instant_lora_module_ref().img_embeds[self.index]
down_weight = self.embed[:, :self.down_dim] if x.dtype != self.embed.dtype:
x = x.to(self.embed.dtype)
down_size = math.prod(self.down_shape)
down_weight = self.embed[:, :down_size]
batch_size = x.shape[0] batch_size = x.shape[0]
@@ -151,72 +141,7 @@ class InstantLoRAMidModule(torch.nn.Module):
if down_weight.shape[0] * 2 == batch_size: if down_weight.shape[0] * 2 == batch_size:
down_weight = torch.cat([down_weight] * 2, dim=0) down_weight = torch.cat([down_weight] * 2, dim=0)
try: weight_chunks = torch.chunk(down_weight, batch_size, dim=0)
if len(x.shape) == 4:
# conv
down_weight = down_weight.view(batch_size, -1, 1, 1)
if x.shape[1] != down_weight.shape[1]:
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
elif len(x.shape) == 2:
down_weight = down_weight.view(batch_size, -1)
if x.shape[1] != down_weight.shape[1]:
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
else:
down_weight = down_weight.view(batch_size, 1, -1)
if x.shape[2] != down_weight.shape[2]:
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
x = x * down_weight
x = self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
except Exception as e:
print(e)
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
return x
def up_forward(self, x, *args, **kwargs):
# do mid here
x = self.mid_forward(x, *args, **kwargs)
if not self.do_up:
return self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
# get the embed
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
up_weight = self.embed[:, -self.out_dim:]
batch_size = x.shape[0]
# unconditional
if up_weight.shape[0] * 2 == batch_size:
up_weight = torch.cat([up_weight] * 2, dim=0)
try:
if len(x.shape) == 4:
# conv
up_weight = up_weight.view(batch_size, -1, 1, 1)
elif len(x.shape) == 2:
up_weight = up_weight.view(batch_size, -1)
else:
up_weight = up_weight.view(batch_size, 1, -1)
x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
x = x * up_weight
except Exception as e:
print(e)
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
return x
def mid_forward(self, x, *args, **kwargs):
if not self.do_mid:
return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
batch_size = x.shape[0]
# get the embed
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
mid_weight = self.embed[:, self.down_dim:self.down_dim + self.mid_dim * self.mid_dim]
# unconditional
if mid_weight.shape[0] * 2 == batch_size:
mid_weight = torch.cat([mid_weight] * 2, dim=0)
weight_chunks = torch.chunk(mid_weight, batch_size, dim=0)
x_chunks = torch.chunk(x, batch_size, dim=0) x_chunks = torch.chunk(x, batch_size, dim=0)
x_out = [] x_out = []
@@ -224,11 +149,43 @@ class InstantLoRAMidModule(torch.nn.Module):
weight_chunk = weight_chunks[i] weight_chunk = weight_chunks[i]
x_chunk = x_chunks[i] x_chunk = x_chunks[i]
# reshape # reshape
if len(x_chunk.shape) == 4: weight_chunk = weight_chunk.view(self.down_shape)
# conv # check if is conv or linear
weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim, 1, 1) if len(weight_chunk.shape) == 4:
org_module = self.lora_module_ref().orig_module_ref()
stride = org_module.stride
padding = org_module.padding
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding, stride=stride)
else: else:
weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim) # run a simple linear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T
x_out.append(x_chunk)
x = torch.cat(x_out, dim=0)
return x
def up_forward(self, x, *args, **kwargs):
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
if x.dtype != self.embed.dtype:
x = x.to(self.embed.dtype)
up_size = math.prod(self.up_shape)
up_weight = self.embed[:, -up_size:]
batch_size = x.shape[0]
# unconditional
if up_weight.shape[0] * 2 == batch_size:
up_weight = torch.cat([up_weight] * 2, dim=0)
weight_chunks = torch.chunk(up_weight, batch_size, dim=0)
x_chunks = torch.chunk(x, batch_size, dim=0)
x_out = []
for i in range(batch_size):
weight_chunk = weight_chunks[i]
x_chunk = x_chunks[i]
# reshape
weight_chunk = weight_chunk.view(self.up_shape)
# check if is conv or linear # check if is conv or linear
if len(weight_chunk.shape) == 4: if len(weight_chunk.shape) == 4:
padding = 0 padding = 0
@@ -243,15 +200,17 @@ class InstantLoRAMidModule(torch.nn.Module):
return x return x
class InstantLoRAModule(torch.nn.Module): class InstantLoRAModule(torch.nn.Module):
def __init__( def __init__(
self, self,
vision_hidden_size: int, vision_hidden_size: int,
vision_tokens: int, vision_tokens: int,
head_dim: int, head_dim: int,
num_heads: int, # number of heads in the resampler num_heads: int, # number of heads in the resampler
sd: 'StableDiffusion', sd: 'StableDiffusion',
config: AdapterConfig config=None
): ):
super(InstantLoRAModule, self).__init__() super(InstantLoRAModule, self).__init__()
# self.linear = torch.nn.Linear(2, 1) # self.linear = torch.nn.Linear(2, 1)
@@ -262,8 +221,6 @@ class InstantLoRAModule(torch.nn.Module):
self.head_dim = head_dim self.head_dim = head_dim
self.num_heads = num_heads self.num_heads = num_heads
self.config: AdapterConfig = config
# stores the projection vector. Grabbed by modules # stores the projection vector. Grabbed by modules
self.img_embeds: List[torch.Tensor] = None self.img_embeds: List[torch.Tensor] = None
@@ -286,21 +243,11 @@ class InstantLoRAModule(torch.nn.Module):
self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]]) self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]])
# module_size = math.prod(down_shape) + math.prod(up_shape)
# module_size = math.prod(down_shape) + math.prod(up_shape)
# conv weight shape is (out_channels, in_channels, kernel_size, kernel_size)
# linear weight shape is (out_features, in_features)
# just doing in dim and out dim
in_dim = down_shape[1] if self.config.ilora_down else 0
mid_dim = down_shape[0] * down_shape[0] if self.config.ilora_mid else 0
out_dim = up_shape[0] if self.config.ilora_up else 0
module_size = in_dim + mid_dim + out_dim
output_size += module_size output_size += module_size
self.embed_lengths.append(module_size) self.embed_lengths.append(module_size)
# add a new mid module that will take the original forward and add a vector to it # add a new mid module that will take the original forward and add a vector to it
# this will be used to add the vector to the original forward # this will be used to add the vector to the original forward
instant_module = InstantLoRAMidModule( instant_module = InstantLoRAMidModule(
@@ -314,11 +261,10 @@ class InstantLoRAModule(torch.nn.Module):
self.ilora_modules.append(instant_module) self.ilora_modules.append(instant_module)
# replace the LoRA forwards # replace the LoRA forwards
lora_module.lora_down.orig_forward = lora_module.lora_down.forward
lora_module.lora_down.forward = instant_module.down_forward lora_module.lora_down.forward = instant_module.down_forward
lora_module.lora_up.orig_forward = lora_module.lora_up.forward
lora_module.lora_up.forward = instant_module.up_forward lora_module.lora_up.forward = instant_module.up_forward
self.output_size = output_size self.output_size = output_size
number_formatted_output_size = "{:,}".format(output_size) number_formatted_output_size = "{:,}".format(output_size)
@@ -378,6 +324,7 @@ class InstantLoRAModule(torch.nn.Module):
# print("No keymap found. Using default names") # print("No keymap found. Using default names")
# return # return
def forward(self, img_embeds): def forward(self, img_embeds):
# expand token rank if only rank 2 # expand token rank if only rank 2
if len(img_embeds.shape) == 2: if len(img_embeds.shape) == 2:
@@ -394,9 +341,10 @@ class InstantLoRAModule(torch.nn.Module):
# get all the slices # get all the slices
start = 0 start = 0
for length in self.embed_lengths: for length in self.embed_lengths:
self.img_embeds.append(img_embeds[:, start:start + length]) self.img_embeds.append(img_embeds[:, start:start+length])
start += length start += length
def get_additional_save_metadata(self) -> Dict[str, Any]: def get_additional_save_metadata(self) -> Dict[str, Any]:
# save the weight mapping # save the weight mapping
return { return {
@@ -406,7 +354,5 @@ class InstantLoRAModule(torch.nn.Module):
"head_dim": self.head_dim, "head_dim": self.head_dim,
"vision_tokens": self.vision_tokens, "vision_tokens": self.vision_tokens,
"output_size": self.output_size, "output_size": self.output_size,
"do_up": self.config.ilora_up,
"do_mid": self.config.ilora_mid,
"do_down": self.config.ilora_down,
} }

View File

@@ -65,8 +65,8 @@ class LLMAdapter(torch.nn.Module):
if tokenizer.pad_token is None: if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# self.system_prompt = "" self.system_prompt = ""
self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> " # self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
# determine length of system prompt # determine length of system prompt
sys_prompt_tokenized = tokenizer( sys_prompt_tokenized = tokenizer(

View File

@@ -1401,8 +1401,7 @@ class StableDiffusion:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
self.adapter(conditional_clip_embeds) self.adapter(conditional_clip_embeds)
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) \ if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
and gen_config.adapter_image_path is not None:
# handle condition the prompts # handle condition the prompts
gen_config.prompt = self.adapter.condition_prompt( gen_config.prompt = self.adapter.condition_prompt(
gen_config.prompt, gen_config.prompt,
@@ -1456,7 +1455,7 @@ class StableDiffusion:
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False) conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
conditional_embeds = self.adapter.condition_encoded_embeds( conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=validation_image, tensors_0_1=validation_image,
prompt_embeds=conditional_embeds, prompt_embeds=conditional_embeds,

View File

@@ -0,0 +1,84 @@
import torch
cached_multipier = None
def get_multiplier(timesteps, num_timesteps=1000):
global cached_multipier
if cached_multipier is None:
# creates a bell curve
x = torch.arange(num_timesteps, dtype=torch.float32)
y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2)
# Shift minimum to 0
y_shifted = y - y.min()
# Scale to make mean 1
cached_multipier = y_shifted * (num_timesteps / y_shifted.sum())
scale_list = []
# get the idx multiplier for each timestep
for i in range(timesteps.shape[0]):
idx = min(int(timesteps[i].item()) - 1, 0)
scale_list.append(cached_multipier[idx:idx + 1])
scales = torch.cat(scale_list, dim=0)
batch_multiplier = scales.view(-1, 1, 1, 1)
return batch_multiplier
def get_blended_blur_noise(latents, noise, timestep):
latent_chunks = torch.chunk(latents, latents.shape[0], dim=0)
# timestep is 1000 to 0
# timestep = timestep.to(latents.device, dtype=latents.dtype)
# scale it so timestep 1000 is 0 and 0 is 2
# blur_strength = value_map(timestep, 1000, 0, 0, 1.0)
# blur_strength = timestep / 500.0
# blur_strength = blur_strength.view(-1, 1, 1, 1)
# scale to 2.0 max
# blur_strength = get_multiplier(timestep).to(
# latents.device, dtype=latents.dtype
# ) * 2.0
# blur_strength = 2.0
blurred_latent_chunks = []
for i in range(len(latent_chunks)):
latent_chunk = latent_chunks[i]
# get two random scalers 0.1 to 0.9
# scaler1 = random.uniform(0.2, 0.8)
scaler1 = 0.25
scaler2 = scaler1
# shrink latents by 1/4 and bring them back for blurring using interpolation
blur_latents = torch.nn.functional.interpolate(
latent_chunk,
size=(int(latents.shape[2] * scaler1), int(latents.shape[3] * scaler2)),
mode='bilinear',
align_corners=False
)
blur_latents = torch.nn.functional.interpolate(
blur_latents,
size=(latents.shape[2], latents.shape[3]),
mode='bilinear',
align_corners=False
)
# only the difference of the blur from ground truth
blur_latents = blur_latents - latent_chunk
blurred_latent_chunks.append(blur_latents)
blur_latents = torch.cat(blurred_latent_chunks, dim=0)
# make random strength along batch 0 to 1
blur_strength = torch.rand((latents.shape[0], 1, 1, 1), device=latents.device, dtype=latents.dtype) * 2
blur_latents = blur_latents * blur_strength
noise = noise + blur_latents
return noise

View File

@@ -42,6 +42,9 @@ def shuffle_tensor_along_axis(tensor, axis=0, seed=None):
# Apply the shuffle # Apply the shuffle
shuffled_tensor = tensor[slices] shuffled_tensor = tensor[slices]
except Exception as e:
raise RuntimeError(f"Error during shuffling: {e}")
finally: finally:
# Restore original random states # Restore original random states
torch.set_rng_state(torch_state) torch.set_rng_state(torch_state)