initial commit
This commit is contained in:
213
extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py
Executable file
213
extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py
Executable file
@@ -0,0 +1,213 @@
|
||||
import torch
|
||||
|
||||
from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter
|
||||
from modules_forge.shared import add_supported_preprocessor
|
||||
from backend.sampling.sampling_function import sampling_function_inner
|
||||
from backend import attention
|
||||
|
||||
|
||||
def sdp(q, k, v, transformer_options):
|
||||
if q.shape[0] == 0:
|
||||
return q
|
||||
|
||||
return attention.attention_function(q, k, v, heads=transformer_options["n_heads"], mask=None)
|
||||
|
||||
|
||||
def adain(x, target_std, target_mean):
|
||||
if x.shape[0] == 0:
|
||||
return x
|
||||
|
||||
std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
||||
return (((x - mean) / std) * target_std) + target_mean
|
||||
|
||||
|
||||
def zero_cat(a, b, dim):
|
||||
if a.shape[0] == 0:
|
||||
return b
|
||||
if b.shape[0] == 0:
|
||||
return a
|
||||
return torch.cat([a, b], dim=dim)
|
||||
|
||||
|
||||
class PreprocessorReference(Preprocessor):
|
||||
def __init__(self, name, use_attn=True, use_adain=True, priority=0):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.use_attn = use_attn
|
||||
self.use_adain = use_adain
|
||||
self.sorting_priority = priority
|
||||
self.tags = ['Reference']
|
||||
self.slider_resolution = PreprocessorParameter(visible=False)
|
||||
self.slider_1 = PreprocessorParameter(label='Style Fidelity', value=0.5, minimum=0.0, maximum=1.0, step=0.01, visible=True)
|
||||
self.show_control_mode = False
|
||||
self.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = False
|
||||
self.do_not_need_model = True
|
||||
|
||||
self.is_recording_style = False
|
||||
self.recorded_attn1 = {}
|
||||
self.recorded_h = {}
|
||||
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
unit = kwargs['unit']
|
||||
weight = float(unit.weight)
|
||||
style_fidelity = float(unit.threshold_a)
|
||||
start_percent = float(unit.guidance_start)
|
||||
end_percent = float(unit.guidance_end)
|
||||
|
||||
if process.sd_model.is_sdxl:
|
||||
style_fidelity = style_fidelity ** 3.0 # sdxl is very sensitive to reference so we lower the weights
|
||||
|
||||
vae = process.sd_model.forge_objects.vae
|
||||
# This is a powerful VAE with integrated memory management, bf16, and tiled fallback.
|
||||
|
||||
latent_image = vae.encode(cond.movedim(1, -1))
|
||||
latent_image = process.sd_model.forge_objects.vae.first_stage_model.process_in(latent_image)
|
||||
|
||||
gen_seed = process.seeds[0] + 1
|
||||
gen_cpu = torch.Generator().manual_seed(gen_seed)
|
||||
|
||||
unet = process.sd_model.forge_objects.unet.clone()
|
||||
sigma_max = unet.model.predictor.percent_to_sigma(start_percent)
|
||||
sigma_min = unet.model.predictor.percent_to_sigma(end_percent)
|
||||
|
||||
self.recorded_attn1 = {}
|
||||
self.recorded_h = {}
|
||||
|
||||
def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed):
|
||||
sigma = timestep[0].item()
|
||||
if not (sigma_min <= sigma <= sigma_max):
|
||||
return model, x, timestep, uncond, cond, cond_scale, model_options, seed
|
||||
|
||||
self.is_recording_style = True
|
||||
|
||||
xt = latent_image.to(x) + torch.randn(x.size(), dtype=x.dtype, generator=gen_cpu).to(x) * sigma
|
||||
sampling_function_inner(model, xt, timestep, uncond, cond, 1, model_options, seed)
|
||||
|
||||
self.is_recording_style = False
|
||||
|
||||
return model, x, timestep, uncond, cond, cond_scale, model_options, seed
|
||||
|
||||
def block_proc(h, flag, transformer_options):
|
||||
if not self.use_adain:
|
||||
return h
|
||||
|
||||
if flag != 'after':
|
||||
return h
|
||||
|
||||
location = transformer_options['block']
|
||||
|
||||
sigma = transformer_options["sigmas"][0].item()
|
||||
if not (sigma_min <= sigma <= sigma_max):
|
||||
return h
|
||||
|
||||
channel = int(h.shape[1])
|
||||
minimal_channel = 1500 - 1000 * weight
|
||||
|
||||
if channel < minimal_channel:
|
||||
return h
|
||||
|
||||
if self.is_recording_style:
|
||||
self.recorded_h[location] = torch.std_mean(h, dim=(2, 3), keepdim=True, correction=0)
|
||||
return h
|
||||
else:
|
||||
cond_indices = transformer_options['cond_indices']
|
||||
uncond_indices = transformer_options['uncond_indices']
|
||||
cond_or_uncond = transformer_options['cond_or_uncond']
|
||||
r_std, r_mean = self.recorded_h[location]
|
||||
|
||||
h_c = h[cond_indices]
|
||||
h_uc = h[uncond_indices]
|
||||
|
||||
o_c = adain(h_c, r_std, r_mean)
|
||||
o_uc_strong = h_uc
|
||||
o_uc_weak = adain(h_uc, r_std, r_mean)
|
||||
o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity
|
||||
|
||||
recon = []
|
||||
for cx in cond_or_uncond:
|
||||
if cx == 0:
|
||||
recon.append(o_c)
|
||||
else:
|
||||
recon.append(o_uc)
|
||||
|
||||
o = torch.cat(recon, dim=0)
|
||||
return o
|
||||
|
||||
def attn1_proc(q, k, v, transformer_options):
|
||||
if not self.use_attn:
|
||||
return sdp(q, k, v, transformer_options)
|
||||
|
||||
sigma = transformer_options["sigmas"][0].item()
|
||||
if not (sigma_min <= sigma <= sigma_max):
|
||||
return sdp(q, k, v, transformer_options)
|
||||
|
||||
location = (transformer_options['block'][0], transformer_options['block'][1],
|
||||
transformer_options['block_index'])
|
||||
|
||||
channel = int(q.shape[2])
|
||||
minimal_channel = 1500 - 1280 * weight
|
||||
|
||||
if channel < minimal_channel:
|
||||
return sdp(q, k, v, transformer_options)
|
||||
|
||||
if self.is_recording_style:
|
||||
self.recorded_attn1[location] = (k, v)
|
||||
return sdp(q, k, v, transformer_options)
|
||||
else:
|
||||
cond_indices = transformer_options['cond_indices']
|
||||
uncond_indices = transformer_options['uncond_indices']
|
||||
cond_or_uncond = transformer_options['cond_or_uncond']
|
||||
|
||||
q_c = q[cond_indices]
|
||||
q_uc = q[uncond_indices]
|
||||
|
||||
k_c = k[cond_indices]
|
||||
k_uc = k[uncond_indices]
|
||||
|
||||
v_c = v[cond_indices]
|
||||
v_uc = v[uncond_indices]
|
||||
|
||||
k_r, v_r = self.recorded_attn1[location]
|
||||
|
||||
o_c = sdp(q_c, zero_cat(k_c, k_r, dim=1), zero_cat(v_c, v_r, dim=1), transformer_options)
|
||||
o_uc_strong = sdp(q_uc, k_uc, v_uc, transformer_options)
|
||||
o_uc_weak = sdp(q_uc, zero_cat(k_uc, k_r, dim=1), zero_cat(v_uc, v_r, dim=1), transformer_options)
|
||||
o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity
|
||||
|
||||
recon = []
|
||||
for cx in cond_or_uncond:
|
||||
if cx == 0:
|
||||
recon.append(o_c)
|
||||
else:
|
||||
recon.append(o_uc)
|
||||
|
||||
o = torch.cat(recon, dim=0)
|
||||
return o
|
||||
|
||||
unet.add_block_modifier(block_proc)
|
||||
unet.add_conditioning_modifier(conditioning_modifier)
|
||||
unet.set_model_replace_all(attn1_proc, 'attn1')
|
||||
|
||||
process.sd_model.forge_objects.unet = unet
|
||||
|
||||
return cond, mask
|
||||
|
||||
|
||||
add_supported_preprocessor(PreprocessorReference(
|
||||
name='reference_only',
|
||||
use_attn=True,
|
||||
use_adain=False,
|
||||
priority=100
|
||||
))
|
||||
|
||||
add_supported_preprocessor(PreprocessorReference(
|
||||
name='reference_adain',
|
||||
use_attn=False,
|
||||
use_adain=True
|
||||
))
|
||||
|
||||
add_supported_preprocessor(PreprocessorReference(
|
||||
name='reference_adain+attn',
|
||||
use_attn=True,
|
||||
use_adain=True
|
||||
))
|
||||
Reference in New Issue
Block a user