mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-04 12:39:58 +00:00
Bug fixes
This commit is contained in:
@@ -164,10 +164,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0]
|
||||
single_step_timestep_schedule = [timesteps_item.squeeze().item()]
|
||||
# extract the sigma idx for our midpoint timestep
|
||||
sigmas = train_sigmas[timestep_idx:timestep_idx + 1]
|
||||
sigmas = train_sigmas[timestep_idx:timestep_idx + 1].to(self.device_torch)
|
||||
|
||||
end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1)
|
||||
end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1]
|
||||
end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1].to(self.device_torch)
|
||||
|
||||
# add noise to our target
|
||||
|
||||
@@ -352,6 +352,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.do_prior_divergence and prior_pred is not None:
|
||||
loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0)
|
||||
|
||||
if self.train_config.train_turbo:
|
||||
mask_multiplier = mask_multiplier[:, 3:, :, :]
|
||||
# resize to the size of the loss
|
||||
mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest')
|
||||
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
|
||||
|
||||
@@ -80,13 +80,16 @@ class MLPProjModelClipFace(torch.nn.Module):
|
||||
|
||||
|
||||
class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False):
|
||||
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False, full_token_scaler=False):
|
||||
super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.train_scaler = train_scaler
|
||||
if train_scaler:
|
||||
# self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.9999)
|
||||
self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
|
||||
if full_token_scaler:
|
||||
self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999)
|
||||
else:
|
||||
self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999)
|
||||
# self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
|
||||
self.ip_scaler.requires_grad_(True)
|
||||
|
||||
def __call__(
|
||||
@@ -514,7 +517,9 @@ class IPAdapter(torch.nn.Module):
|
||||
scale=1.0,
|
||||
num_tokens=self.config.num_tokens,
|
||||
adapter=self,
|
||||
train_scaler=self.config.train_scaler or self.config.merge_scaler
|
||||
train_scaler=self.config.train_scaler or self.config.merge_scaler,
|
||||
# full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler
|
||||
full_token_scaler=False
|
||||
)
|
||||
if self.sd_ref().is_pixart:
|
||||
# pixart is much more sensitive
|
||||
|
||||
Reference in New Issue
Block a user