Added new experimental time step weighing that should solve a lot of issues with distribution. Updated example. Removed a warning

This commit is contained in:
Jaret Burkett
2024-08-13 12:02:11 -06:00
parent 9ee1ef2a0a
commit 418f5f7e8c
3 changed files with 19 additions and 20 deletions

View File

@@ -12,29 +12,23 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000
# Bell-Shaped Mean-Normalized Timestep Weighting
# bsmntw? need a better name
# generate the multiplier based on cosmap loss weighing
# this is only used on linear timesteps for now
x = torch.arange(num_timesteps, dtype=torch.float32)
y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2)
# cosine map weighing is higher in the middle and lower at the ends
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
# cosmap_weighing = 2 / (math.pi * bot)
# Shift minimum to 0
y_shifted = y - y.min()
# sigma sqrt weighing is significantly higher at the end and lower at the beginning
sigma_sqrt_weighing = (self.sigmas ** -2.0).float()
# clip at 1e4 (1e6 is too high)
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
# bring to a mean of 1
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
# Scale to make mean 1
bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
self.linear_timesteps = timesteps
# self.linear_timesteps_weights = cosmap_weighing
self.linear_timesteps_weights = sigma_sqrt_weighing
# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
self.linear_timesteps_weights = bsmntw_weighing
pass
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:

View File

@@ -640,8 +640,8 @@ def add_all_snr_to_noise_scheduler(noise_scheduler, device):
all_snr.requires_grad = False
noise_scheduler.all_snr = all_snr.to(device)
except Exception as e:
print(e)
print("Failed to add all_snr to noise_scheduler")
# just move on
pass
def get_all_snr(noise_scheduler, device):