fill time fix

This commit is contained in:
Антон Антонов
2023-03-18 13:42:15 +03:00
parent 0a3d03a41a
commit 3761cb9605

View File

@@ -1,11 +1,165 @@
import math
import copy
import time
import gradio as gr
from PIL import Image, ImageDraw, ImageOps
import numpy as np
from PIL import Image, ImageDraw, ImageOps, ImageFilter
from modules import processing, shared, images, devices, scripts
from modules.processing import StableDiffusionProcessing
from modules.processing import Processed
from modules.shared import opts, state
from enum import Enum
from hashlib import md5
from collections import namedtuple
class USDUGrid():
def __init__(self, image, padding, tile_width, tile_height, mask_blur):
self.image = image
self.padding = padding
self.tile_width = tile_width
self.tile_height = tile_height
self.tiles = []
self.rows_c = math.ceil(self.image.height / self.tile_height)
self.cols_c = math.ceil(self.image.width / self.tile_width)
self.mask_blur = mask_blur
def add_row(self, row):
self.tiles.append(row)
def calc_crop_region(self, xi, yi):
x1 = xi * self.tile_width - self.padding
y1 = yi * self.tile_height - self.padding
x2 = (xi + 1) * self.tile_width + self.padding
y2 = (yi + 1) * self.tile_height + self.padding
if x1 < 0:
x1 = 0
if y1 < 0:
y1 = 0
if x2 > self.image.width:
x2 = self.image.width
if y2 > self.image.height:
y2 = self.image.height
return x1, y1, x2, y2
def split_grid(self):
for yi in range(self.rows_c):
row = USDUGridRow()
for xi in range(self.cols_c):
crop_region = self.calc_crop_region(xi, yi)
x1, y1, x2, y2 = crop_region
tile = self.image.crop(crop_region)
col = USDUGridCol()
col.add_tile(tile, (xi, yi), (x1, y1, x2-x1, y2-y1))
row.add_col(col)
self.add_row(row)
def combine_grid(self):
start_at = time.time()
image = self.image
for row in self.tiles:
for col in row.cols:
if (col.mask is None):
continue
xi, yi = col.pos
x, y, w, h = col.paste_to
m_image = Image.new('RGB', (image.width, image.height))
s_at = time.time()
m_image.paste(col.mask.filter(ImageFilter.GaussianBlur(self.mask_blur)), (x,y))
e_at = time.time()
print(f"Gauss: {e_at - s_at}")
np_mask = np.array(m_image)
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
mask_for_overlay = Image.fromarray(np_mask)
image_masked = Image.new('RGBa', (m_image.width, m_image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask_for_overlay.convert('L')))
image_masked = image_masked.convert('RGBA')
# image_masked.save(f"F:/tt2/o{col.pos}.png")
if col.paste_to is not None:
x, y, w, h = col.paste_to
base_image = Image.new('RGBA', (image_masked.width, image_masked.height))
image = images.resize_image(1, col.tile, w, h)
# image.save(f"F:/tt/o{col.pos}.png")
base_image.paste(image, (x, y))
image = base_image
# image.save(f"F:/tt1/o{col.pos}.png")
image = image.convert('RGBA')
image.alpha_composite(image_masked)
image = image.convert('RGB')
# base_image = copy.deepcopy(image)
# mask_image = Image.new('L', (image.width, image.height))
# x, y, w, h = col.paste_to
# base_image.paste(col.tile, (x,y))
# print(f"fuck you bitch: {type(col.mask)}")
# mask_image.paste(col.mask, (x,y))
# mask_image = ImageOps.invert(mask_image)
# mask_image = mask_image.filter(ImageFilter.GaussianBlur(self.mask_blur))
# # if col.paste_to is not None:
# # x, y, w, h = col.paste_to
# # base_image = Image.new('RGBA', (col.tile.width, col.tile.height))
# # image = images.resize_image(1, image, w, h)
# # base_image.paste(image, (x, y))
# # image = base_image
# image = image.convert('RGBA')
# image = Image.composite(image, base_image, mask_image)
# image = image.convert('RGB')
# # xi, yi = col.pos
# # x1 = self.padding if xi > 0 else 0
# # y1 = self.padding if yi > 0 else 0
# # x2 = col.tile.width - self.padding if xi < len(row.cols) else col.tile.width
# # y2 = col.tile.height - self.padding if yi < len(self.tiles) else col.tile.height
# # tile = col.tile.crop((x1, y1, x2, y2))
# # x, y, w, h = col.paste_to
# # x = x + self.padding if xi > 0 else 0
# # y = y + self.padding if yi > 0 else 0
# # self.image.paste(tile, (x,y))
end_at = time.time()
print(f"Combine time: {end_at - start_at}")
return image
class USDUGridRow():
def __init__(self):
self.cols = []
def add_col(self, col):
self.cols.append(col)
class USDUGridCol():
def __init__(self):
self.tile = None
self.pos = None
self.paste_to = None
self.tile_width = None
self.tile_height = None
self.mask = None
def add_tile(self, tile, pos, paste_to):
self.tile = tile
self.pos = pos
self.paste_to = paste_to
self.tile_width = self.tile.width
self.tile_height = self.tile.height
def apply_overlay(self, image):
if self.tile is None:
return image
# x, y, w, h = self.paste_to
# image.paste(self.tile, (x,y))
# if self.paste_to is not None:
# x, y, w, h = self.paste_to
# base_image = Image.new('RGBA', (self.tile.width, self.tile.height))
# image = images.resize_image(1, image, w, h)
# base_image.paste(image, (x, y))
# image = base_image
# image = image.convert('RGBA')
# image.alpha_composite(self.tile)
# image = image.convert('RGB')
return image
class USDUMode(Enum):
LINEAR = 0
@@ -20,7 +174,7 @@ class USDUSFMode(Enum):
class USDUpscaler():
def __init__(self, p, image, upscaler_index:int, save_redraw, save_seams_fix, tile_width, tile_height) -> None:
def __init__(self, p, image, upscaler_index:int, save_redraw, save_seams_fix, tile_width, tile_height, padding) -> None:
self.p:StableDiffusionProcessing = p
self.image:Image = image
self.scale_factor = math.ceil(max(p.width, p.height) / max(image.width, image.height))
@@ -29,13 +183,14 @@ class USDUpscaler():
self.redraw.save = save_redraw
self.redraw.tile_width = tile_width if tile_width > 0 else tile_height
self.redraw.tile_height = tile_height if tile_height > 0 else tile_width
self.redraw.padding = padding
self.seams_fix = USDUSeamsFix()
self.seams_fix.save = save_seams_fix
self.seams_fix.tile_width = tile_width if tile_width > 0 else tile_height
self.seams_fix.tile_height = tile_height if tile_height > 0 else tile_width
self.initial_info = None
self.rows = math.ceil(self.p.height / self.redraw.tile_height)
self.cols = math.ceil(self.p.width / self.redraw.tile_width)
self.rows = math.ceil(p.height / self.redraw.tile_height)
self.cols = math.ceil(p.width / self.redraw.tile_width)
def get_factor(self, num):
# Its just return, don't need elif
@@ -82,11 +237,11 @@ class USDUpscaler():
# Resize image to set values
self.image = self.image.resize((self.p.width, self.p.height), resample=Image.LANCZOS)
def setup_redraw(self, redraw_mode, padding, mask_blur):
def setup_redraw(self, redraw_mode, mask_blur):
self.redraw.mode = USDUMode(redraw_mode)
self.redraw.enabled = self.redraw.mode != USDUMode.NONE
self.redraw.padding = padding
self.p.mask_blur = mask_blur
self.redraw.max_batch_size = self.p.batch_size
def setup_seams_fix(self, padding, denoise, mask_blur, width, mode):
self.seams_fix.padding = padding
@@ -149,17 +304,23 @@ class USDURedraw():
def init_draw(self, p, width, height):
p.inpaint_full_res = True
p.inpaint_full_res_padding = self.padding
p.width = math.ceil((self.tile_width+self.padding) / 64) * 64
p.height = math.ceil((self.tile_height+self.padding) / 64) * 64
p.width = width
p.height = height
mask = Image.new("L", (width, height), "black")
draw = ImageDraw.Draw(mask)
return mask, draw
def calc_rectangle(self, xi, yi):
x1 = xi * self.tile_width
y1 = yi * self.tile_height
x2 = xi * self.tile_width + self.tile_width
y2 = yi * self.tile_height + self.tile_height
def calc_rectangle(self, xi, yi, padding, cols, rows, tile_width, tile_height, mask_blur):
# x1 = 0
# y1 = 0
# x2 = self.tile_width
# y2 = self.tile_height
x1 = math.ceil(padding / 2) if xi > 0 else 0
y1 = math.ceil(padding / 2) if yi > 0 else 0
x2 = tile_width - math.ceil(padding / 2) if xi < (cols - 1) else tile_width
y2 = tile_height - math.ceil(padding / 2) if yi < (rows - 1) else tile_height
return x1, y1, x2, y2
@@ -183,8 +344,58 @@ class USDURedraw():
return image
def chess_process_processing(self, p, image, rows, cols, polar, tiles):
grid = USDUGrid(image, self.padding, self.tile_width, self.tile_height, p.mask_blur)
grid.split_grid()
print(len(grid.tiles))
if len(grid.tiles) == 0:
return image
tiles_processing_data = {}
for row in grid.tiles:
for col in row.cols:
xi, yi = col.pos
if (tiles[yi][xi] == polar):
coords = self.calc_rectangle(xi, yi, self.padding, cols, rows, col.tile.width, col.tile.height, p.mask_blur)
idx = ''.join([str(value) for value in (col.tile.width, col.tile.height)]).join([str(value) for value in coords])
if tiles_processing_data.get(idx) == None:
tiles_processing_data[idx] = [coords, [], [], []]
tiles_processing_data[idx][1].append(col.tile)
tiles_processing_data[idx][2].append(xi)
tiles_processing_data[idx][3].append(yi)
max_batch_size = self.max_batch_size
v = 0
for idxf, tile_data in tiles_processing_data.items():
kk = 0
for tile in tile_data[1]:
kk += 1
v += 1
for idxf, tile_data in tiles_processing_data.items():
mask, draw = self.init_draw(p, tile_data[1][0].width, tile_data[1][0].height)
draw.rectangle(tile_data[0], fill="white")
p.image_mask = mask
batch_count = math.ceil(len(tile_data[1]) / max_batch_size)
for i in range(batch_count):
p.batch_size = max_batch_size if len(tile_data[1]) > max_batch_size * (i + 1) else len(tile_data[1]) - max_batch_size * i
work_images = []
begin_index = 0 if i == 0 else max_batch_size * (i)
end_index = i * max_batch_size + p.batch_size
for j in range(begin_index, end_index):
work_images.append(tile_data[1][j])
p.init_images = work_images
processed = processing.process_images(p)
k = 0
for j in range(begin_index, end_index):
row_index = tile_data[3][j]
col_index = tile_data[2][j]
grid.tiles[row_index].cols[col_index].tile = processed.images[k]
grid.tiles[row_index].cols[col_index].mask = copy.deepcopy(mask)
k += 1
return grid.combine_grid()
def chess_process(self, p, image, rows, cols):
mask, draw = self.init_draw(p, image.width, image.height)
tiles = []
# calc tiles colors
for yi in range(rows):
@@ -197,36 +408,42 @@ class USDURedraw():
if yi > 0 and yi % 2 != 0:
color = not color
tiles[yi].append(color)
for yi in range(len(tiles)):
for xi in range(len(tiles[yi])):
if state.interrupted:
break
if not tiles[yi][xi]:
tiles[yi][xi] = not tiles[yi][xi]
continue
tiles[yi][xi] = not tiles[yi][xi]
draw.rectangle(self.calc_rectangle(xi, yi), fill="white")
p.init_images = [image]
p.image_mask = mask
processed = processing.process_images(p)
draw.rectangle(self.calc_rectangle(xi, yi), fill="black")
if (len(processed.images) > 0):
image = processed.images[0]
for yi in range(len(tiles)):
for xi in range(len(tiles[yi])):
if state.interrupted:
break
if not tiles[yi][xi]:
continue
draw.rectangle(self.calc_rectangle(xi, yi), fill="white")
p.init_images = [image]
p.image_mask = mask
processed = processing.process_images(p)
draw.rectangle(self.calc_rectangle(xi, yi), fill="black")
if (len(processed.images) > 0):
image = processed.images[0]
image = self.chess_process_processing(p, image, rows, cols, True, tiles)
image = self.chess_process_processing(p, image, rows, cols, False, tiles)
# image = self.chess_process_processing(p, image, rows, cols, True, tiles)
# image = self.chess_process_processing(p, image, rows, cols, False, tiles)
return image
# image = images.combine_grid(grid)
# return image
# for yi in range(len(tiles)):
# for xi in range(len(tiles[yi])):
# if state.interrupted:
# break
# if not tiles[yi][xi]:
# tiles[yi][xi] = not tiles[yi][xi]
# continue
# tiles[yi][xi] = not tiles[yi][xi]
# draw.rectangle(self.calc_rectangle(xi, yi), fill="white")
# p.init_images = [image]
# p.image_mask = mask
# processed = processing.process_images(p)
# draw.rectangle(self.calc_rectangle(xi, yi), fill="black")
# if (len(processed.images) > 0):
# image = processed.images[0]
# for yi in range(len(tiles)):
# for xi in range(len(tiles[yi])):
# if state.interrupted:
# break
# if not tiles[yi][xi]:
# continue
# draw.rectangle(self.calc_rectangle(xi, yi), fill="white")
# p.init_images = [image]
# p.image_mask = mask
# processed = processing.process_images(p)
# draw.rectangle(self.calc_rectangle(xi, yi), fill="black")
# if (len(processed.images) > 0):
# image = processed.images[0]
p.width = image.width
p.height = image.height
@@ -518,7 +735,7 @@ class Script(scripts.Script):
p.do_not_save_samples = True
p.inpaint_full_res = False
p.inpainting_fill = 1
# p.inpainting_fill = 1
seed = p.seed
@@ -537,11 +754,11 @@ class Script(scripts.Script):
p.height = math.ceil((init_img.height * custom_scale) / 64) * 64
# Upscaling
upscaler = USDUpscaler(p, init_img, upscaler_index, save_upscaled_image, save_seams_fix_image, tile_width, tile_height)
upscaler = USDUpscaler(p, init_img, upscaler_index, save_upscaled_image, save_seams_fix_image, tile_width, tile_height, padding)
upscaler.upscale()
# Drawing
upscaler.setup_redraw(redraw_mode, padding, mask_blur)
upscaler.setup_redraw(redraw_mode, mask_blur)
upscaler.setup_seams_fix(seams_fix_padding, seams_fix_denoise, seams_fix_mask_blur, seams_fix_width, seams_fix_type)
upscaler.print_info()
upscaler.add_extra_info()