diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 3ca5f927..5cc50697 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -164,17 +164,21 @@ class ControlNetForForgeOfficial(scripts.Script): if unit.input_mode == external_code.InputMode.BATCH: image_list = [] image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] - for idx, filename in enumerate(os.listdir(unit.batch_image_dir)): + batch_image_files = shared.listfiles(unit.batch_image_dir) + for batch_modifier in getattr(unit, 'batch_modifiers', []): + batch_image_files = batch_modifier(batch_image_files, p) + for idx, filename in enumerate(batch_image_files): if any(filename.lower().endswith(ext) for ext in image_extensions): img_path = os.path.join(unit.batch_image_dir, filename) logger.info(f'Try to read image: {img_path}') img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy() mask = None - if len(unit.batch_mask_dir) > 0: - if len(unit.batch_mask_dir) >= len(unit.batch_image_dir): - mask_path = unit.batch_mask_dir[idx] + if unit.batch_mask_dir: + batch_mask_files = shared.listfiles(unit.batch_mask_dir) + if len(batch_mask_files) >= len(batch_image_files): + mask_path = batch_mask_files[idx] else: - mask_path = unit.batch_mask_dir[0] + mask_path = batch_mask_files[0] mask_path = os.path.join(unit.batch_mask_dir, mask_path) mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy() if img is not None: diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 19234c1c..8dd29bab 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -92,9 +92,9 @@ def compute_controlnet_weighting(control, cnet): if isinstance(advanced_mask_weighting, torch.Tensor): if advanced_mask_weighting.shape[0] != 1: - k = int(control_signal.shape[0] // advanced_mask_weighting.shape[0]) - if control_signal.shape[0] == k * advanced_mask_weighting.shape[0]: - advanced_mask_weighting = advanced_mask_weighting.repeat(k, 1, 1, 1) + k_ = int(control_signal.shape[0] // advanced_mask_weighting.shape[0]) + if control_signal.shape[0] == k_ * advanced_mask_weighting.shape[0]: + advanced_mask_weighting = advanced_mask_weighting.repeat(k_, 1, 1, 1) control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear') control[k][i] = control_signal * final_weight[:, None, None, None] @@ -244,7 +244,7 @@ class ControlNet(ControlBase): if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) + self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) @@ -264,7 +264,7 @@ class ControlNet(ControlBase): wrapper_args['inner_model'] = self.control_model control = controlnet_model_function_wrapper(**wrapper_args) else: - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint.to(self.device), timesteps=timestep.float(), context=context.to(dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): @@ -549,7 +549,7 @@ class T2IAdapter(ControlBase): self.control_input = None self.cond_hint = None width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8) - self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device) + self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float() if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) if x_noisy.shape[0] != self.cond_hint.shape[0]: @@ -567,7 +567,7 @@ class T2IAdapter(ControlBase): wrapper_args['inner_t2i_model'] = self.t2i_model self.control_input = controlnet_model_function_wrapper(**wrapper_args) else: - self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) + self.control_input = self.t2i_model(self.cond_hint.to(x_noisy)) self.t2i_model.cpu()