Added some split prompting started code, adamw8bit, replacements improving, learnable snr gos. A lot of good stuff.

This commit is contained in:
Jaret Burkett
2023-11-01 06:52:21 -06:00
parent 436a09430e
commit a899ec91c8
9 changed files with 149 additions and 18 deletions

View File

@@ -683,6 +683,68 @@ def get_all_snr(noise_scheduler, device):
all_snr.requires_grad = False
return all_snr.to(device)
class LearnableSNRGamma:
"""
This is a trainer for learnable snr gamma
It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps
"""
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
self.device = device
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
self.offset = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.scale = torch.nn.Parameter(torch.tensor(0.001, dtype=torch.float32, device=device))
self.gamma = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.1)
self.buffer = []
self.max_buffer_size = 100
def forward(self, loss, timesteps):
# do a our train loop for lsnr here and return our values detached
loss = loss.detach()
with torch.no_grad():
loss_chunks = torch.chunk(loss, loss.shape[0], dim=0)
for loss_chunk in loss_chunks:
self.buffer.append(loss_chunk.mean().detach())
if len(self.buffer) > self.max_buffer_size:
self.buffer.pop(0)
all_snr = get_all_snr(self.noise_scheduler, loss.device)
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
base_snrs = snr.clone().detach()
snr.requires_grad = True
snr = snr * self.scale + self.offset
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
snr_adjusted_loss = loss * snr_weight
with torch.no_grad():
target = torch.mean(torch.stack(self.buffer)).detach()
# local_loss = torch.mean(torch.abs(snr_adjusted_loss - target))
squared_differences = (snr_adjusted_loss - target) ** 2
local_loss = torch.mean(squared_differences)
local_loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
return base_snrs, self.gamma.detach(), self.offset.detach(), self.scale.detach()
def apply_learnable_snr_gos(
loss,
timesteps,
learnable_snr_trainer:LearnableSNRGamma
):
snr, gamma, offset, scale = learnable_snr_trainer.forward(loss, timesteps)
snr = snr * scale + offset
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
snr_adjusted_loss = loss * snr_weight
return snr_adjusted_loss
def apply_snr_weight(
loss,
@@ -700,5 +762,6 @@ def apply_snr_weight(
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
else:
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)
loss = loss * snr_weight
return loss
snr_adjusted_loss = loss * snr_weight
return snr_adjusted_loss