mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added specialized scaler training to ip adapters
This commit is contained in:
@@ -5,6 +5,10 @@ import torch.nn as nn
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.models.clip_fusion import ZipperBlock
|
||||
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
||||
import sys
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRAModule
|
||||
@@ -50,7 +54,7 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
raise e
|
||||
# apply tanh to limit values to -1 to 1
|
||||
# scaler = torch.tanh(scaler)
|
||||
return x * scaler
|
||||
return x * (scaler + 1.0)
|
||||
|
||||
|
||||
class InstantLoRAModule(torch.nn.Module):
|
||||
@@ -78,15 +82,30 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
lora_modules = self.sd_ref().network.get_all_modules()
|
||||
|
||||
# resample the output so each module gets one token with a size of its dim so we can multiply by that
|
||||
self.resampler = ZipperResampler(
|
||||
in_size=self.vision_hidden_size,
|
||||
in_tokens=self.vision_tokens,
|
||||
out_size=self.dim,
|
||||
out_tokens=len(lora_modules),
|
||||
hidden_size=self.vision_hidden_size,
|
||||
hidden_tokens=self.vision_tokens,
|
||||
num_blocks=1,
|
||||
)
|
||||
# self.resampler = ZipperResampler(
|
||||
# in_size=self.vision_hidden_size,
|
||||
# in_tokens=self.vision_tokens,
|
||||
# out_size=self.dim,
|
||||
# out_tokens=len(lora_modules),
|
||||
# hidden_size=self.vision_hidden_size,
|
||||
# hidden_tokens=self.vision_tokens,
|
||||
# num_blocks=1,
|
||||
# )
|
||||
# heads = 20
|
||||
heads = 12
|
||||
dim = 1280
|
||||
output_dim = self.dim
|
||||
self.resampler = Resampler(
|
||||
dim=dim,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=heads,
|
||||
num_queries=len(lora_modules),
|
||||
embedding_dim=self.vision_hidden_size,
|
||||
max_seq_len=self.vision_tokens,
|
||||
output_dim=output_dim,
|
||||
ff_mult=4
|
||||
)
|
||||
|
||||
for idx, lora_module in enumerate(lora_modules):
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
|
||||
Reference in New Issue
Block a user