Added ability to pair samples with a closer noise with optimal_noise_pairing_samples

This commit is contained in:
Jaret Burkett
2025-01-21 18:30:10 -07:00
parent 29122b1a54
commit 89dd041b97
2 changed files with 32 additions and 7 deletions

View File

@@ -778,15 +778,37 @@ class BaseSDTrainProcess(BaseTrainProcess):
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def get_optimal_noise(self, latents, dtype=torch.float32):
batch_num = latents.shape[0]
chunks = torch.chunk(latents, batch_num, dim=0)
noise_chunks = []
for chunk in chunks:
noise_samples = [torch.randn_like(chunk, device=chunk.device, dtype=dtype) for _ in range(self.train_config.optimal_noise_pairing_samples)]
# find the one most similar to the chunk
lowest_loss = 999999999999
best_noise = None
for noise in noise_samples:
loss = torch.nn.functional.mse_loss(chunk, noise)
if loss < lowest_loss:
lowest_loss = loss
best_noise = noise
noise_chunks.append(best_noise)
noise = torch.cat(noise_chunks, dim=0)
return noise
def get_noise(self, latents, batch_size, dtype=torch.float32):
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
if self.train_config.optimal_noise_pairing_samples > 1:
noise = self.get_optimal_noise(latents, dtype=dtype)
else:
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1

View File

@@ -403,6 +403,9 @@ class TrainConfig:
# diffusion feature extractor
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
# optimal noise pairing
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)
class ModelConfig: