mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-21 06:49:08 +00:00
Removed old code for fixing multistep sampler that is no longer needed
This commit is contained in:
@@ -834,42 +834,7 @@ class StableDiffusion:
|
||||
noisy_latents_chunks = []
|
||||
|
||||
for idx in range(original_samples.shape[0]):
|
||||
|
||||
# the add noise for ddpm solver is broken, do it ourselves
|
||||
noise_timesteps = timesteps_chunks[idx]
|
||||
if scheduler_class_name == 'DPMSolverMultistepScheduler':
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.noise_scheduler.sigmas.to(device=original_samples_chunks[idx].device,
|
||||
dtype=original_samples_chunks[idx].dtype)
|
||||
if original_samples_chunks[idx].device.type == "mps" and torch.is_floating_point(noise_timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.noise_scheduler.timesteps.to(original_samples_chunks[idx].device,
|
||||
dtype=torch.float32)
|
||||
noise_timesteps = noise_timesteps.to(original_samples_chunks[idx].device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.noise_scheduler.timesteps.to(original_samples_chunks[idx].device)
|
||||
noise_timesteps = noise_timesteps.to(original_samples_chunks[idx].device)
|
||||
|
||||
step_indices = []
|
||||
for t in noise_timesteps:
|
||||
for i, st in enumerate(schedule_timesteps):
|
||||
if st == t:
|
||||
step_indices.append(i)
|
||||
break
|
||||
|
||||
# find only first match. There can be double here, this breaks
|
||||
# step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
alpha_t, sigma_t = self.noise_scheduler._sigma_to_alpha_sigma_t(sigma)
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise_chunks[idx]
|
||||
noisy_latents = noisy_samples
|
||||
else:
|
||||
noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx],
|
||||
noise_timesteps)
|
||||
noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], timesteps_chunks[idx])
|
||||
noisy_latents_chunks.append(noisy_latents)
|
||||
|
||||
noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user