Merge pull request #1665 from wcole3/controlllite-dtype

Set controlllite module dtype to model dtype
This commit is contained in:
DenOfEquity
2024-09-02 18:09:51 +01:00
committed by GitHub

View File

@@ -25,7 +25,7 @@ def extra_options_to_module_prefix(extra_options):
return module_pfx
def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent):
def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent, model_dtype):
# calculate start and end step
start_step = math.floor(num_steps * start_percent) if start_percent > 0 else 0
end_step = math.floor(num_steps * end_percent) if end_percent > 0 else num_steps
@@ -63,6 +63,7 @@ def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, st
num_steps=num_steps,
start_step=start_step,
end_step=end_step,
dtype=model_dtype
)
info = module.load_state_dict(weights)
modules[module_name] = module
@@ -125,6 +126,7 @@ class LLLiteModule(torch.nn.Module):
num_steps: int,
start_step: int,
end_step: int,
dtype: torch.dtype = torch.float32
):
super().__init__()
self.name = name
@@ -181,6 +183,7 @@ class LLLiteModule(torch.nn.Module):
self.cond_image = None
self.cond_emb = None
self.current_step = 0
self.to(dtype=dtype)
# @torch.inference_mode()
def set_cond_image(self, cond_image):
@@ -260,7 +263,8 @@ class LLLiteLoader:
# cond_image is b,h,w,3, 0-1
model_lllite = model.clone()
patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent)
patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent,
model.model.diffusion_model.computation_dtype)
if patch is not None:
model_lllite.set_model_attn1_patch(patch)
model_lllite.set_model_attn2_patch(patch)