From 2d0a1be59dacce4c5b7a6f7472cbd93faec8bb51 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 16 Apr 2024 03:48:13 -0600 Subject: [PATCH] Bug fixes --- extensions_built_in/sd_trainer/SDTrainer.py | 9 +++++++-- toolkit/ip_adapter.py | 13 +++++++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index be35198..1daf377 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -164,10 +164,10 @@ class SDTrainer(BaseSDTrainProcess): timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0] single_step_timestep_schedule = [timesteps_item.squeeze().item()] # extract the sigma idx for our midpoint timestep - sigmas = train_sigmas[timestep_idx:timestep_idx + 1] + sigmas = train_sigmas[timestep_idx:timestep_idx + 1].to(self.device_torch) end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1) - end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1] + end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1].to(self.device_torch) # add noise to our target @@ -352,6 +352,11 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.do_prior_divergence and prior_pred is not None: loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) + if self.train_config.train_turbo: + mask_multiplier = mask_multiplier[:, 3:, :, :] + # resize to the size of the loss + mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest') + # multiply by our mask loss = loss * mask_multiplier diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 5588bf6..928f8eb 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -80,13 +80,16 @@ class MLPProjModelClipFace(torch.nn.Module): class CustomIPAttentionProcessor(IPAttnProcessor2_0): - def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False): + def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False, full_token_scaler=False): super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens) self.adapter_ref: weakref.ref = weakref.ref(adapter) self.train_scaler = train_scaler if train_scaler: - # self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.9999) - self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999) + if full_token_scaler: + self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999) + else: + self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999) + # self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999) self.ip_scaler.requires_grad_(True) def __call__( @@ -514,7 +517,9 @@ class IPAdapter(torch.nn.Module): scale=1.0, num_tokens=self.config.num_tokens, adapter=self, - train_scaler=self.config.train_scaler or self.config.merge_scaler + train_scaler=self.config.train_scaler or self.config.merge_scaler, + # full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler + full_token_scaler=False ) if self.sd_ref().is_pixart: # pixart is much more sensitive