mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Various experiments and minor bug fixes for edge cases
This commit is contained in:
@@ -69,6 +69,7 @@ import transformers
|
|||||||
import diffusers
|
import diffusers
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
from toolkit.util.blended_blur_noise import get_blended_blur_noise
|
||||||
from toolkit.util.get_model import get_model_class
|
from toolkit.util.get_model import get_model_class
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
@@ -903,7 +904,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
def get_noise(self, latents, batch_size, dtype=torch.float32, batch: 'DataLoaderBatchDTO' = None):
|
def get_noise(
|
||||||
|
self,
|
||||||
|
latents,
|
||||||
|
batch_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
batch: 'DataLoaderBatchDTO' = None,
|
||||||
|
timestep=None,
|
||||||
|
):
|
||||||
if self.train_config.optimal_noise_pairing_samples > 1:
|
if self.train_config.optimal_noise_pairing_samples > 1:
|
||||||
noise = self.get_optimal_noise(latents, dtype=dtype)
|
noise = self.get_optimal_noise(latents, dtype=dtype)
|
||||||
elif self.train_config.force_consistent_noise:
|
elif self.train_config.force_consistent_noise:
|
||||||
@@ -933,12 +941,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
# add to noise
|
# add to noise
|
||||||
noise += noise_shift
|
noise += noise_shift
|
||||||
|
|
||||||
# standardize the noise
|
if self.train_config.blended_blur_noise:
|
||||||
# shouldnt be needed?
|
noise = get_blended_blur_noise(
|
||||||
# std = noise.std(dim=(2, 3), keepdim=True)
|
latents, noise, timestep
|
||||||
# normalizer = 1 / (std + 1e-6)
|
)
|
||||||
# noise = noise * normalizer
|
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
@@ -1193,7 +1200,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
timesteps = torch.stack(timesteps, dim=0)
|
timesteps = torch.stack(timesteps, dim=0)
|
||||||
|
|
||||||
# get noise
|
# get noise
|
||||||
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch)
|
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps)
|
||||||
|
|
||||||
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
|
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
|
||||||
# this will negate any noise offsets
|
# this will negate any noise offsets
|
||||||
@@ -1924,10 +1931,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
start_step_num = self.step_num
|
start_step_num = self.step_num
|
||||||
did_first_flush = False
|
did_first_flush = False
|
||||||
|
flush_next = False
|
||||||
for step in range(start_step_num, self.train_config.steps):
|
for step in range(start_step_num, self.train_config.steps):
|
||||||
if self.train_config.do_paramiter_swapping:
|
if self.train_config.do_paramiter_swapping:
|
||||||
self.optimizer.optimizer.swap_paramiters()
|
self.optimizer.optimizer.swap_paramiters()
|
||||||
self.timer.start('train_loop')
|
self.timer.start('train_loop')
|
||||||
|
if flush_next:
|
||||||
|
flush()
|
||||||
|
flush_next = False
|
||||||
if self.train_config.do_random_cfg:
|
if self.train_config.do_random_cfg:
|
||||||
self.train_config.do_cfg = True
|
self.train_config.do_cfg = True
|
||||||
self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
|
self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
|
||||||
@@ -2089,6 +2100,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
print_acc(f"\nSaving at step {self.step_num}")
|
print_acc(f"\nSaving at step {self.step_num}")
|
||||||
self.save(self.step_num)
|
self.save(self.step_num)
|
||||||
self.ensure_params_requires_grad()
|
self.ensure_params_requires_grad()
|
||||||
|
# clear any grads
|
||||||
|
optimizer.zero_grad()
|
||||||
|
flush()
|
||||||
|
flush_next = True
|
||||||
if self.progress_bar is not None:
|
if self.progress_bar is not None:
|
||||||
self.progress_bar.unpause()
|
self.progress_bar.unpause()
|
||||||
|
|
||||||
|
|||||||
@@ -460,6 +460,7 @@ class TrainConfig:
|
|||||||
|
|
||||||
# forces same noise for the same image at a given size.
|
# forces same noise for the same image at a given size.
|
||||||
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
|
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
|
||||||
|
self.blended_blur_noise = kwargs.get('blended_blur_noise', False)
|
||||||
|
|
||||||
|
|
||||||
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
|
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
|
||||||
|
|||||||
@@ -1160,12 +1160,12 @@ class BaseModel:
|
|||||||
if self.model_config.ignore_if_contains is not None:
|
if self.model_config.ignore_if_contains is not None:
|
||||||
# remove params that contain the ignore_if_contains from named params
|
# remove params that contain the ignore_if_contains from named params
|
||||||
for key in list(named_params.keys()):
|
for key in list(named_params.keys()):
|
||||||
if any([s in key for s in self.model_config.ignore_if_contains]):
|
if any([s in f"transformer.{key}" for s in self.model_config.ignore_if_contains]):
|
||||||
del named_params[key]
|
del named_params[key]
|
||||||
if self.model_config.only_if_contains is not None:
|
if self.model_config.only_if_contains is not None:
|
||||||
# remove params that do not contain the only_if_contains from named params
|
# remove params that do not contain the only_if_contains from named params
|
||||||
for key in list(named_params.keys()):
|
for key in list(named_params.keys()):
|
||||||
if not any([s in key for s in self.model_config.only_if_contains]):
|
if not any([s in f"transformer.{key}" for s in self.model_config.only_if_contains]):
|
||||||
del named_params[key]
|
del named_params[key]
|
||||||
|
|
||||||
if refiner:
|
if refiner:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
from toolkit.config_modules import AdapterConfig
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import TYPE_CHECKING, List, Dict, Any
|
from typing import TYPE_CHECKING, List, Dict, Any
|
||||||
@@ -35,7 +34,6 @@ class MLP(nn.Module):
|
|||||||
x = x + residual
|
x = x + residual
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LoRAGenerator(torch.nn.Module):
|
class LoRAGenerator(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -60,8 +58,7 @@ class LoRAGenerator(torch.nn.Module):
|
|||||||
self.lin_in = nn.Linear(input_size, hidden_size)
|
self.lin_in = nn.Linear(input_size, hidden_size)
|
||||||
|
|
||||||
self.mlp_blocks = nn.Sequential(*[
|
self.mlp_blocks = nn.Sequential(*[
|
||||||
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in
|
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers)
|
||||||
range(num_mlp_layers)
|
|
||||||
])
|
])
|
||||||
self.head = nn.Linear(hidden_size, head_size, bias=False)
|
self.head = nn.Linear(hidden_size, head_size, bias=False)
|
||||||
self.norm = nn.LayerNorm(head_size)
|
self.norm = nn.LayerNorm(head_size)
|
||||||
@@ -128,22 +125,15 @@ class InstantLoRAMidModule(torch.nn.Module):
|
|||||||
self.lora_module_ref = weakref.ref(lora_module)
|
self.lora_module_ref = weakref.ref(lora_module)
|
||||||
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
||||||
|
|
||||||
self.do_up = instant_lora_module.config.ilora_up
|
|
||||||
self.do_down = instant_lora_module.config.ilora_down
|
|
||||||
self.do_mid = instant_lora_module.config.ilora_mid
|
|
||||||
|
|
||||||
self.down_dim = self.down_shape[1] if self.do_down else 0
|
|
||||||
self.mid_dim = self.up_shape[1] if self.do_mid else 0
|
|
||||||
self.out_dim = self.up_shape[0] if self.do_up else 0
|
|
||||||
|
|
||||||
self.embed = None
|
self.embed = None
|
||||||
|
|
||||||
def down_forward(self, x, *args, **kwargs):
|
def down_forward(self, x, *args, **kwargs):
|
||||||
if not self.do_down:
|
|
||||||
return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
|
|
||||||
# get the embed
|
# get the embed
|
||||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||||
down_weight = self.embed[:, :self.down_dim]
|
if x.dtype != self.embed.dtype:
|
||||||
|
x = x.to(self.embed.dtype)
|
||||||
|
down_size = math.prod(self.down_shape)
|
||||||
|
down_weight = self.embed[:, :down_size]
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
@@ -151,72 +141,7 @@ class InstantLoRAMidModule(torch.nn.Module):
|
|||||||
if down_weight.shape[0] * 2 == batch_size:
|
if down_weight.shape[0] * 2 == batch_size:
|
||||||
down_weight = torch.cat([down_weight] * 2, dim=0)
|
down_weight = torch.cat([down_weight] * 2, dim=0)
|
||||||
|
|
||||||
try:
|
weight_chunks = torch.chunk(down_weight, batch_size, dim=0)
|
||||||
if len(x.shape) == 4:
|
|
||||||
# conv
|
|
||||||
down_weight = down_weight.view(batch_size, -1, 1, 1)
|
|
||||||
if x.shape[1] != down_weight.shape[1]:
|
|
||||||
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
|
|
||||||
elif len(x.shape) == 2:
|
|
||||||
down_weight = down_weight.view(batch_size, -1)
|
|
||||||
if x.shape[1] != down_weight.shape[1]:
|
|
||||||
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
|
|
||||||
else:
|
|
||||||
down_weight = down_weight.view(batch_size, 1, -1)
|
|
||||||
if x.shape[2] != down_weight.shape[2]:
|
|
||||||
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
|
|
||||||
x = x * down_weight
|
|
||||||
x = self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def up_forward(self, x, *args, **kwargs):
|
|
||||||
# do mid here
|
|
||||||
x = self.mid_forward(x, *args, **kwargs)
|
|
||||||
if not self.do_up:
|
|
||||||
return self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
|
|
||||||
# get the embed
|
|
||||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
|
||||||
up_weight = self.embed[:, -self.out_dim:]
|
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
|
|
||||||
# unconditional
|
|
||||||
if up_weight.shape[0] * 2 == batch_size:
|
|
||||||
up_weight = torch.cat([up_weight] * 2, dim=0)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if len(x.shape) == 4:
|
|
||||||
# conv
|
|
||||||
up_weight = up_weight.view(batch_size, -1, 1, 1)
|
|
||||||
elif len(x.shape) == 2:
|
|
||||||
up_weight = up_weight.view(batch_size, -1)
|
|
||||||
else:
|
|
||||||
up_weight = up_weight.view(batch_size, 1, -1)
|
|
||||||
x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
|
|
||||||
x = x * up_weight
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def mid_forward(self, x, *args, **kwargs):
|
|
||||||
if not self.do_mid:
|
|
||||||
return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
# get the embed
|
|
||||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
|
||||||
mid_weight = self.embed[:, self.down_dim:self.down_dim + self.mid_dim * self.mid_dim]
|
|
||||||
|
|
||||||
# unconditional
|
|
||||||
if mid_weight.shape[0] * 2 == batch_size:
|
|
||||||
mid_weight = torch.cat([mid_weight] * 2, dim=0)
|
|
||||||
|
|
||||||
weight_chunks = torch.chunk(mid_weight, batch_size, dim=0)
|
|
||||||
x_chunks = torch.chunk(x, batch_size, dim=0)
|
x_chunks = torch.chunk(x, batch_size, dim=0)
|
||||||
|
|
||||||
x_out = []
|
x_out = []
|
||||||
@@ -224,11 +149,43 @@ class InstantLoRAMidModule(torch.nn.Module):
|
|||||||
weight_chunk = weight_chunks[i]
|
weight_chunk = weight_chunks[i]
|
||||||
x_chunk = x_chunks[i]
|
x_chunk = x_chunks[i]
|
||||||
# reshape
|
# reshape
|
||||||
if len(x_chunk.shape) == 4:
|
weight_chunk = weight_chunk.view(self.down_shape)
|
||||||
# conv
|
# check if is conv or linear
|
||||||
weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim, 1, 1)
|
if len(weight_chunk.shape) == 4:
|
||||||
|
org_module = self.lora_module_ref().orig_module_ref()
|
||||||
|
stride = org_module.stride
|
||||||
|
padding = org_module.padding
|
||||||
|
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding, stride=stride)
|
||||||
else:
|
else:
|
||||||
weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim)
|
# run a simple linear layer with the down weight
|
||||||
|
x_chunk = x_chunk @ weight_chunk.T
|
||||||
|
x_out.append(x_chunk)
|
||||||
|
x = torch.cat(x_out, dim=0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def up_forward(self, x, *args, **kwargs):
|
||||||
|
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||||
|
if x.dtype != self.embed.dtype:
|
||||||
|
x = x.to(self.embed.dtype)
|
||||||
|
up_size = math.prod(self.up_shape)
|
||||||
|
up_weight = self.embed[:, -up_size:]
|
||||||
|
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
|
# unconditional
|
||||||
|
if up_weight.shape[0] * 2 == batch_size:
|
||||||
|
up_weight = torch.cat([up_weight] * 2, dim=0)
|
||||||
|
|
||||||
|
weight_chunks = torch.chunk(up_weight, batch_size, dim=0)
|
||||||
|
x_chunks = torch.chunk(x, batch_size, dim=0)
|
||||||
|
|
||||||
|
x_out = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
weight_chunk = weight_chunks[i]
|
||||||
|
x_chunk = x_chunks[i]
|
||||||
|
# reshape
|
||||||
|
weight_chunk = weight_chunk.view(self.up_shape)
|
||||||
# check if is conv or linear
|
# check if is conv or linear
|
||||||
if len(weight_chunk.shape) == 4:
|
if len(weight_chunk.shape) == 4:
|
||||||
padding = 0
|
padding = 0
|
||||||
@@ -243,15 +200,17 @@ class InstantLoRAMidModule(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class InstantLoRAModule(torch.nn.Module):
|
class InstantLoRAModule(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vision_hidden_size: int,
|
vision_hidden_size: int,
|
||||||
vision_tokens: int,
|
vision_tokens: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
num_heads: int, # number of heads in the resampler
|
num_heads: int, # number of heads in the resampler
|
||||||
sd: 'StableDiffusion',
|
sd: 'StableDiffusion',
|
||||||
config: AdapterConfig
|
config=None
|
||||||
):
|
):
|
||||||
super(InstantLoRAModule, self).__init__()
|
super(InstantLoRAModule, self).__init__()
|
||||||
# self.linear = torch.nn.Linear(2, 1)
|
# self.linear = torch.nn.Linear(2, 1)
|
||||||
@@ -262,8 +221,6 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
|
||||||
self.config: AdapterConfig = config
|
|
||||||
|
|
||||||
# stores the projection vector. Grabbed by modules
|
# stores the projection vector. Grabbed by modules
|
||||||
self.img_embeds: List[torch.Tensor] = None
|
self.img_embeds: List[torch.Tensor] = None
|
||||||
|
|
||||||
@@ -286,21 +243,11 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
|
|
||||||
self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]])
|
self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]])
|
||||||
|
|
||||||
#
|
module_size = math.prod(down_shape) + math.prod(up_shape)
|
||||||
# module_size = math.prod(down_shape) + math.prod(up_shape)
|
|
||||||
|
|
||||||
# conv weight shape is (out_channels, in_channels, kernel_size, kernel_size)
|
|
||||||
# linear weight shape is (out_features, in_features)
|
|
||||||
|
|
||||||
# just doing in dim and out dim
|
|
||||||
in_dim = down_shape[1] if self.config.ilora_down else 0
|
|
||||||
mid_dim = down_shape[0] * down_shape[0] if self.config.ilora_mid else 0
|
|
||||||
out_dim = up_shape[0] if self.config.ilora_up else 0
|
|
||||||
module_size = in_dim + mid_dim + out_dim
|
|
||||||
|
|
||||||
output_size += module_size
|
output_size += module_size
|
||||||
self.embed_lengths.append(module_size)
|
self.embed_lengths.append(module_size)
|
||||||
|
|
||||||
|
|
||||||
# add a new mid module that will take the original forward and add a vector to it
|
# add a new mid module that will take the original forward and add a vector to it
|
||||||
# this will be used to add the vector to the original forward
|
# this will be used to add the vector to the original forward
|
||||||
instant_module = InstantLoRAMidModule(
|
instant_module = InstantLoRAMidModule(
|
||||||
@@ -314,11 +261,10 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
self.ilora_modules.append(instant_module)
|
self.ilora_modules.append(instant_module)
|
||||||
|
|
||||||
# replace the LoRA forwards
|
# replace the LoRA forwards
|
||||||
lora_module.lora_down.orig_forward = lora_module.lora_down.forward
|
|
||||||
lora_module.lora_down.forward = instant_module.down_forward
|
lora_module.lora_down.forward = instant_module.down_forward
|
||||||
lora_module.lora_up.orig_forward = lora_module.lora_up.forward
|
|
||||||
lora_module.lora_up.forward = instant_module.up_forward
|
lora_module.lora_up.forward = instant_module.up_forward
|
||||||
|
|
||||||
|
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
|
|
||||||
number_formatted_output_size = "{:,}".format(output_size)
|
number_formatted_output_size = "{:,}".format(output_size)
|
||||||
@@ -378,6 +324,7 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
# print("No keymap found. Using default names")
|
# print("No keymap found. Using default names")
|
||||||
# return
|
# return
|
||||||
|
|
||||||
|
|
||||||
def forward(self, img_embeds):
|
def forward(self, img_embeds):
|
||||||
# expand token rank if only rank 2
|
# expand token rank if only rank 2
|
||||||
if len(img_embeds.shape) == 2:
|
if len(img_embeds.shape) == 2:
|
||||||
@@ -394,9 +341,10 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
# get all the slices
|
# get all the slices
|
||||||
start = 0
|
start = 0
|
||||||
for length in self.embed_lengths:
|
for length in self.embed_lengths:
|
||||||
self.img_embeds.append(img_embeds[:, start:start + length])
|
self.img_embeds.append(img_embeds[:, start:start+length])
|
||||||
start += length
|
start += length
|
||||||
|
|
||||||
|
|
||||||
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
||||||
# save the weight mapping
|
# save the weight mapping
|
||||||
return {
|
return {
|
||||||
@@ -406,7 +354,5 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
"head_dim": self.head_dim,
|
"head_dim": self.head_dim,
|
||||||
"vision_tokens": self.vision_tokens,
|
"vision_tokens": self.vision_tokens,
|
||||||
"output_size": self.output_size,
|
"output_size": self.output_size,
|
||||||
"do_up": self.config.ilora_up,
|
|
||||||
"do_mid": self.config.ilora_mid,
|
|
||||||
"do_down": self.config.ilora_down,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -65,8 +65,8 @@ class LLMAdapter(torch.nn.Module):
|
|||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
# self.system_prompt = ""
|
self.system_prompt = ""
|
||||||
self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
|
# self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
|
||||||
|
|
||||||
# determine length of system prompt
|
# determine length of system prompt
|
||||||
sys_prompt_tokenized = tokenizer(
|
sys_prompt_tokenized = tokenizer(
|
||||||
|
|||||||
@@ -1401,8 +1401,7 @@ class StableDiffusion:
|
|||||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||||
self.adapter(conditional_clip_embeds)
|
self.adapter(conditional_clip_embeds)
|
||||||
|
|
||||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) \
|
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||||
and gen_config.adapter_image_path is not None:
|
|
||||||
# handle condition the prompts
|
# handle condition the prompts
|
||||||
gen_config.prompt = self.adapter.condition_prompt(
|
gen_config.prompt = self.adapter.condition_prompt(
|
||||||
gen_config.prompt,
|
gen_config.prompt,
|
||||||
@@ -1456,7 +1455,7 @@ class StableDiffusion:
|
|||||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
|
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
|
||||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
|
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
|
||||||
|
|
||||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None:
|
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||||
tensors_0_1=validation_image,
|
tensors_0_1=validation_image,
|
||||||
prompt_embeds=conditional_embeds,
|
prompt_embeds=conditional_embeds,
|
||||||
|
|||||||
84
toolkit/util/blended_blur_noise.py
Normal file
84
toolkit/util/blended_blur_noise.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
cached_multipier = None
|
||||||
|
|
||||||
|
def get_multiplier(timesteps, num_timesteps=1000):
|
||||||
|
global cached_multipier
|
||||||
|
if cached_multipier is None:
|
||||||
|
# creates a bell curve
|
||||||
|
x = torch.arange(num_timesteps, dtype=torch.float32)
|
||||||
|
y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2)
|
||||||
|
|
||||||
|
# Shift minimum to 0
|
||||||
|
y_shifted = y - y.min()
|
||||||
|
|
||||||
|
# Scale to make mean 1
|
||||||
|
cached_multipier = y_shifted * (num_timesteps / y_shifted.sum())
|
||||||
|
|
||||||
|
scale_list = []
|
||||||
|
# get the idx multiplier for each timestep
|
||||||
|
for i in range(timesteps.shape[0]):
|
||||||
|
idx = min(int(timesteps[i].item()) - 1, 0)
|
||||||
|
scale_list.append(cached_multipier[idx:idx + 1])
|
||||||
|
|
||||||
|
scales = torch.cat(scale_list, dim=0)
|
||||||
|
|
||||||
|
batch_multiplier = scales.view(-1, 1, 1, 1)
|
||||||
|
|
||||||
|
return batch_multiplier
|
||||||
|
|
||||||
|
|
||||||
|
def get_blended_blur_noise(latents, noise, timestep):
|
||||||
|
latent_chunks = torch.chunk(latents, latents.shape[0], dim=0)
|
||||||
|
|
||||||
|
# timestep is 1000 to 0
|
||||||
|
# timestep = timestep.to(latents.device, dtype=latents.dtype)
|
||||||
|
|
||||||
|
# scale it so timestep 1000 is 0 and 0 is 2
|
||||||
|
# blur_strength = value_map(timestep, 1000, 0, 0, 1.0)
|
||||||
|
# blur_strength = timestep / 500.0
|
||||||
|
# blur_strength = blur_strength.view(-1, 1, 1, 1)
|
||||||
|
|
||||||
|
# scale to 2.0 max
|
||||||
|
# blur_strength = get_multiplier(timestep).to(
|
||||||
|
# latents.device, dtype=latents.dtype
|
||||||
|
# ) * 2.0
|
||||||
|
|
||||||
|
# blur_strength = 2.0
|
||||||
|
|
||||||
|
blurred_latent_chunks = []
|
||||||
|
for i in range(len(latent_chunks)):
|
||||||
|
latent_chunk = latent_chunks[i]
|
||||||
|
# get two random scalers 0.1 to 0.9
|
||||||
|
# scaler1 = random.uniform(0.2, 0.8)
|
||||||
|
scaler1 = 0.25
|
||||||
|
scaler2 = scaler1
|
||||||
|
|
||||||
|
# shrink latents by 1/4 and bring them back for blurring using interpolation
|
||||||
|
blur_latents = torch.nn.functional.interpolate(
|
||||||
|
latent_chunk,
|
||||||
|
size=(int(latents.shape[2] * scaler1), int(latents.shape[3] * scaler2)),
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
blur_latents = torch.nn.functional.interpolate(
|
||||||
|
blur_latents,
|
||||||
|
size=(latents.shape[2], latents.shape[3]),
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
# only the difference of the blur from ground truth
|
||||||
|
blur_latents = blur_latents - latent_chunk
|
||||||
|
blurred_latent_chunks.append(blur_latents)
|
||||||
|
|
||||||
|
blur_latents = torch.cat(blurred_latent_chunks, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
# make random strength along batch 0 to 1
|
||||||
|
blur_strength = torch.rand((latents.shape[0], 1, 1, 1), device=latents.device, dtype=latents.dtype) * 2
|
||||||
|
|
||||||
|
blur_latents = blur_latents * blur_strength
|
||||||
|
|
||||||
|
noise = noise + blur_latents
|
||||||
|
return noise
|
||||||
|
|
||||||
@@ -41,6 +41,9 @@ def shuffle_tensor_along_axis(tensor, axis=0, seed=None):
|
|||||||
|
|
||||||
# Apply the shuffle
|
# Apply the shuffle
|
||||||
shuffled_tensor = tensor[slices]
|
shuffled_tensor = tensor[slices]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error during shuffling: {e}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore original random states
|
# Restore original random states
|
||||||
|
|||||||
Reference in New Issue
Block a user