Added refiner fine tuning. Works, but needs some polish.

This commit is contained in:
Jaret Burkett
2023-11-05 17:15:03 -07:00
parent 8a9e8f708f
commit 93ea955d7c
14 changed files with 4541 additions and 128 deletions

View File

@@ -691,10 +691,11 @@ class LearnableSNRGamma:
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
self.device = device
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
self.offset = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))
self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.01)
self.optimizer = torch.optim.AdamW([self.offset_1, self.offset_2, self.gamma, self.scale], lr=0.01)
self.buffer = []
self.max_buffer_size = 20
@@ -711,7 +712,7 @@ class LearnableSNRGamma:
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
base_snrs = snr.clone().detach()
snr.requires_grad = True
snr = snr * self.scale + self.offset
snr = (snr + self.offset_1) * self.scale + self.offset_2
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
@@ -726,18 +727,18 @@ class LearnableSNRGamma:
self.optimizer.step()
self.optimizer.zero_grad()
return base_snrs, self.gamma.detach(), self.offset.detach(), self.scale.detach()
return base_snrs, self.gamma.detach(), self.offset_1.detach(), self.offset_2.detach(), self.scale.detach()
def apply_learnable_snr_gos(
loss,
timesteps,
learnable_snr_trainer:LearnableSNRGamma
learnable_snr_trainer: LearnableSNRGamma
):
snr, gamma, offset, scale = learnable_snr_trainer.forward(loss, timesteps)
snr, gamma, offset_1, offset_2, scale = learnable_snr_trainer.forward(loss, timesteps)
snr = snr * scale + offset
snr = (snr + offset_1) * scale + offset_2
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr