Update preprocessor_inpaint.py

This commit is contained in:
lllyasviel
2024-01-30 15:00:22 -08:00
parent 981cee1378
commit c708d8a239

View File

@@ -93,6 +93,10 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
return
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs):
H, W, C = input_image.shape
raw_color = input_image[:, :, 0:3].copy()
raw_mask = input_image[:, :, 3:4].copy()
input_image, remove_pad = resize_image_with_pad(input_image, 256)
self.load_model()
@@ -108,13 +112,20 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
color = color * (1 - mask)
image_feed = torch.cat([color, mask], dim=2)
image_feed = einops.rearrange(image_feed, 'h w c -> 1 c h w')
result = self.model_patcher.model(image_feed)[0]
result = einops.rearrange(result, 'c h w -> h w c')
result = result * mask + color * (1 - mask)
result *= 255.0
result = result.detach().cpu().numpy().clip(0, 255).astype(np.uint8)
prd_color = self.model_patcher.model(image_feed)[0]
prd_color = einops.rearrange(prd_color, 'c h w -> h w c')
prd_color = prd_color * mask + color * (1 - mask)
prd_color *= 255.0
prd_color = prd_color.detach().cpu().numpy().clip(0, 255).astype(np.uint8)
result = remove_pad(result)
prd_color = remove_pad(prd_color)
prd_color = cv2.resize(prd_color, (W, H))
alpha = raw_mask.astype(np.float32) / 255.0
fin_color = prd_color.astype(np.float32) * alpha + raw_color.astype(np.float32) * (1 - alpha)
fin_color = fin_color.clip(0, 255).astype(np.uint8)
result = np.concatenate([fin_color, raw_mask], axis=2)
return result