mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-01 19:19:52 +00:00
rework sample function
This commit is contained in:
@@ -1,123 +1,3 @@
|
||||
import torch
|
||||
from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn
|
||||
from ldm_patched.modules.samplers import sampling_function
|
||||
from ldm_patched.modules import model_management
|
||||
from ldm_patched.modules.ops import cleanup_cache
|
||||
from backend.sampling.sampling_function import *
|
||||
|
||||
|
||||
def cond_from_a1111_to_patched_ldm(cond):
|
||||
if isinstance(cond, torch.Tensor):
|
||||
result = dict(
|
||||
cross_attn=cond,
|
||||
model_conds=dict(
|
||||
c_crossattn=CONDCrossAttn(cond),
|
||||
)
|
||||
)
|
||||
return [result, ]
|
||||
|
||||
cross_attn = cond['crossattn']
|
||||
pooled_output = cond['vector']
|
||||
|
||||
result = dict(
|
||||
cross_attn=cross_attn,
|
||||
pooled_output=pooled_output,
|
||||
model_conds=dict(
|
||||
c_crossattn=CONDCrossAttn(cross_attn),
|
||||
y=CONDRegular(pooled_output)
|
||||
)
|
||||
)
|
||||
|
||||
return [result, ]
|
||||
|
||||
|
||||
def cond_from_a1111_to_patched_ldm_weighted(cond, weights):
|
||||
transposed = list(map(list, zip(*weights)))
|
||||
results = []
|
||||
|
||||
for cond_pre in transposed:
|
||||
current_indices = []
|
||||
current_weight = 0
|
||||
for i, w in cond_pre:
|
||||
current_indices.append(i)
|
||||
current_weight = w
|
||||
|
||||
if hasattr(cond, 'advanced_indexing'):
|
||||
feed = cond.advanced_indexing(current_indices)
|
||||
else:
|
||||
feed = cond[current_indices]
|
||||
|
||||
h = cond_from_a1111_to_patched_ldm(feed)
|
||||
h[0]['strength'] = current_weight
|
||||
results += h
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def forge_sample(self, denoiser_params, cond_scale, cond_composition):
|
||||
model = self.inner_model.inner_model.forge_objects.unet.model
|
||||
control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list
|
||||
extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition
|
||||
x = denoiser_params.x
|
||||
timestep = denoiser_params.sigma
|
||||
uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)
|
||||
cond = cond_from_a1111_to_patched_ldm_weighted(denoiser_params.text_cond, cond_composition)
|
||||
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
|
||||
seed = self.p.seeds[0]
|
||||
|
||||
if extra_concat_condition is not None:
|
||||
image_cond_in = extra_concat_condition
|
||||
else:
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
|
||||
if isinstance(image_cond_in, torch.Tensor):
|
||||
if image_cond_in.shape[0] == x.shape[0] \
|
||||
and image_cond_in.shape[2] == x.shape[2] \
|
||||
and image_cond_in.shape[3] == x.shape[3]:
|
||||
for i in range(len(uncond)):
|
||||
uncond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in)
|
||||
for i in range(len(cond)):
|
||||
cond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in)
|
||||
|
||||
if control is not None:
|
||||
for h in cond + uncond:
|
||||
h['control'] = control
|
||||
|
||||
for modifier in model_options.get('conditioning_modifiers', []):
|
||||
model, x, timestep, uncond, cond, cond_scale, model_options, seed = modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
|
||||
|
||||
denoised = sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
|
||||
return denoised
|
||||
|
||||
|
||||
def sampling_prepare(unet, x):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
memory_estimation_function = unet.model_options.get('memory_peak_estimation_modifier', unet.memory_required)
|
||||
|
||||
unet_inference_memory = memory_estimation_function([B * 2, C, H, W])
|
||||
additional_inference_memory = unet.extra_preserved_memory_during_sampling
|
||||
additional_model_patchers = unet.extra_model_patchers_during_sampling
|
||||
|
||||
if unet.controlnet_linked_list is not None:
|
||||
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
|
||||
additional_model_patchers += unet.controlnet_linked_list.get_models()
|
||||
|
||||
model_management.load_models_gpu(
|
||||
models=[unet] + additional_model_patchers,
|
||||
memory_required=unet_inference_memory + additional_inference_memory)
|
||||
|
||||
real_model = unet.model
|
||||
|
||||
percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p)
|
||||
|
||||
for cnet in unet.list_controlnets():
|
||||
cnet.pre_run(real_model, percent_to_timestep_function)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def sampling_cleanup(unet):
|
||||
for cnet in unet.list_controlnets():
|
||||
cnet.cleanup()
|
||||
cleanup_cache()
|
||||
return
|
||||
forge_sample = sampling_function
|
||||
|
||||
Reference in New Issue
Block a user