diff --git a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py index 3ef4f72c..3bf7760d 100644 --- a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py +++ b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py @@ -25,7 +25,7 @@ def extra_options_to_module_prefix(extra_options): return module_pfx -def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent): +def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent, model_dtype): # calculate start and end step start_step = math.floor(num_steps * start_percent) if start_percent > 0 else 0 end_step = math.floor(num_steps * end_percent) if end_percent > 0 else num_steps @@ -63,6 +63,7 @@ def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, st num_steps=num_steps, start_step=start_step, end_step=end_step, + dtype=model_dtype ) info = module.load_state_dict(weights) modules[module_name] = module @@ -125,6 +126,7 @@ class LLLiteModule(torch.nn.Module): num_steps: int, start_step: int, end_step: int, + dtype: torch.dtype = torch.float32 ): super().__init__() self.name = name @@ -181,6 +183,7 @@ class LLLiteModule(torch.nn.Module): self.cond_image = None self.cond_emb = None self.current_step = 0 + self.to(dtype=dtype) # @torch.inference_mode() def set_cond_image(self, cond_image): @@ -260,7 +263,8 @@ class LLLiteLoader: # cond_image is b,h,w,3, 0-1 model_lllite = model.clone() - patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent) + patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent, + model.model.diffusion_model.computation_dtype) if patch is not None: model_lllite.set_model_attn1_patch(patch) model_lllite.set_model_attn2_patch(patch)