mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Added some split prompting started code, adamw8bit, replacements improving, learnable snr gos. A lot of good stuff.
This commit is contained in:
@@ -683,6 +683,68 @@ def get_all_snr(noise_scheduler, device):
|
||||
all_snr.requires_grad = False
|
||||
return all_snr.to(device)
|
||||
|
||||
class LearnableSNRGamma:
|
||||
"""
|
||||
This is a trainer for learnable snr gamma
|
||||
It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps
|
||||
"""
|
||||
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
|
||||
self.device = device
|
||||
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
|
||||
self.offset = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
|
||||
self.scale = torch.nn.Parameter(torch.tensor(0.001, dtype=torch.float32, device=device))
|
||||
self.gamma = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
|
||||
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.1)
|
||||
self.buffer = []
|
||||
self.max_buffer_size = 100
|
||||
|
||||
def forward(self, loss, timesteps):
|
||||
# do a our train loop for lsnr here and return our values detached
|
||||
loss = loss.detach()
|
||||
with torch.no_grad():
|
||||
loss_chunks = torch.chunk(loss, loss.shape[0], dim=0)
|
||||
for loss_chunk in loss_chunks:
|
||||
self.buffer.append(loss_chunk.mean().detach())
|
||||
if len(self.buffer) > self.max_buffer_size:
|
||||
self.buffer.pop(0)
|
||||
all_snr = get_all_snr(self.noise_scheduler, loss.device)
|
||||
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
|
||||
base_snrs = snr.clone().detach()
|
||||
snr.requires_grad = True
|
||||
snr = snr * self.scale + self.offset
|
||||
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
|
||||
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||
snr_adjusted_loss = loss * snr_weight
|
||||
with torch.no_grad():
|
||||
target = torch.mean(torch.stack(self.buffer)).detach()
|
||||
|
||||
# local_loss = torch.mean(torch.abs(snr_adjusted_loss - target))
|
||||
squared_differences = (snr_adjusted_loss - target) ** 2
|
||||
local_loss = torch.mean(squared_differences)
|
||||
local_loss.backward()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
return base_snrs, self.gamma.detach(), self.offset.detach(), self.scale.detach()
|
||||
|
||||
|
||||
def apply_learnable_snr_gos(
|
||||
loss,
|
||||
timesteps,
|
||||
learnable_snr_trainer:LearnableSNRGamma
|
||||
):
|
||||
|
||||
snr, gamma, offset, scale = learnable_snr_trainer.forward(loss, timesteps)
|
||||
|
||||
snr = snr * scale + offset
|
||||
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||
snr_adjusted_loss = loss * snr_weight
|
||||
|
||||
return snr_adjusted_loss
|
||||
|
||||
|
||||
def apply_snr_weight(
|
||||
loss,
|
||||
@@ -700,5 +762,6 @@ def apply_snr_weight(
|
||||
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
||||
else:
|
||||
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)
|
||||
loss = loss * snr_weight
|
||||
return loss
|
||||
snr_adjusted_loss = loss * snr_weight
|
||||
|
||||
return snr_adjusted_loss
|
||||
|
||||
Reference in New Issue
Block a user