mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-28 18:21:48 +00:00
i
This commit is contained in:
@@ -38,11 +38,11 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
|
||||||
if self.model.diffusion_model.in_channels == 9:
|
if self.model.diffusion_model.in_channels == 9:
|
||||||
x = torch.cat([x] + cond['c_concat'], dim=1)
|
x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||||
|
|
||||||
return self.model(x, t, cond)
|
return self.model(x, t, cond, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||||
|
|||||||
@@ -77,10 +77,13 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
if "sampler_cfg_function" in model_options or "sampler_post_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options or "sampler_post_cfg_function" in model_options:
|
||||||
cond_scale = float(cond_scale)
|
cond_scale = float(cond_scale)
|
||||||
model = self.inner_model.inner_model.forge_objects.unet
|
model = self.inner_model.inner_model.forge_objects.unet.model
|
||||||
x = x_in[-uncond.shape[0]:]
|
x = x_in[-uncond.shape[0]:]
|
||||||
uncond_pred = denoised_uncond
|
uncond_pred = denoised_uncond
|
||||||
cond_pred = ((denoised - uncond_pred) / cond_scale) + uncond_pred
|
cond_pred = ((denoised - uncond_pred) / cond_scale) + uncond_pred
|
||||||
|
timestep = timestep[-uncond.shape[0]:]
|
||||||
|
|
||||||
|
from modules_forge.forge_util import cond_from_a1111_to_patched_ldm
|
||||||
|
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale,
|
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale,
|
||||||
@@ -93,10 +96,15 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
# sanity_check = torch.allclose(cfg_result, denoised)
|
# sanity_check = torch.allclose(cfg_result, denoised)
|
||||||
|
|
||||||
for fn in model_options.get("sampler_post_cfg_function", []):
|
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||||
args = {"denoised": cfg_result, "cond": cond,
|
args = {"denoised": cfg_result,
|
||||||
"uncond": uncond, "model": model,
|
"cond": cond_from_a1111_to_patched_ldm(cond),
|
||||||
"uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
"uncond": cond_from_a1111_to_patched_ldm(uncond),
|
||||||
"sigma": timestep, "model_options": model_options, "input": x}
|
"model": model,
|
||||||
|
"uncond_denoised": uncond_pred,
|
||||||
|
"cond_denoised": cond_pred,
|
||||||
|
"sigma": timestep,
|
||||||
|
"model_options": model_options,
|
||||||
|
"input": x}
|
||||||
cfg_result = fn(args)
|
cfg_result = fn(args)
|
||||||
else:
|
else:
|
||||||
cfg_result = denoised
|
cfg_result = denoised
|
||||||
|
|||||||
17
modules_forge/forge_util.py
Normal file
17
modules_forge/forge_util.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn
|
||||||
|
|
||||||
|
|
||||||
|
def cond_from_a1111_to_patched_ldm(cond):
|
||||||
|
cross_attn = cond['crossattn']
|
||||||
|
pooled_output = cond['vector']
|
||||||
|
|
||||||
|
result = dict(
|
||||||
|
cross_attn=cross_attn,
|
||||||
|
pooled_output=pooled_output,
|
||||||
|
model_conds=dict(
|
||||||
|
c_crossattn=CONDCrossAttn(cross_attn),
|
||||||
|
y=CONDRegular(pooled_output)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [result, ]
|
||||||
Reference in New Issue
Block a user