mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
Update preprocessor_inpaint.py
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter
|
from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter
|
||||||
from modules_forge.shared import add_supported_preprocessor
|
from modules_forge.shared import add_supported_preprocessor
|
||||||
from modules_forge.forge_util import numpy_to_pytorch
|
from modules_forge.forge_util import numpy_to_pytorch
|
||||||
@@ -16,8 +18,26 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = 'inpaint_only'
|
self.name = 'inpaint_only'
|
||||||
|
self.image = None
|
||||||
|
self.mask = None
|
||||||
|
self.latent_image = None
|
||||||
|
self.latent_mask = None
|
||||||
|
|
||||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
||||||
|
self.image = kwargs['cond_before_inpaint_fix'][:, 0:3]
|
||||||
|
self.mask = kwargs['cond_before_inpaint_fix'][:, 3:]
|
||||||
|
|
||||||
|
vae = process.sd_model.forge_objects.vae
|
||||||
|
# This is a powerful VAE with integrated memory management, bf16, and tiled fallback.
|
||||||
|
|
||||||
|
self.latent_image = vae.encode(self.image.movedim(1, -1))
|
||||||
|
B, C, H, W = self.latent_image.shape
|
||||||
|
|
||||||
|
latent_mask = self.mask
|
||||||
|
latent_mask = torch.nn.functional.interpolate(latent_mask, size=(H * 8, W * 8), mode="bilinear").round()
|
||||||
|
latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round().to(self.latent_image)
|
||||||
|
self.latent_mask = latent_mask
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def process_after_every_sampling(self, process, params, *args, **kwargs):
|
def process_after_every_sampling(self, process, params, *args, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user