mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 16:59:22 +00:00
Added masking to slider training. Something is still weird though
This commit is contained in:
@@ -691,12 +691,12 @@ class LearnableSNRGamma:
|
||||
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
|
||||
self.device = device
|
||||
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
|
||||
self.offset = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
|
||||
self.scale = torch.nn.Parameter(torch.tensor(0.001, dtype=torch.float32, device=device))
|
||||
self.gamma = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
|
||||
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.1)
|
||||
self.offset = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
|
||||
self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
|
||||
self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
|
||||
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.01)
|
||||
self.buffer = []
|
||||
self.max_buffer_size = 100
|
||||
self.max_buffer_size = 20
|
||||
|
||||
def forward(self, loss, timesteps):
|
||||
# do a our train loop for lsnr here and return our values detached
|
||||
|
||||
Reference in New Issue
Block a user