From 350018eff542c9ade13159450be23d59f6ec3dcd Mon Sep 17 00:00:00 2001 From: wcole3 Date: Mon, 2 Sep 2024 09:52:13 -0400 Subject: [PATCH 1/4] Set controlllite module dtype to model dtype --- .../lib_controllllite/lib_controllllite.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py index 3ef4f72c..9e917c4f 100644 --- a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py +++ b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py @@ -25,7 +25,7 @@ def extra_options_to_module_prefix(extra_options): return module_pfx -def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent): +def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent, model_dtype): # calculate start and end step start_step = math.floor(num_steps * start_percent) if start_percent > 0 else 0 end_step = math.floor(num_steps * end_percent) if end_percent > 0 else num_steps @@ -63,6 +63,7 @@ def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, st num_steps=num_steps, start_step=start_step, end_step=end_step, + dtype=model_dtype ) info = module.load_state_dict(weights) modules[module_name] = module @@ -125,6 +126,7 @@ class LLLiteModule(torch.nn.Module): num_steps: int, start_step: int, end_step: int, + dtype: torch.dtype = torch.float32 ): super().__init__() self.name = name @@ -134,6 +136,7 @@ class LLLiteModule(torch.nn.Module): self.start_step = start_step self.end_step = end_step self.is_first = False + self.dtype = dtype modules = [] modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2 @@ -181,6 +184,7 @@ class LLLiteModule(torch.nn.Module): self.cond_image = None self.cond_emb = None self.current_step = 0 + self.to(dtype=dtype) # @torch.inference_mode() def set_cond_image(self, cond_image): @@ -210,6 +214,8 @@ class LLLiteModule(torch.nn.Module): if self.cond_emb is None: # print(f"cond_emb is None, {self.name}") + # test bad idea + #self.cond_image = self.cond_image.view(dtype=x.dtype) cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype)) if not self.is_conv2d: # reshape / b,c,h,w -> b,h*w,c @@ -260,7 +266,8 @@ class LLLiteLoader: # cond_image is b,h,w,3, 0-1 model_lllite = model.clone() - patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent) + patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent, + model_lllite.model.storage_dtype) if patch is not None: model_lllite.set_model_attn1_patch(patch) model_lllite.set_model_attn2_patch(patch) From dee7a3c287aacb770e6f7c9b6b4fcd3f8bd9950b Mon Sep 17 00:00:00 2001 From: wcole3 Date: Mon, 2 Sep 2024 12:23:17 -0400 Subject: [PATCH 2/4] Use computation_dtype instead of storage_dtype --- .../lib_controllllite/lib_controllllite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py index 9e917c4f..1aae0bad 100644 --- a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py +++ b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py @@ -267,7 +267,7 @@ class LLLiteLoader: model_lllite = model.clone() patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent, - model_lllite.model.storage_dtype) + model.model.diffusion_model.computation_dtype) if patch is not None: model_lllite.set_model_attn1_patch(patch) model_lllite.set_model_attn2_patch(patch) From ef112ae8d28bd2a03124e4763a1ec89339b269fa Mon Sep 17 00:00:00 2001 From: wcole3 Date: Mon, 2 Sep 2024 12:38:31 -0400 Subject: [PATCH 3/4] Cleanup --- .../lib_controllllite/lib_controllllite.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py index 1aae0bad..c1586424 100644 --- a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py +++ b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py @@ -214,8 +214,6 @@ class LLLiteModule(torch.nn.Module): if self.cond_emb is None: # print(f"cond_emb is None, {self.name}") - # test bad idea - #self.cond_image = self.cond_image.view(dtype=x.dtype) cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype)) if not self.is_conv2d: # reshape / b,c,h,w -> b,h*w,c From 05bdc3e44f34232cd7d5ae3c9cf3021b7816a2a6 Mon Sep 17 00:00:00 2001 From: wcole3 Date: Mon, 2 Sep 2024 13:06:30 -0400 Subject: [PATCH 4/4] Remove dtype field --- .../sd_forge_controlllite/lib_controllllite/lib_controllllite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py index c1586424..3bf7760d 100644 --- a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py +++ b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py @@ -136,7 +136,6 @@ class LLLiteModule(torch.nn.Module): self.start_step = start_step self.end_step = end_step self.is_first = False - self.dtype = dtype modules = [] modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2