mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 03:01:15 +00:00
ControlNet batch fix (#113)
* cn forward patcher * simplify * use args instead of kwargs * postpond moving cond_hint to gpu * also do this for t2i adapter * use a1111's code to load files in a batch * revert * patcher for batch images * patcher for batch images * remove cn fn wrapper dupl * remove shit * use unit getattr instead of unet patcher * fix bug * small changte
This commit is contained in:
@@ -92,9 +92,9 @@ def compute_controlnet_weighting(control, cnet):
|
||||
|
||||
if isinstance(advanced_mask_weighting, torch.Tensor):
|
||||
if advanced_mask_weighting.shape[0] != 1:
|
||||
k = int(control_signal.shape[0] // advanced_mask_weighting.shape[0])
|
||||
if control_signal.shape[0] == k * advanced_mask_weighting.shape[0]:
|
||||
advanced_mask_weighting = advanced_mask_weighting.repeat(k, 1, 1, 1)
|
||||
k_ = int(control_signal.shape[0] // advanced_mask_weighting.shape[0])
|
||||
if control_signal.shape[0] == k_ * advanced_mask_weighting.shape[0]:
|
||||
advanced_mask_weighting = advanced_mask_weighting.repeat(k_, 1, 1, 1)
|
||||
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear')
|
||||
|
||||
control[k][i] = control_signal * final_weight[:, None, None, None]
|
||||
@@ -244,7 +244,7 @@ class ControlNet(ControlBase):
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
||||
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
@@ -264,7 +264,7 @@ class ControlNet(ControlBase):
|
||||
wrapper_args['inner_model'] = self.control_model
|
||||
control = controlnet_model_function_wrapper(**wrapper_args)
|
||||
else:
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint.to(self.device), timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
return self.control_merge(None, control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
@@ -549,7 +549,7 @@ class T2IAdapter(ControlBase):
|
||||
self.control_input = None
|
||||
self.cond_hint = None
|
||||
width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
|
||||
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device)
|
||||
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float()
|
||||
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
||||
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
@@ -567,7 +567,7 @@ class T2IAdapter(ControlBase):
|
||||
wrapper_args['inner_t2i_model'] = self.t2i_model
|
||||
self.control_input = controlnet_model_function_wrapper(**wrapper_args)
|
||||
else:
|
||||
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
||||
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy))
|
||||
|
||||
self.t2i_model.cpu()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user