Added specialized scaler training to ip adapters

This commit is contained in:
Jaret Burkett
2024-04-05 08:17:09 -06:00
parent 427847ac4c
commit 7284aab7c0
7 changed files with 182 additions and 29 deletions

View File

@@ -5,6 +5,10 @@ import torch.nn as nn
from typing import TYPE_CHECKING
from toolkit.models.clip_fusion import ZipperBlock
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
import sys
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.resampler import Resampler
if TYPE_CHECKING:
from toolkit.lora_special import LoRAModule
@@ -50,7 +54,7 @@ class InstantLoRAMidModule(torch.nn.Module):
raise e
# apply tanh to limit values to -1 to 1
# scaler = torch.tanh(scaler)
return x * scaler
return x * (scaler + 1.0)
class InstantLoRAModule(torch.nn.Module):
@@ -78,15 +82,30 @@ class InstantLoRAModule(torch.nn.Module):
lora_modules = self.sd_ref().network.get_all_modules()
# resample the output so each module gets one token with a size of its dim so we can multiply by that
self.resampler = ZipperResampler(
in_size=self.vision_hidden_size,
in_tokens=self.vision_tokens,
out_size=self.dim,
out_tokens=len(lora_modules),
hidden_size=self.vision_hidden_size,
hidden_tokens=self.vision_tokens,
num_blocks=1,
)
# self.resampler = ZipperResampler(
# in_size=self.vision_hidden_size,
# in_tokens=self.vision_tokens,
# out_size=self.dim,
# out_tokens=len(lora_modules),
# hidden_size=self.vision_hidden_size,
# hidden_tokens=self.vision_tokens,
# num_blocks=1,
# )
# heads = 20
heads = 12
dim = 1280
output_dim = self.dim
self.resampler = Resampler(
dim=dim,
depth=4,
dim_head=64,
heads=heads,
num_queries=len(lora_modules),
embedding_dim=self.vision_hidden_size,
max_seq_len=self.vision_tokens,
output_dim=output_dim,
ff_mult=4
)
for idx, lora_module in enumerate(lora_modules):
# add a new mid module that will take the original forward and add a vector to it