diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 21d5006e..8aa26982 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -6,6 +6,7 @@ import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback +from ldm_patched.modules import model_management def catenate_conds(conds): @@ -209,14 +210,14 @@ class CFGDenoiser(torch.nn.Module): uncond = pad_cond(uncond, num_repeats, empty) self.padded_cond_uncond = True - unet_dtype = self.inner_model.inner_model.forge_objects.unet.model.model_config.unet_config['dtype'] + unet_input_dtype = torch.float16 if model_management.should_use_fp16() else torch.float32 x_input_dtype = x_in.dtype - x_in = x_in.to(unet_dtype) - sigma_in = sigma_in.to(unet_dtype) - image_cond_in = image_cond_in.to(unet_dtype) - tensor = tensor.to(unet_dtype) - uncond = uncond.to(unet_dtype) + x_in = x_in.to(unet_input_dtype) + sigma_in = sigma_in.to(unet_input_dtype) + image_cond_in = image_cond_in.to(unet_input_dtype) + tensor = tensor.to(unet_input_dtype) + uncond = uncond.to(unet_input_dtype) self.inner_model.inner_model.current_sigmas = sigma_in