diff --git a/modules/sd_models.py b/modules/sd_models.py index 9c5909168..7d4ab0fd8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -733,6 +733,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config + # ldm's Unet is using self.dtype to cast input tensor. If we do not overwrite + # UnetModel.dtype, it will be the default dtype from config. + # sgm's Unet is not using dtype for casting. The value will be ignored. + sd_model.model.diffusion_model.dtype = devices.dtype_unet timer.record("create model")