mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 01:39:20 +00:00
Various experiments and minor bug fixes for edge cases
This commit is contained in:
@@ -1160,12 +1160,12 @@ class BaseModel:
|
||||
if self.model_config.ignore_if_contains is not None:
|
||||
# remove params that contain the ignore_if_contains from named params
|
||||
for key in list(named_params.keys()):
|
||||
if any([s in key for s in self.model_config.ignore_if_contains]):
|
||||
if any([s in f"transformer.{key}" for s in self.model_config.ignore_if_contains]):
|
||||
del named_params[key]
|
||||
if self.model_config.only_if_contains is not None:
|
||||
# remove params that do not contain the only_if_contains from named params
|
||||
for key in list(named_params.keys()):
|
||||
if not any([s in key for s in self.model_config.only_if_contains]):
|
||||
if not any([s in f"transformer.{key}" for s in self.model_config.only_if_contains]):
|
||||
del named_params[key]
|
||||
|
||||
if refiner:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import math
|
||||
import weakref
|
||||
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING, List, Dict, Any
|
||||
@@ -35,7 +34,6 @@ class MLP(nn.Module):
|
||||
x = x + residual
|
||||
return x
|
||||
|
||||
|
||||
class LoRAGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -60,8 +58,7 @@ class LoRAGenerator(torch.nn.Module):
|
||||
self.lin_in = nn.Linear(input_size, hidden_size)
|
||||
|
||||
self.mlp_blocks = nn.Sequential(*[
|
||||
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in
|
||||
range(num_mlp_layers)
|
||||
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers)
|
||||
])
|
||||
self.head = nn.Linear(hidden_size, head_size, bias=False)
|
||||
self.norm = nn.LayerNorm(head_size)
|
||||
@@ -128,22 +125,15 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
self.lora_module_ref = weakref.ref(lora_module)
|
||||
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
||||
|
||||
self.do_up = instant_lora_module.config.ilora_up
|
||||
self.do_down = instant_lora_module.config.ilora_down
|
||||
self.do_mid = instant_lora_module.config.ilora_mid
|
||||
|
||||
self.down_dim = self.down_shape[1] if self.do_down else 0
|
||||
self.mid_dim = self.up_shape[1] if self.do_mid else 0
|
||||
self.out_dim = self.up_shape[0] if self.do_up else 0
|
||||
|
||||
self.embed = None
|
||||
|
||||
def down_forward(self, x, *args, **kwargs):
|
||||
if not self.do_down:
|
||||
return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
|
||||
# get the embed
|
||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||
down_weight = self.embed[:, :self.down_dim]
|
||||
if x.dtype != self.embed.dtype:
|
||||
x = x.to(self.embed.dtype)
|
||||
down_size = math.prod(self.down_shape)
|
||||
down_weight = self.embed[:, :down_size]
|
||||
|
||||
batch_size = x.shape[0]
|
||||
|
||||
@@ -151,72 +141,7 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
if down_weight.shape[0] * 2 == batch_size:
|
||||
down_weight = torch.cat([down_weight] * 2, dim=0)
|
||||
|
||||
try:
|
||||
if len(x.shape) == 4:
|
||||
# conv
|
||||
down_weight = down_weight.view(batch_size, -1, 1, 1)
|
||||
if x.shape[1] != down_weight.shape[1]:
|
||||
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
|
||||
elif len(x.shape) == 2:
|
||||
down_weight = down_weight.view(batch_size, -1)
|
||||
if x.shape[1] != down_weight.shape[1]:
|
||||
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
|
||||
else:
|
||||
down_weight = down_weight.view(batch_size, 1, -1)
|
||||
if x.shape[2] != down_weight.shape[2]:
|
||||
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
|
||||
x = x * down_weight
|
||||
x = self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
def up_forward(self, x, *args, **kwargs):
|
||||
# do mid here
|
||||
x = self.mid_forward(x, *args, **kwargs)
|
||||
if not self.do_up:
|
||||
return self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
|
||||
# get the embed
|
||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||
up_weight = self.embed[:, -self.out_dim:]
|
||||
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# unconditional
|
||||
if up_weight.shape[0] * 2 == batch_size:
|
||||
up_weight = torch.cat([up_weight] * 2, dim=0)
|
||||
|
||||
try:
|
||||
if len(x.shape) == 4:
|
||||
# conv
|
||||
up_weight = up_weight.view(batch_size, -1, 1, 1)
|
||||
elif len(x.shape) == 2:
|
||||
up_weight = up_weight.view(batch_size, -1)
|
||||
else:
|
||||
up_weight = up_weight.view(batch_size, 1, -1)
|
||||
x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
|
||||
x = x * up_weight
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
def mid_forward(self, x, *args, **kwargs):
|
||||
if not self.do_mid:
|
||||
return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
|
||||
batch_size = x.shape[0]
|
||||
# get the embed
|
||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||
mid_weight = self.embed[:, self.down_dim:self.down_dim + self.mid_dim * self.mid_dim]
|
||||
|
||||
# unconditional
|
||||
if mid_weight.shape[0] * 2 == batch_size:
|
||||
mid_weight = torch.cat([mid_weight] * 2, dim=0)
|
||||
|
||||
weight_chunks = torch.chunk(mid_weight, batch_size, dim=0)
|
||||
weight_chunks = torch.chunk(down_weight, batch_size, dim=0)
|
||||
x_chunks = torch.chunk(x, batch_size, dim=0)
|
||||
|
||||
x_out = []
|
||||
@@ -224,11 +149,43 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
weight_chunk = weight_chunks[i]
|
||||
x_chunk = x_chunks[i]
|
||||
# reshape
|
||||
if len(x_chunk.shape) == 4:
|
||||
# conv
|
||||
weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim, 1, 1)
|
||||
weight_chunk = weight_chunk.view(self.down_shape)
|
||||
# check if is conv or linear
|
||||
if len(weight_chunk.shape) == 4:
|
||||
org_module = self.lora_module_ref().orig_module_ref()
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding, stride=stride)
|
||||
else:
|
||||
weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim)
|
||||
# run a simple linear layer with the down weight
|
||||
x_chunk = x_chunk @ weight_chunk.T
|
||||
x_out.append(x_chunk)
|
||||
x = torch.cat(x_out, dim=0)
|
||||
return x
|
||||
|
||||
|
||||
def up_forward(self, x, *args, **kwargs):
|
||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||
if x.dtype != self.embed.dtype:
|
||||
x = x.to(self.embed.dtype)
|
||||
up_size = math.prod(self.up_shape)
|
||||
up_weight = self.embed[:, -up_size:]
|
||||
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# unconditional
|
||||
if up_weight.shape[0] * 2 == batch_size:
|
||||
up_weight = torch.cat([up_weight] * 2, dim=0)
|
||||
|
||||
weight_chunks = torch.chunk(up_weight, batch_size, dim=0)
|
||||
x_chunks = torch.chunk(x, batch_size, dim=0)
|
||||
|
||||
x_out = []
|
||||
for i in range(batch_size):
|
||||
weight_chunk = weight_chunks[i]
|
||||
x_chunk = x_chunks[i]
|
||||
# reshape
|
||||
weight_chunk = weight_chunk.view(self.up_shape)
|
||||
# check if is conv or linear
|
||||
if len(weight_chunk.shape) == 4:
|
||||
padding = 0
|
||||
@@ -243,15 +200,17 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class InstantLoRAModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vision_hidden_size: int,
|
||||
vision_tokens: int,
|
||||
head_dim: int,
|
||||
num_heads: int, # number of heads in the resampler
|
||||
num_heads: int, # number of heads in the resampler
|
||||
sd: 'StableDiffusion',
|
||||
config: AdapterConfig
|
||||
config=None
|
||||
):
|
||||
super(InstantLoRAModule, self).__init__()
|
||||
# self.linear = torch.nn.Linear(2, 1)
|
||||
@@ -262,8 +221,6 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.config: AdapterConfig = config
|
||||
|
||||
# stores the projection vector. Grabbed by modules
|
||||
self.img_embeds: List[torch.Tensor] = None
|
||||
|
||||
@@ -286,21 +243,11 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
|
||||
self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]])
|
||||
|
||||
#
|
||||
# module_size = math.prod(down_shape) + math.prod(up_shape)
|
||||
|
||||
# conv weight shape is (out_channels, in_channels, kernel_size, kernel_size)
|
||||
# linear weight shape is (out_features, in_features)
|
||||
|
||||
# just doing in dim and out dim
|
||||
in_dim = down_shape[1] if self.config.ilora_down else 0
|
||||
mid_dim = down_shape[0] * down_shape[0] if self.config.ilora_mid else 0
|
||||
out_dim = up_shape[0] if self.config.ilora_up else 0
|
||||
module_size = in_dim + mid_dim + out_dim
|
||||
|
||||
module_size = math.prod(down_shape) + math.prod(up_shape)
|
||||
output_size += module_size
|
||||
self.embed_lengths.append(module_size)
|
||||
|
||||
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
# this will be used to add the vector to the original forward
|
||||
instant_module = InstantLoRAMidModule(
|
||||
@@ -314,11 +261,10 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
self.ilora_modules.append(instant_module)
|
||||
|
||||
# replace the LoRA forwards
|
||||
lora_module.lora_down.orig_forward = lora_module.lora_down.forward
|
||||
lora_module.lora_down.forward = instant_module.down_forward
|
||||
lora_module.lora_up.orig_forward = lora_module.lora_up.forward
|
||||
lora_module.lora_up.forward = instant_module.up_forward
|
||||
|
||||
|
||||
self.output_size = output_size
|
||||
|
||||
number_formatted_output_size = "{:,}".format(output_size)
|
||||
@@ -378,6 +324,7 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
# print("No keymap found. Using default names")
|
||||
# return
|
||||
|
||||
|
||||
def forward(self, img_embeds):
|
||||
# expand token rank if only rank 2
|
||||
if len(img_embeds.shape) == 2:
|
||||
@@ -394,9 +341,10 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
# get all the slices
|
||||
start = 0
|
||||
for length in self.embed_lengths:
|
||||
self.img_embeds.append(img_embeds[:, start:start + length])
|
||||
self.img_embeds.append(img_embeds[:, start:start+length])
|
||||
start += length
|
||||
|
||||
|
||||
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
||||
# save the weight mapping
|
||||
return {
|
||||
@@ -406,7 +354,5 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
"head_dim": self.head_dim,
|
||||
"vision_tokens": self.vision_tokens,
|
||||
"output_size": self.output_size,
|
||||
"do_up": self.config.ilora_up,
|
||||
"do_mid": self.config.ilora_mid,
|
||||
"do_down": self.config.ilora_down,
|
||||
}
|
||||
|
||||
|
||||
@@ -65,8 +65,8 @@ class LLMAdapter(torch.nn.Module):
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# self.system_prompt = ""
|
||||
self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
|
||||
self.system_prompt = ""
|
||||
# self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
|
||||
|
||||
# determine length of system prompt
|
||||
sys_prompt_tokenized = tokenizer(
|
||||
|
||||
Reference in New Issue
Block a user