restore some legacy codes from webui

#933
This commit is contained in:
layerdiffusion
2024-08-05 11:25:05 -07:00
parent 47ed1bacb2
commit d291672a30
2 changed files with 35 additions and 1 deletions

View File

@@ -31,6 +31,10 @@ def pad_cond(tensor, repeats, empty):
return tensor
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
class CFGDenoiser(torch.nn.Module):
"""
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
@@ -66,6 +70,36 @@ class CFGDenoiser(torch.nn.Module):
self.classic_ddim_eps_estimation = False
self.sigmas = model.forge_objects.unet.model.predictor.sigmas
self.log_sigmas = self.sigmas.log()
def get_sigmas(self, n=None):
if n is None:
return append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return append_zero(self.t_to_sigma(t))
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
log_sigma = sigma.log()
dists = log_sigma - self.log_sigmas[:, None]
if quantize:
return dists.abs().argmin(dim=0).view(sigma.shape)
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)

View File

@@ -85,7 +85,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif scheduler is None or scheduler.function is None:
raise ValueError('Wrong scheduler!')
sigmas = self.model_wrap.get_sigmas(steps)
else:
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}