Set controlllite module dtype to model dtype

This commit is contained in:
wcole3
2024-09-02 09:52:13 -04:00
parent 4f64f6daa4
commit 350018eff5

View File

@@ -25,7 +25,7 @@ def extra_options_to_module_prefix(extra_options):
return module_pfx
def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent):
def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent, model_dtype):
# calculate start and end step
start_step = math.floor(num_steps * start_percent) if start_percent > 0 else 0
end_step = math.floor(num_steps * end_percent) if end_percent > 0 else num_steps
@@ -63,6 +63,7 @@ def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, st
num_steps=num_steps,
start_step=start_step,
end_step=end_step,
dtype=model_dtype
)
info = module.load_state_dict(weights)
modules[module_name] = module
@@ -125,6 +126,7 @@ class LLLiteModule(torch.nn.Module):
num_steps: int,
start_step: int,
end_step: int,
dtype: torch.dtype = torch.float32
):
super().__init__()
self.name = name
@@ -134,6 +136,7 @@ class LLLiteModule(torch.nn.Module):
self.start_step = start_step
self.end_step = end_step
self.is_first = False
self.dtype = dtype
modules = []
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
@@ -181,6 +184,7 @@ class LLLiteModule(torch.nn.Module):
self.cond_image = None
self.cond_emb = None
self.current_step = 0
self.to(dtype=dtype)
# @torch.inference_mode()
def set_cond_image(self, cond_image):
@@ -210,6 +214,8 @@ class LLLiteModule(torch.nn.Module):
if self.cond_emb is None:
# print(f"cond_emb is None, {self.name}")
# test bad idea
#self.cond_image = self.cond_image.view(dtype=x.dtype)
cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype))
if not self.is_conv2d:
# reshape / b,c,h,w -> b,h*w,c
@@ -260,7 +266,8 @@ class LLLiteLoader:
# cond_image is b,h,w,3, 0-1
model_lllite = model.clone()
patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent)
patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent,
model_lllite.model.storage_dtype)
if patch is not None:
model_lllite.set_model_attn1_patch(patch)
model_lllite.set_model_attn2_patch(patch)