Bug fixes

This commit is contained in:
Jaret Burkett
2024-04-16 03:48:13 -06:00
parent 7284aab7c0
commit 2d0a1be59d
2 changed files with 16 additions and 6 deletions

View File

@@ -80,13 +80,16 @@ class MLPProjModelClipFace(torch.nn.Module):
class CustomIPAttentionProcessor(IPAttnProcessor2_0):
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False):
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False, full_token_scaler=False):
super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.train_scaler = train_scaler
if train_scaler:
# self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.9999)
self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
if full_token_scaler:
self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999)
else:
self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999)
# self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
self.ip_scaler.requires_grad_(True)
def __call__(
@@ -514,7 +517,9 @@ class IPAdapter(torch.nn.Module):
scale=1.0,
num_tokens=self.config.num_tokens,
adapter=self,
train_scaler=self.config.train_scaler or self.config.merge_scaler
train_scaler=self.config.train_scaler or self.config.merge_scaler,
# full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler
full_token_scaler=False
)
if self.sd_ref().is_pixart:
# pixart is much more sensitive