mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 03:01:15 +00:00
i
This commit is contained in:
@@ -82,7 +82,8 @@ class PreprocessorReference(Preprocessor):
|
|||||||
self.recorded_h[location] = h
|
self.recorded_h[location] = h
|
||||||
return h
|
return h
|
||||||
else:
|
else:
|
||||||
cond_mark = transformer_options['cond_mark'][:, None, None, None] # cond is 0
|
cond_indices = transformer_options['cond_indices']
|
||||||
|
uncond_indices = transformer_options['uncond_indices']
|
||||||
rh = self.recorded_h[location]
|
rh = self.recorded_h[location]
|
||||||
return h
|
return h
|
||||||
|
|
||||||
@@ -101,11 +102,35 @@ class PreprocessorReference(Preprocessor):
|
|||||||
self.recorded_attn1[location] = (k, v)
|
self.recorded_attn1[location] = (k, v)
|
||||||
return sdp(q, k, v, transformer_options)
|
return sdp(q, k, v, transformer_options)
|
||||||
else:
|
else:
|
||||||
cond_mark = transformer_options['cond_mark'][:, None, None, None] # cond is 0
|
cond_indices = transformer_options['cond_indices']
|
||||||
rk, rv = self.recorded_attn1[location]
|
uncond_indices = transformer_options['uncond_indices']
|
||||||
rk = torch.cat([k, rk], dim=1)
|
cond_or_uncond = transformer_options['cond_or_uncond']
|
||||||
rv = torch.cat([v, rv], dim=1)
|
|
||||||
return sdp(q, k, v, transformer_options)
|
q_c = q[cond_indices]
|
||||||
|
q_uc = q[uncond_indices]
|
||||||
|
|
||||||
|
k_c = k[cond_indices]
|
||||||
|
k_uc = k[uncond_indices]
|
||||||
|
|
||||||
|
v_c = v[cond_indices]
|
||||||
|
v_uc = v[uncond_indices]
|
||||||
|
|
||||||
|
k_r, v_r = self.recorded_attn1[location]
|
||||||
|
|
||||||
|
o_c = sdp(q_c, torch.cat([k_c, k_r], dim=1), torch.cat([v_c, v_r], dim=1), transformer_options)
|
||||||
|
o_uc_strong = sdp(q_uc, k_uc, v_uc, transformer_options)
|
||||||
|
o_uc_weak = sdp(q_uc, torch.cat([k_uc, k_r], dim=1), torch.cat([v_uc, v_r], dim=1), transformer_options)
|
||||||
|
o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity
|
||||||
|
|
||||||
|
recon = []
|
||||||
|
for cx in cond_or_uncond:
|
||||||
|
if cx == 0:
|
||||||
|
recon.append(o_c)
|
||||||
|
else:
|
||||||
|
recon.append(o_uc)
|
||||||
|
|
||||||
|
o = torch.cat(recon, dim=0)
|
||||||
|
return o
|
||||||
|
|
||||||
unet.add_block_modifier(block_proc)
|
unet.add_block_modifier(block_proc)
|
||||||
unet.add_conditioning_modifier(conditioning_modifier)
|
unet.add_conditioning_modifier(conditioning_modifier)
|
||||||
|
|||||||
@@ -39,6 +39,20 @@ def compute_cond_mark(cond_or_uncond, sigmas):
|
|||||||
return cond_mark
|
return cond_mark
|
||||||
|
|
||||||
|
|
||||||
|
def compute_cond_indices(cond_or_uncond, sigmas):
|
||||||
|
cl = int(sigmas.shape[0])
|
||||||
|
|
||||||
|
cond_indices = []
|
||||||
|
uncond_indices = []
|
||||||
|
for i, cx in enumerate(cond_or_uncond):
|
||||||
|
if cx == 0:
|
||||||
|
cond_indices += list(range(i * cl, (i + 1) * cl))
|
||||||
|
else:
|
||||||
|
uncond_indices += list(range(i * cl, (i + 1) * cl))
|
||||||
|
|
||||||
|
return cond_indices, uncond_indices
|
||||||
|
|
||||||
|
|
||||||
def generate_random_filename(extension=".txt"):
|
def generate_random_filename(extension=".txt"):
|
||||||
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||||
random_string = ''.join(random.choices(string.ascii_lowercase + string.digits, k=5))
|
random_string = ''.join(random.choices(string.ascii_lowercase + string.digits, k=5))
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from ldm_patched.modules.controlnet import ControlBase
|
|||||||
from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat
|
from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat
|
||||||
from ldm_patched.modules import model_management
|
from ldm_patched.modules import model_management
|
||||||
from modules_forge.controlnet import compute_controlnet_weighting
|
from modules_forge.controlnet import compute_controlnet_weighting
|
||||||
from modules_forge.forge_util import compute_cond_mark
|
from modules_forge.forge_util import compute_cond_mark, compute_cond_indices
|
||||||
|
|
||||||
|
|
||||||
def patched_control_merge(self, control_input, control_output, control_prev, output_dtype):
|
def patched_control_merge(self, control_input, control_output, control_prev, output_dtype):
|
||||||
@@ -155,8 +155,8 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op
|
|||||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||||
transformer_options["sigmas"] = timestep
|
transformer_options["sigmas"] = timestep
|
||||||
|
|
||||||
cond_mark = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep)
|
transformer_options["cond_mark"] = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep)
|
||||||
transformer_options["cond_mark"] = cond_mark
|
transformer_options["cond_indices"], transformer_options["uncond_indices"] = compute_cond_indices(cond_or_uncond=cond_or_uncond, sigmas=timestep)
|
||||||
|
|
||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user