From 01dfe3ac490b86c242843bcc109e4682aa499b5f Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 30 Jan 2024 13:15:57 -0800 Subject: [PATCH] inpaint ini --- .../controlnet_ui/controlnet_ui_group.py | 6 +++--- .../lib_controlnet/utils.py | 6 ++++++ .../sd_forge_controlnet/scripts/controlnet.py | 19 +++++++++---------- modules_forge/supported_preprocessor.py | 1 - 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py index b19b8710..40f41074 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py @@ -6,7 +6,7 @@ from typing import List, Optional, Union, Callable, Dict, Tuple, Literal from dataclasses import dataclass import numpy as np -from lib_controlnet.utils import svg_preprocess, read_image +from lib_controlnet.utils import svg_preprocess, read_image, judge_image_type from lib_controlnet import ( global_state, external_code, @@ -935,9 +935,9 @@ class ControlNetUiGroup(object): else None, ) - is_image = isinstance(result, np.ndarray) and result.ndim == 3 and result.shape[2] < 5 + is_hwc, is_png = judge_image_type(result) - if not is_image: + if not is_hwc: result = img result = external_code.visualize_inpaint_mask(result) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py index cea7ce83..3fd69414 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py @@ -408,3 +408,9 @@ def crop_and_resize_image(detected_map, resize_mode, h, w): detected_map = detected_map[pad_h:pad_h+h, pad_w:pad_w+w] detected_map = safe_numpy(detected_map) return detected_map + + +def judge_image_type(img): + is_image_hw3or4 = isinstance(img, np.ndarray) and img.ndim == 3 and int(img.shape[2]) in [3, 4] + is_png = is_image_hw3or4 and int(img.shape[2]) == 4 + return is_image_hw3or4, is_png diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index f418e5c6..97188981 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -11,7 +11,7 @@ import gradio as gr from lib_controlnet import global_state, external_code from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \ - prepare_mask + prepare_mask, judge_image_type from lib_controlnet.enums import StableDiffusionVersion from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit from lib_controlnet.controlnet_ui.photopea import Photopea @@ -398,7 +398,6 @@ class ControlNetForForgeOfficial(scripts.Script): return h, w, hr_y, hr_x @torch.no_grad() - @torch.inference_mode() def process_unit_after_click_generate(self, p: StableDiffusionProcessing, unit: external_code.ControlNetUnit, @@ -440,22 +439,24 @@ class ControlNetForForgeOfficial(scripts.Script): slider_2=unit.threshold_b, ) - preprocessor_output_is_image = \ - isinstance(preprocessor_output, np.ndarray) \ - and preprocessor_output.ndim == 3 \ - and preprocessor_output.shape[2] < 5 + preprocessor_output_is_image, need_inpaint_fix = judge_image_type(preprocessor_output) if preprocessor_output_is_image: params.control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w) - p.extra_result_images.append(params.control_cond) + p.extra_result_images.append(external_code.visualize_inpaint_mask(params.control_cond)) params.control_cond = numpy_to_pytorch(params.control_cond).movedim(-1, 1) if has_high_res_fix: params.control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x) - p.extra_result_images.append(params.control_cond_for_hr_fix) + p.extra_result_images.append(external_code.visualize_inpaint_mask(params.control_cond_for_hr_fix)) params.control_cond_for_hr_fix = numpy_to_pytorch(params.control_cond_for_hr_fix).movedim(-1, 1) else: params.control_cond_for_hr_fix = params.control_cond + + if need_inpaint_fix: + fixer = lambda x: x[:, :3] * (1.0 - x[:, 3:]) - x[:, 3:] + params.control_cond = fixer(params.control_cond) + params.control_cond_for_hr_fix = fixer(params.control_cond_for_hr_fix) else: params.control_cond = preprocessor_output params.control_cond_for_hr_fix = preprocessor_output @@ -478,7 +479,6 @@ class ControlNetForForgeOfficial(scripts.Script): return @torch.no_grad() - @torch.inference_mode() def process_unit_before_every_sampling(self, p: StableDiffusionProcessing, unit: external_code.ControlNetUnit, @@ -536,7 +536,6 @@ class ControlNetForForgeOfficial(scripts.Script): return @torch.no_grad() - @torch.inference_mode() def process_unit_after_every_sampling(self, p: StableDiffusionProcessing, unit: external_code.ControlNetUnit, diff --git a/modules_forge/supported_preprocessor.py b/modules_forge/supported_preprocessor.py index 8a688290..7de23d6a 100644 --- a/modules_forge/supported_preprocessor.py +++ b/modules_forge/supported_preprocessor.py @@ -129,7 +129,6 @@ class PreprocessorClipVision(Preprocessor): return self.clipvision @torch.no_grad() - @torch.inference_mode() def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs): clipvision = self.load_clipvision() return clipvision.encode_image(numpy_to_pytorch(input_image))