mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 17:51:41 +00:00
Added refiner fine tuning. Works, but needs some polish.
This commit is contained in:
@@ -691,10 +691,11 @@ class LearnableSNRGamma:
|
||||
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
|
||||
self.device = device
|
||||
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
|
||||
self.offset = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
|
||||
self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))
|
||||
self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
|
||||
self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
|
||||
self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
|
||||
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.01)
|
||||
self.optimizer = torch.optim.AdamW([self.offset_1, self.offset_2, self.gamma, self.scale], lr=0.01)
|
||||
self.buffer = []
|
||||
self.max_buffer_size = 20
|
||||
|
||||
@@ -711,7 +712,7 @@ class LearnableSNRGamma:
|
||||
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
|
||||
base_snrs = snr.clone().detach()
|
||||
snr.requires_grad = True
|
||||
snr = snr * self.scale + self.offset
|
||||
snr = (snr + self.offset_1) * self.scale + self.offset_2
|
||||
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
|
||||
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||
@@ -726,18 +727,18 @@ class LearnableSNRGamma:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
return base_snrs, self.gamma.detach(), self.offset.detach(), self.scale.detach()
|
||||
return base_snrs, self.gamma.detach(), self.offset_1.detach(), self.offset_2.detach(), self.scale.detach()
|
||||
|
||||
|
||||
def apply_learnable_snr_gos(
|
||||
loss,
|
||||
timesteps,
|
||||
learnable_snr_trainer:LearnableSNRGamma
|
||||
learnable_snr_trainer: LearnableSNRGamma
|
||||
):
|
||||
|
||||
snr, gamma, offset, scale = learnable_snr_trainer.forward(loss, timesteps)
|
||||
snr, gamma, offset_1, offset_2, scale = learnable_snr_trainer.forward(loss, timesteps)
|
||||
|
||||
snr = snr * scale + offset
|
||||
snr = (snr + offset_1) * scale + offset_2
|
||||
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||
|
||||
Reference in New Issue
Block a user