ControlNet batch fix (#113)

* cn forward patcher

* simplify

* use args instead of kwargs

* postpond moving cond_hint to gpu

* also do this for t2i adapter

* use a1111's code to load files in a batch

* revert

* patcher for batch images

* patcher for batch images

* remove cn fn wrapper dupl

* remove shit

* use unit getattr instead of unet patcher

* fix bug

* small changte
This commit is contained in:
Chengsong Zhang
2024-02-09 21:55:29 -06:00
committed by GitHub
parent d6f2e5bdd9
commit ee565b337c
2 changed files with 16 additions and 12 deletions

View File

@@ -164,17 +164,21 @@ class ControlNetForForgeOfficial(scripts.Script):
if unit.input_mode == external_code.InputMode.BATCH:
image_list = []
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
for idx, filename in enumerate(os.listdir(unit.batch_image_dir)):
batch_image_files = shared.listfiles(unit.batch_image_dir)
for batch_modifier in getattr(unit, 'batch_modifiers', []):
batch_image_files = batch_modifier(batch_image_files, p)
for idx, filename in enumerate(batch_image_files):
if any(filename.lower().endswith(ext) for ext in image_extensions):
img_path = os.path.join(unit.batch_image_dir, filename)
logger.info(f'Try to read image: {img_path}')
img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy()
mask = None
if len(unit.batch_mask_dir) > 0:
if len(unit.batch_mask_dir) >= len(unit.batch_image_dir):
mask_path = unit.batch_mask_dir[idx]
if unit.batch_mask_dir:
batch_mask_files = shared.listfiles(unit.batch_mask_dir)
if len(batch_mask_files) >= len(batch_image_files):
mask_path = batch_mask_files[idx]
else:
mask_path = unit.batch_mask_dir[0]
mask_path = batch_mask_files[0]
mask_path = os.path.join(unit.batch_mask_dir, mask_path)
mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy()
if img is not None:

View File

@@ -92,9 +92,9 @@ def compute_controlnet_weighting(control, cnet):
if isinstance(advanced_mask_weighting, torch.Tensor):
if advanced_mask_weighting.shape[0] != 1:
k = int(control_signal.shape[0] // advanced_mask_weighting.shape[0])
if control_signal.shape[0] == k * advanced_mask_weighting.shape[0]:
advanced_mask_weighting = advanced_mask_weighting.repeat(k, 1, 1, 1)
k_ = int(control_signal.shape[0] // advanced_mask_weighting.shape[0])
if control_signal.shape[0] == k_ * advanced_mask_weighting.shape[0]:
advanced_mask_weighting = advanced_mask_weighting.repeat(k_, 1, 1, 1)
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear')
control[k][i] = control_signal * final_weight[:, None, None, None]
@@ -244,7 +244,7 @@ class ControlNet(ControlBase):
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
@@ -264,7 +264,7 @@ class ControlNet(ControlBase):
wrapper_args['inner_model'] = self.control_model
control = controlnet_model_function_wrapper(**wrapper_args)
else:
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint.to(self.device), timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
def copy(self):
@@ -549,7 +549,7 @@ class T2IAdapter(ControlBase):
self.control_input = None
self.cond_hint = None
width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device)
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float()
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
@@ -567,7 +567,7 @@ class T2IAdapter(ControlBase):
wrapper_args['inner_t2i_model'] = self.t2i_model
self.control_input = controlnet_model_function_wrapper(**wrapper_args)
else:
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy))
self.t2i_model.cpu()