mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-20 22:39:09 +00:00
better sampler
This commit is contained in:
49
modules_forge/forge_sampler.py
Normal file
49
modules_forge/forge_sampler.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn
|
||||
from ldm_patched.modules.samplers import sampling_function
|
||||
|
||||
|
||||
def cond_from_a1111_to_patched_ldm(cond):
|
||||
if isinstance(cond, torch.Tensor):
|
||||
result = dict(
|
||||
cross_attn=cond,
|
||||
model_conds=dict(
|
||||
c_crossattn=CONDCrossAttn(cond),
|
||||
)
|
||||
)
|
||||
return [result, ]
|
||||
|
||||
cross_attn = cond['crossattn']
|
||||
pooled_output = cond['vector']
|
||||
|
||||
result = dict(
|
||||
cross_attn=cross_attn,
|
||||
pooled_output=pooled_output,
|
||||
model_conds=dict(
|
||||
c_crossattn=CONDCrossAttn(cross_attn),
|
||||
y=CONDRegular(pooled_output)
|
||||
)
|
||||
)
|
||||
|
||||
return [result, ]
|
||||
|
||||
|
||||
def forge_sample(self, denoiser_params, cond_scale):
|
||||
model = self.inner_model.inner_model.forge_objects.unet.model
|
||||
x = denoiser_params.x
|
||||
timestep = denoiser_params.sigma
|
||||
uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)
|
||||
cond = cond_from_a1111_to_patched_ldm(denoiser_params.text_cond)
|
||||
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
|
||||
seed = self.p.seeds[0]
|
||||
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
if isinstance(image_cond_in, torch.Tensor):
|
||||
if image_cond_in.shape[0] == x.shape[0] \
|
||||
and image_cond_in.shape[2] == x.shape[2] \
|
||||
and image_cond_in.shape[3] == x.shape[3]:
|
||||
uncond[0]['model_conds']['c_concat'] = CONDRegular(image_cond_in)
|
||||
cond[0]['model_conds']['c_concat'] = CONDRegular(image_cond_in)
|
||||
|
||||
denoised = sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
|
||||
return denoised
|
||||
Reference in New Issue
Block a user