diff --git a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py index 3ef4f72c..9e917c4f 100644 --- a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py +++ b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py @@ -25,7 +25,7 @@ def extra_options_to_module_prefix(extra_options): return module_pfx -def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent): +def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent, model_dtype): # calculate start and end step start_step = math.floor(num_steps * start_percent) if start_percent > 0 else 0 end_step = math.floor(num_steps * end_percent) if end_percent > 0 else num_steps @@ -63,6 +63,7 @@ def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, st num_steps=num_steps, start_step=start_step, end_step=end_step, + dtype=model_dtype ) info = module.load_state_dict(weights) modules[module_name] = module @@ -125,6 +126,7 @@ class LLLiteModule(torch.nn.Module): num_steps: int, start_step: int, end_step: int, + dtype: torch.dtype = torch.float32 ): super().__init__() self.name = name @@ -134,6 +136,7 @@ class LLLiteModule(torch.nn.Module): self.start_step = start_step self.end_step = end_step self.is_first = False + self.dtype = dtype modules = [] modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2 @@ -181,6 +184,7 @@ class LLLiteModule(torch.nn.Module): self.cond_image = None self.cond_emb = None self.current_step = 0 + self.to(dtype=dtype) # @torch.inference_mode() def set_cond_image(self, cond_image): @@ -210,6 +214,8 @@ class LLLiteModule(torch.nn.Module): if self.cond_emb is None: # print(f"cond_emb is None, {self.name}") + # test bad idea + #self.cond_image = self.cond_image.view(dtype=x.dtype) cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype)) if not self.is_conv2d: # reshape / b,c,h,w -> b,h*w,c @@ -260,7 +266,8 @@ class LLLiteLoader: # cond_image is b,h,w,3, 0-1 model_lllite = model.clone() - patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent) + patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent, + model_lllite.model.storage_dtype) if patch is not None: model_lllite.set_model_attn1_patch(patch) model_lllite.set_model_attn2_patch(patch)