This commit is contained in:
lllyasviel
2024-01-26 15:14:55 -08:00
parent f998a26cdb
commit 31a1ee761c

View File

@@ -6,6 +6,7 @@ import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
from ldm_patched.modules import model_management
def catenate_conds(conds):
@@ -209,14 +210,14 @@ class CFGDenoiser(torch.nn.Module):
uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True
unet_dtype = self.inner_model.inner_model.forge_objects.unet.model.model_config.unet_config['dtype']
unet_input_dtype = torch.float16 if model_management.should_use_fp16() else torch.float32
x_input_dtype = x_in.dtype
x_in = x_in.to(unet_dtype)
sigma_in = sigma_in.to(unet_dtype)
image_cond_in = image_cond_in.to(unet_dtype)
tensor = tensor.to(unet_dtype)
uncond = uncond.to(unet_dtype)
x_in = x_in.to(unet_input_dtype)
sigma_in = sigma_in.to(unet_input_dtype)
image_cond_in = image_cond_in.to(unet_input_dtype)
tensor = tensor.to(unet_input_dtype)
uncond = uncond.to(unet_input_dtype)
self.inner_model.inner_model.current_sigmas = sigma_in