inpaint ini

This commit is contained in:
lllyasviel
2024-01-30 13:15:57 -08:00
parent 7b0a202f89
commit 01dfe3ac49
4 changed files with 18 additions and 14 deletions

View File

@@ -6,7 +6,7 @@ from typing import List, Optional, Union, Callable, Dict, Tuple, Literal
from dataclasses import dataclass
import numpy as np
from lib_controlnet.utils import svg_preprocess, read_image
from lib_controlnet.utils import svg_preprocess, read_image, judge_image_type
from lib_controlnet import (
global_state,
external_code,
@@ -935,9 +935,9 @@ class ControlNetUiGroup(object):
else None,
)
is_image = isinstance(result, np.ndarray) and result.ndim == 3 and result.shape[2] < 5
is_hwc, is_png = judge_image_type(result)
if not is_image:
if not is_hwc:
result = img
result = external_code.visualize_inpaint_mask(result)

View File

@@ -408,3 +408,9 @@ def crop_and_resize_image(detected_map, resize_mode, h, w):
detected_map = detected_map[pad_h:pad_h+h, pad_w:pad_w+w]
detected_map = safe_numpy(detected_map)
return detected_map
def judge_image_type(img):
is_image_hw3or4 = isinstance(img, np.ndarray) and img.ndim == 3 and int(img.shape[2]) in [3, 4]
is_png = is_image_hw3or4 and int(img.shape[2]) == 4
return is_image_hw3or4, is_png

View File

@@ -11,7 +11,7 @@ import gradio as gr
from lib_controlnet import global_state, external_code
from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \
prepare_mask
prepare_mask, judge_image_type
from lib_controlnet.enums import StableDiffusionVersion
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from lib_controlnet.controlnet_ui.photopea import Photopea
@@ -398,7 +398,6 @@ class ControlNetForForgeOfficial(scripts.Script):
return h, w, hr_y, hr_x
@torch.no_grad()
@torch.inference_mode()
def process_unit_after_click_generate(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
@@ -440,22 +439,24 @@ class ControlNetForForgeOfficial(scripts.Script):
slider_2=unit.threshold_b,
)
preprocessor_output_is_image = \
isinstance(preprocessor_output, np.ndarray) \
and preprocessor_output.ndim == 3 \
and preprocessor_output.shape[2] < 5
preprocessor_output_is_image, need_inpaint_fix = judge_image_type(preprocessor_output)
if preprocessor_output_is_image:
params.control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w)
p.extra_result_images.append(params.control_cond)
p.extra_result_images.append(external_code.visualize_inpaint_mask(params.control_cond))
params.control_cond = numpy_to_pytorch(params.control_cond).movedim(-1, 1)
if has_high_res_fix:
params.control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x)
p.extra_result_images.append(params.control_cond_for_hr_fix)
p.extra_result_images.append(external_code.visualize_inpaint_mask(params.control_cond_for_hr_fix))
params.control_cond_for_hr_fix = numpy_to_pytorch(params.control_cond_for_hr_fix).movedim(-1, 1)
else:
params.control_cond_for_hr_fix = params.control_cond
if need_inpaint_fix:
fixer = lambda x: x[:, :3] * (1.0 - x[:, 3:]) - x[:, 3:]
params.control_cond = fixer(params.control_cond)
params.control_cond_for_hr_fix = fixer(params.control_cond_for_hr_fix)
else:
params.control_cond = preprocessor_output
params.control_cond_for_hr_fix = preprocessor_output
@@ -478,7 +479,6 @@ class ControlNetForForgeOfficial(scripts.Script):
return
@torch.no_grad()
@torch.inference_mode()
def process_unit_before_every_sampling(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
@@ -536,7 +536,6 @@ class ControlNetForForgeOfficial(scripts.Script):
return
@torch.no_grad()
@torch.inference_mode()
def process_unit_after_every_sampling(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,

View File

@@ -129,7 +129,6 @@ class PreprocessorClipVision(Preprocessor):
return self.clipvision
@torch.no_grad()
@torch.inference_mode()
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs):
clipvision = self.load_clipvision()
return clipvision.encode_image(numpy_to_pytorch(input_image))