mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Bug fixes
This commit is contained in:
@@ -80,13 +80,16 @@ class MLPProjModelClipFace(torch.nn.Module):
|
||||
|
||||
|
||||
class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False):
|
||||
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False, full_token_scaler=False):
|
||||
super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.train_scaler = train_scaler
|
||||
if train_scaler:
|
||||
# self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.9999)
|
||||
self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
|
||||
if full_token_scaler:
|
||||
self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999)
|
||||
else:
|
||||
self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999)
|
||||
# self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
|
||||
self.ip_scaler.requires_grad_(True)
|
||||
|
||||
def __call__(
|
||||
@@ -514,7 +517,9 @@ class IPAdapter(torch.nn.Module):
|
||||
scale=1.0,
|
||||
num_tokens=self.config.num_tokens,
|
||||
adapter=self,
|
||||
train_scaler=self.config.train_scaler or self.config.merge_scaler
|
||||
train_scaler=self.config.train_scaler or self.config.merge_scaler,
|
||||
# full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler
|
||||
full_token_scaler=False
|
||||
)
|
||||
if self.sd_ref().is_pixart:
|
||||
# pixart is much more sensitive
|
||||
|
||||
Reference in New Issue
Block a user