Added masking to slider training. Something is still weird though

This commit is contained in:
Jaret Burkett
2023-11-01 14:51:29 -06:00
parent a899ec91c8
commit 7d707b2fe6
6 changed files with 97 additions and 25 deletions

View File

@@ -691,12 +691,12 @@ class LearnableSNRGamma:
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
self.device = device
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
self.offset = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.scale = torch.nn.Parameter(torch.tensor(0.001, dtype=torch.float32, device=device))
self.gamma = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.1)
self.offset = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.01)
self.buffer = []
self.max_buffer_size = 100
self.max_buffer_size = 20
def forward(self, loss, timesteps):
# do a our train loop for lsnr here and return our values detached