mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
Set controlllite module dtype to model dtype
This commit is contained in:
@@ -25,7 +25,7 @@ def extra_options_to_module_prefix(extra_options):
|
||||
return module_pfx
|
||||
|
||||
|
||||
def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent):
|
||||
def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent, model_dtype):
|
||||
# calculate start and end step
|
||||
start_step = math.floor(num_steps * start_percent) if start_percent > 0 else 0
|
||||
end_step = math.floor(num_steps * end_percent) if end_percent > 0 else num_steps
|
||||
@@ -63,6 +63,7 @@ def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, st
|
||||
num_steps=num_steps,
|
||||
start_step=start_step,
|
||||
end_step=end_step,
|
||||
dtype=model_dtype
|
||||
)
|
||||
info = module.load_state_dict(weights)
|
||||
modules[module_name] = module
|
||||
@@ -125,6 +126,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
num_steps: int,
|
||||
start_step: int,
|
||||
end_step: int,
|
||||
dtype: torch.dtype = torch.float32
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
@@ -134,6 +136,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
self.start_step = start_step
|
||||
self.end_step = end_step
|
||||
self.is_first = False
|
||||
self.dtype = dtype
|
||||
|
||||
modules = []
|
||||
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
|
||||
@@ -181,6 +184,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
self.cond_image = None
|
||||
self.cond_emb = None
|
||||
self.current_step = 0
|
||||
self.to(dtype=dtype)
|
||||
|
||||
# @torch.inference_mode()
|
||||
def set_cond_image(self, cond_image):
|
||||
@@ -210,6 +214,8 @@ class LLLiteModule(torch.nn.Module):
|
||||
|
||||
if self.cond_emb is None:
|
||||
# print(f"cond_emb is None, {self.name}")
|
||||
# test bad idea
|
||||
#self.cond_image = self.cond_image.view(dtype=x.dtype)
|
||||
cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype))
|
||||
if not self.is_conv2d:
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
@@ -260,7 +266,8 @@ class LLLiteLoader:
|
||||
# cond_image is b,h,w,3, 0-1
|
||||
|
||||
model_lllite = model.clone()
|
||||
patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent)
|
||||
patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent,
|
||||
model_lllite.model.storage_dtype)
|
||||
if patch is not None:
|
||||
model_lllite.set_model_attn1_patch(patch)
|
||||
model_lllite.set_model_attn2_patch(patch)
|
||||
|
||||
Reference in New Issue
Block a user