mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-12 14:09:49 +00:00
Added ability to pair samples with a closer noise with optimal_noise_pairing_samples
This commit is contained in:
@@ -778,15 +778,37 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
def get_optimal_noise(self, latents, dtype=torch.float32):
|
||||
batch_num = latents.shape[0]
|
||||
chunks = torch.chunk(latents, batch_num, dim=0)
|
||||
noise_chunks = []
|
||||
for chunk in chunks:
|
||||
noise_samples = [torch.randn_like(chunk, device=chunk.device, dtype=dtype) for _ in range(self.train_config.optimal_noise_pairing_samples)]
|
||||
# find the one most similar to the chunk
|
||||
lowest_loss = 999999999999
|
||||
best_noise = None
|
||||
for noise in noise_samples:
|
||||
loss = torch.nn.functional.mse_loss(chunk, noise)
|
||||
if loss < lowest_loss:
|
||||
lowest_loss = loss
|
||||
best_noise = noise
|
||||
noise_chunks.append(best_noise)
|
||||
noise = torch.cat(noise_chunks, dim=0)
|
||||
return noise
|
||||
|
||||
|
||||
def get_noise(self, latents, batch_size, dtype=torch.float32):
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
if self.train_config.optimal_noise_pairing_samples > 1:
|
||||
noise = self.get_optimal_noise(latents, dtype=dtype)
|
||||
else:
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.train_config.random_noise_shift > 0.0:
|
||||
# get random noise -1 to 1
|
||||
|
||||
@@ -403,6 +403,9 @@ class TrainConfig:
|
||||
|
||||
# diffusion feature extractor
|
||||
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
|
||||
|
||||
# optimal noise pairing
|
||||
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
Reference in New Issue
Block a user