Bug fixes

This commit is contained in:
Jaret Burkett
2024-04-16 03:48:13 -06:00
parent 7284aab7c0
commit 2d0a1be59d
2 changed files with 16 additions and 6 deletions

View File

@@ -164,10 +164,10 @@ class SDTrainer(BaseSDTrainProcess):
timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0]
single_step_timestep_schedule = [timesteps_item.squeeze().item()]
# extract the sigma idx for our midpoint timestep
sigmas = train_sigmas[timestep_idx:timestep_idx + 1]
sigmas = train_sigmas[timestep_idx:timestep_idx + 1].to(self.device_torch)
end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1)
end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1]
end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1].to(self.device_torch)
# add noise to our target
@@ -352,6 +352,11 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.do_prior_divergence and prior_pred is not None:
loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0)
if self.train_config.train_turbo:
mask_multiplier = mask_multiplier[:, 3:, :, :]
# resize to the size of the loss
mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest')
# multiply by our mask
loss = loss * mask_multiplier

View File

@@ -80,13 +80,16 @@ class MLPProjModelClipFace(torch.nn.Module):
class CustomIPAttentionProcessor(IPAttnProcessor2_0):
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False):
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False, full_token_scaler=False):
super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.train_scaler = train_scaler
if train_scaler:
# self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.9999)
self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
if full_token_scaler:
self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999)
else:
self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999)
# self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
self.ip_scaler.requires_grad_(True)
def __call__(
@@ -514,7 +517,9 @@ class IPAdapter(torch.nn.Module):
scale=1.0,
num_tokens=self.config.num_tokens,
adapter=self,
train_scaler=self.config.train_scaler or self.config.merge_scaler
train_scaler=self.config.train_scaler or self.config.merge_scaler,
# full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler
full_token_scaler=False
)
if self.sd_ref().is_pixart:
# pixart is much more sensitive