mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-04 04:29:50 +00:00
predictor is a better name
This commit is contained in:
@@ -12,11 +12,11 @@ class KModel(torch.nn.Module):
|
||||
self.computation_dtype = computation_dtype
|
||||
|
||||
self.diffusion_model = huggingface_components['unet']
|
||||
self.prediction = k_prediction_from_diffusers_scheduler(huggingface_components['scheduler'])
|
||||
self.predictor = k_prediction_from_diffusers_scheduler(huggingface_components['scheduler'])
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
xc = self.prediction.calculate_input(sigma, x)
|
||||
xc = self.predictor.calculate_input(sigma, x)
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||
|
||||
@@ -24,7 +24,7 @@ class KModel(torch.nn.Module):
|
||||
dtype = self.computation_dtype
|
||||
|
||||
xc = xc.to(dtype)
|
||||
t = self.prediction.timestep(t).float()
|
||||
t = self.predictor.timestep(t).float()
|
||||
context = context.to(dtype)
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
@@ -35,7 +35,7 @@ class KModel(torch.nn.Module):
|
||||
extra_conds[o] = extra
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
return self.prediction.calculate_denoised(sigma, model_output, x)
|
||||
return self.predictor.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
|
||||
@@ -154,7 +154,7 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
|
||||
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
cond, mask = super().process_before_every_sampling(process, cond, mask, *args, **kwargs)
|
||||
sigma_max = process.sd_model.forge_objects.unet.model.prediction.sigma_max
|
||||
sigma_max = process.sd_model.forge_objects.unet.model.predictor.sigma_max
|
||||
original_noise = kwargs['noise']
|
||||
process.modified_noise = original_noise + self.latent.to(original_noise) / sigma_max.to(original_noise)
|
||||
return cond, mask
|
||||
|
||||
@@ -67,8 +67,8 @@ class PreprocessorReference(Preprocessor):
|
||||
gen_cpu = torch.Generator().manual_seed(gen_seed)
|
||||
|
||||
unet = process.sd_model.forge_objects.unet.clone()
|
||||
sigma_max = unet.model.prediction.percent_to_sigma(start_percent)
|
||||
sigma_min = unet.model.prediction.percent_to_sigma(end_percent)
|
||||
sigma_max = unet.model.predictor.percent_to_sigma(start_percent)
|
||||
sigma_min = unet.model.predictor.percent_to_sigma(end_percent)
|
||||
|
||||
self.recorded_attn1 = {}
|
||||
self.recorded_h = {}
|
||||
|
||||
@@ -43,7 +43,7 @@ class PreprocessorTileColorFix(PreprocessorTile):
|
||||
latent = self.register_latent(process, cond)
|
||||
|
||||
unet = process.sd_model.forge_objects.unet.clone()
|
||||
sigma_data = process.sd_model.forge_objects.unet.model.prediction.sigma_data
|
||||
sigma_data = process.sd_model.forge_objects.unet.model.predictor.sigma_data
|
||||
|
||||
if getattr(process, 'is_hr_pass', False):
|
||||
k = int(self.variation * 2)
|
||||
|
||||
@@ -38,7 +38,7 @@ class DynamicThresholdingNode:
|
||||
cond = input - args["cond"]
|
||||
uncond = input - args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
time_step = model.model.prediction.timestep(args["sigma"])
|
||||
time_step = model.model.predictor.timestep(args["sigma"])
|
||||
time_step = time_step[0].item()
|
||||
dynamic_thresh.step = 999 - time_step
|
||||
|
||||
|
||||
@@ -102,8 +102,8 @@ class FooocusInpaintPatcher(ControlModelPatcher):
|
||||
if not_patched_count > 0:
|
||||
print(f"[Fooocus Patch Loader] Failed to load {not_patched_count} keys")
|
||||
|
||||
sigma_start = unet.model.prediction.percent_to_sigma(self.start_percent)
|
||||
sigma_end = unet.model.prediction.percent_to_sigma(self.end_percent)
|
||||
sigma_start = unet.model.predictor.percent_to_sigma(self.start_percent)
|
||||
sigma_end = unet.model.predictor.percent_to_sigma(self.end_percent)
|
||||
|
||||
def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed):
|
||||
if timestep > sigma_start or timestep < sigma_end:
|
||||
|
||||
@@ -760,8 +760,8 @@ class IPAdapterApply:
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(self.device)
|
||||
|
||||
sigma_start = model.model.prediction.percent_to_sigma(start_at)
|
||||
sigma_end = model.model.prediction.percent_to_sigma(end_at)
|
||||
sigma_start = model.model.predictor.percent_to_sigma(start_at)
|
||||
sigma_end = model.model.predictor.percent_to_sigma(end_at)
|
||||
|
||||
patch_kwargs = {
|
||||
"number": 0,
|
||||
|
||||
@@ -919,10 +919,10 @@ class ModelSamplerLatentMegaModifier:
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
timestep = model.model.prediction.timestep(args["timestep"])
|
||||
timestep = model.model.predictor.timestep(args["timestep"])
|
||||
sigma = args["sigma"]
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
|
||||
#print(model.model.prediction.timestep(timestep))
|
||||
#print(model.model.predictor.timestep(timestep))
|
||||
|
||||
x = x_input / (sigma * sigma + 1.0)
|
||||
cond = ((x - (x_input - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
|
||||
|
||||
@@ -285,7 +285,7 @@ class ControlNet(ControlBase):
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.model_sampling_current = model.prediction
|
||||
self.model_sampling_current = model.predictor
|
||||
|
||||
def cleanup(self):
|
||||
self.model_sampling_current = None
|
||||
|
||||
@@ -108,7 +108,7 @@ def sampling_prepare(unet, x):
|
||||
|
||||
real_model = unet.model
|
||||
|
||||
percent_to_timestep_function = lambda p: real_model.prediction.percent_to_sigma(p)
|
||||
percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p)
|
||||
|
||||
for cnet in unet.list_controlnets():
|
||||
cnet.pre_run(real_model, percent_to_timestep_function)
|
||||
|
||||
Reference in New Issue
Block a user