Merge pull request #1665 from wcole3/controlllite-dtype

Set controlllite module dtype to model dtype
This commit is contained in:
DenOfEquity
2024-09-02 18:09:51 +01:00
committed by GitHub

View File

@@ -25,7 +25,7 @@ def extra_options_to_module_prefix(extra_options):
return module_pfx return module_pfx
def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent): def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent, model_dtype):
# calculate start and end step # calculate start and end step
start_step = math.floor(num_steps * start_percent) if start_percent > 0 else 0 start_step = math.floor(num_steps * start_percent) if start_percent > 0 else 0
end_step = math.floor(num_steps * end_percent) if end_percent > 0 else num_steps end_step = math.floor(num_steps * end_percent) if end_percent > 0 else num_steps
@@ -63,6 +63,7 @@ def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, st
num_steps=num_steps, num_steps=num_steps,
start_step=start_step, start_step=start_step,
end_step=end_step, end_step=end_step,
dtype=model_dtype
) )
info = module.load_state_dict(weights) info = module.load_state_dict(weights)
modules[module_name] = module modules[module_name] = module
@@ -125,6 +126,7 @@ class LLLiteModule(torch.nn.Module):
num_steps: int, num_steps: int,
start_step: int, start_step: int,
end_step: int, end_step: int,
dtype: torch.dtype = torch.float32
): ):
super().__init__() super().__init__()
self.name = name self.name = name
@@ -181,6 +183,7 @@ class LLLiteModule(torch.nn.Module):
self.cond_image = None self.cond_image = None
self.cond_emb = None self.cond_emb = None
self.current_step = 0 self.current_step = 0
self.to(dtype=dtype)
# @torch.inference_mode() # @torch.inference_mode()
def set_cond_image(self, cond_image): def set_cond_image(self, cond_image):
@@ -260,7 +263,8 @@ class LLLiteLoader:
# cond_image is b,h,w,3, 0-1 # cond_image is b,h,w,3, 0-1
model_lllite = model.clone() model_lllite = model.clone()
patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent) patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent,
model.model.diffusion_model.computation_dtype)
if patch is not None: if patch is not None:
model_lllite.set_model_attn1_patch(patch) model_lllite.set_model_attn1_patch(patch)
model_lllite.set_model_attn2_patch(patch) model_lllite.set_model_attn2_patch(patch)