mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 17:09:49 +00:00
tiled diffusion
This commit is contained in:
@@ -7,61 +7,91 @@
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import ldm_patched.modules.model_management
|
from backend import memory_management
|
||||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
from backend.misc.image_resize import adaptive_resize
|
||||||
import ldm_patched.modules.model_patcher
|
from backend.patcher.base import ModelPatcher
|
||||||
from ldm_patched.modules.model_base import BaseModel
|
|
||||||
from typing import List, Union, Tuple, Dict
|
from typing import List, Union, Tuple, Dict
|
||||||
from ldm_patched.contrib.external import ImageScale
|
|
||||||
import ldm_patched.modules.utils
|
|
||||||
from backend.patcher.controlnet import ControlNet, T2IAdapter
|
from backend.patcher.controlnet import ControlNet, T2IAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class ImageScale:
|
||||||
|
def upscale(self, image, upscale_method, width, height, crop):
|
||||||
|
if width == 0 and height == 0:
|
||||||
|
s = image
|
||||||
|
else:
|
||||||
|
samples = image.movedim(-1, 1)
|
||||||
|
|
||||||
|
if width == 0:
|
||||||
|
width = max(1, round(samples.shape[3] * height / samples.shape[2]))
|
||||||
|
elif height == 0:
|
||||||
|
height = max(1, round(samples.shape[2] * width / samples.shape[3]))
|
||||||
|
|
||||||
|
s = adaptive_resize(samples, width, height, upscale_method, crop)
|
||||||
|
s = s.movedim(1, -1)
|
||||||
|
return (s,)
|
||||||
|
|
||||||
|
|
||||||
opt_C = 4
|
opt_C = 4
|
||||||
opt_f = 8
|
opt_f = 8
|
||||||
|
|
||||||
|
|
||||||
def ceildiv(big, small):
|
def ceildiv(big, small):
|
||||||
# Correct ceiling division that avoids floating-point errors and importing math.ceil.
|
# Correct ceiling division that avoids floating-point errors and importing math.ceil.
|
||||||
return -(big // -small)
|
return -(big // -small)
|
||||||
|
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class BlendMode(Enum): # i.e. LayerType
|
class BlendMode(Enum): # i.e. LayerType
|
||||||
FOREGROUND = 'Foreground'
|
FOREGROUND = 'Foreground'
|
||||||
BACKGROUND = 'Background'
|
BACKGROUND = 'Background'
|
||||||
|
|
||||||
|
|
||||||
class Processing: ...
|
class Processing: ...
|
||||||
|
|
||||||
|
|
||||||
class Device: ...
|
class Device: ...
|
||||||
|
|
||||||
|
|
||||||
devices = Device()
|
devices = Device()
|
||||||
devices.device = ldm_patched.modules.model_management.get_torch_device()
|
devices.device = memory_management.get_torch_device()
|
||||||
|
|
||||||
|
|
||||||
def null_decorator(fn):
|
def null_decorator(fn):
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
keep_signature = null_decorator
|
keep_signature = null_decorator
|
||||||
controlnet = null_decorator
|
controlnet = null_decorator
|
||||||
stablesr = null_decorator
|
stablesr = null_decorator
|
||||||
grid_bbox = null_decorator
|
grid_bbox = null_decorator
|
||||||
custom_bbox = null_decorator
|
custom_bbox = null_decorator
|
||||||
noise_inverse = null_decorator
|
noise_inverse = null_decorator
|
||||||
|
|
||||||
|
|
||||||
class BBox:
|
class BBox:
|
||||||
''' grid bbox '''
|
''' grid bbox '''
|
||||||
|
|
||||||
def __init__(self, x:int, y:int, w:int, h:int):
|
def __init__(self, x: int, y: int, w: int, h: int):
|
||||||
self.x = x
|
self.x = x
|
||||||
self.y = y
|
self.y = y
|
||||||
self.w = w
|
self.w = w
|
||||||
self.h = h
|
self.h = h
|
||||||
self.box = [x, y, x+w, y+h]
|
self.box = [x, y, x + w, y + h]
|
||||||
self.slicer = slice(None), slice(None), slice(y, y+h), slice(x, x+w)
|
self.slicer = slice(None), slice(None), slice(y, y + h), slice(x, x + w)
|
||||||
|
|
||||||
def __getitem__(self, idx:int) -> int:
|
def __getitem__(self, idx: int) -> int:
|
||||||
return self.box[idx]
|
return self.box[idx]
|
||||||
|
|
||||||
def split_bboxes(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]:
|
|
||||||
cols = ceildiv((w - overlap) , (tile_w - overlap))
|
def split_bboxes(w: int, h: int, tile_w: int, tile_h: int, overlap: int = 16, init_weight: Union[Tensor, float] = 1.0) -> Tuple[List[BBox], Tensor]:
|
||||||
rows = ceildiv((h - overlap) , (tile_h - overlap))
|
cols = ceildiv((w - overlap), (tile_w - overlap))
|
||||||
|
rows = ceildiv((h - overlap), (tile_h - overlap))
|
||||||
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
||||||
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
||||||
|
|
||||||
@@ -78,16 +108,17 @@ def split_bboxes(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16, init_weig
|
|||||||
|
|
||||||
return bbox_list, weight
|
return bbox_list, weight
|
||||||
|
|
||||||
|
|
||||||
class CustomBBox(BBox):
|
class CustomBBox(BBox):
|
||||||
''' region control bbox '''
|
''' region control bbox '''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AbstractDiffusion:
|
class AbstractDiffusion:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.method = self.__class__.__name__
|
self.method = self.__class__.__name__
|
||||||
self.pbar = None
|
self.pbar = None
|
||||||
|
|
||||||
|
|
||||||
self.w: int = 0
|
self.w: int = 0
|
||||||
self.h: int = 0
|
self.h: int = 0
|
||||||
self.tile_width: int = None
|
self.tile_width: int = None
|
||||||
@@ -107,8 +138,8 @@ class AbstractDiffusion:
|
|||||||
self._init_done = None
|
self._init_done = None
|
||||||
|
|
||||||
# count the step correctly
|
# count the step correctly
|
||||||
self.step_count = 0
|
self.step_count = 0
|
||||||
self.inner_loop_count = 0
|
self.inner_loop_count = 0
|
||||||
self.kdiff_step = -1
|
self.kdiff_step = -1
|
||||||
|
|
||||||
# ext. Grid tiling painting (grid bbox)
|
# ext. Grid tiling painting (grid bbox)
|
||||||
@@ -138,7 +169,7 @@ class AbstractDiffusion:
|
|||||||
self.control_tensor_cpu: bool = None
|
self.control_tensor_cpu: bool = None
|
||||||
self.control_tensor_custom: List[List[Tensor]] = []
|
self.control_tensor_custom: List[List[Tensor]] = []
|
||||||
|
|
||||||
self.draw_background: bool = True # by default we draw major prompts in grid tiles
|
self.draw_background: bool = True # by default we draw major prompts in grid tiles
|
||||||
self.control_tensor_cpu = False
|
self.control_tensor_cpu = False
|
||||||
self.weights = None
|
self.weights = None
|
||||||
self.imagescale = ImageScale()
|
self.imagescale = ImageScale()
|
||||||
@@ -154,19 +185,20 @@ class AbstractDiffusion:
|
|||||||
self.tile_overlap = tile_overlap
|
self.tile_overlap = tile_overlap
|
||||||
self.tile_batch_size = tile_batch_size
|
self.tile_batch_size = tile_batch_size
|
||||||
|
|
||||||
def repeat_tensor(self, x:Tensor, n:int, concat=False, concat_to=0) -> Tensor:
|
def repeat_tensor(self, x: Tensor, n: int, concat=False, concat_to=0) -> Tensor:
|
||||||
''' repeat the tensor on it's first dim '''
|
''' repeat the tensor on it's first dim '''
|
||||||
if n == 1: return x
|
if n == 1: return x
|
||||||
B = x.shape[0]
|
B = x.shape[0]
|
||||||
r_dims = len(x.shape) - 1
|
r_dims = len(x.shape) - 1
|
||||||
if B == 1: # batch_size = 1 (not `tile_batch_size`)
|
if B == 1: # batch_size = 1 (not `tile_batch_size`)
|
||||||
shape = [n] + [-1] * r_dims # [N, -1, ...]
|
shape = [n] + [-1] * r_dims # [N, -1, ...]
|
||||||
return x.expand(shape) # `expand` is much lighter than `tile`
|
return x.expand(shape) # `expand` is much lighter than `tile`
|
||||||
else:
|
else:
|
||||||
if concat:
|
if concat:
|
||||||
return torch.cat([x for _ in range(n)], dim=0)[:concat_to]
|
return torch.cat([x for _ in range(n)], dim=0)[:concat_to]
|
||||||
shape = [n] + [1] * r_dims # [N, 1, ...]
|
shape = [n] + [1] * r_dims # [N, 1, ...]
|
||||||
return x.repeat(shape)
|
return x.repeat(shape)
|
||||||
|
|
||||||
def update_pbar(self):
|
def update_pbar(self):
|
||||||
if self.pbar.n >= self.pbar.total:
|
if self.pbar.n >= self.pbar.total:
|
||||||
self.pbar.close()
|
self.pbar.close()
|
||||||
@@ -180,7 +212,8 @@ class AbstractDiffusion:
|
|||||||
else:
|
else:
|
||||||
self.step_count = sampling_step
|
self.step_count = sampling_step
|
||||||
self.inner_loop_count = 0
|
self.inner_loop_count = 0
|
||||||
def reset_buffer(self, x_in:Tensor):
|
|
||||||
|
def reset_buffer(self, x_in: Tensor):
|
||||||
# Judge if the shape of x_in is the same as the shape of x_buffer
|
# Judge if the shape of x_in is the same as the shape of x_buffer
|
||||||
if self.x_buffer is None or self.x_buffer.shape != x_in.shape:
|
if self.x_buffer is None or self.x_buffer.shape != x_in.shape:
|
||||||
self.x_buffer = torch.zeros_like(x_in, device=x_in.device, dtype=x_in.dtype)
|
self.x_buffer = torch.zeros_like(x_in, device=x_in.device, dtype=x_in.dtype)
|
||||||
@@ -188,7 +221,7 @@ class AbstractDiffusion:
|
|||||||
self.x_buffer.zero_()
|
self.x_buffer.zero_()
|
||||||
|
|
||||||
@grid_bbox
|
@grid_bbox
|
||||||
def init_grid_bbox(self, tile_w:int, tile_h:int, overlap:int, tile_bs:int):
|
def init_grid_bbox(self, tile_w: int, tile_h: int, overlap: int, tile_bs: int):
|
||||||
# if self._init_grid_bbox is not None: return
|
# if self._init_grid_bbox is not None: return
|
||||||
# self._init_grid_bbox = True
|
# self._init_grid_bbox = True
|
||||||
self.weights = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32)
|
self.weights = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32)
|
||||||
@@ -202,16 +235,16 @@ class AbstractDiffusion:
|
|||||||
bboxes, weights = split_bboxes(self.w, self.h, self.tile_w, self.tile_h, overlap, self.get_tile_weights())
|
bboxes, weights = split_bboxes(self.w, self.h, self.tile_w, self.tile_h, overlap, self.get_tile_weights())
|
||||||
self.weights += weights
|
self.weights += weights
|
||||||
self.num_tiles = len(bboxes)
|
self.num_tiles = len(bboxes)
|
||||||
self.num_batches = ceildiv(self.num_tiles , tile_bs)
|
self.num_batches = ceildiv(self.num_tiles, tile_bs)
|
||||||
self.tile_bs = ceildiv(len(bboxes) , self.num_batches) # optimal_batch_size
|
self.tile_bs = ceildiv(len(bboxes), self.num_batches) # optimal_batch_size
|
||||||
self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)]
|
self.batched_bboxes = [bboxes[i * self.tile_bs:(i + 1) * self.tile_bs] for i in range(self.num_batches)]
|
||||||
|
|
||||||
@grid_bbox
|
@grid_bbox
|
||||||
def get_tile_weights(self) -> Union[Tensor, float]:
|
def get_tile_weights(self) -> Union[Tensor, float]:
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
@noise_inverse
|
@noise_inverse
|
||||||
def init_noise_inverse(self, steps:int, retouch:float, get_cache_callback, set_cache_callback, renoise_strength:float, renoise_kernel:int):
|
def init_noise_inverse(self, steps: int, retouch: float, get_cache_callback, set_cache_callback, renoise_strength: float, renoise_kernel: int):
|
||||||
self.noise_inverse_enabled = True
|
self.noise_inverse_enabled = True
|
||||||
self.noise_inverse_steps = steps
|
self.noise_inverse_steps = steps
|
||||||
self.noise_inverse_retouch = float(retouch)
|
self.noise_inverse_retouch = float(retouch)
|
||||||
@@ -239,7 +272,7 @@ class AbstractDiffusion:
|
|||||||
# self.pbar = tqdm(total=(self.total_bboxes) * sampling_steps, desc=f"{self.method} Sampling: ")
|
# self.pbar = tqdm(total=(self.total_bboxes) * sampling_steps, desc=f"{self.method} Sampling: ")
|
||||||
|
|
||||||
@controlnet
|
@controlnet
|
||||||
def prepare_controlnet_tensors(self, refresh:bool=False, tensor=None):
|
def prepare_controlnet_tensors(self, refresh: bool = False, tensor=None):
|
||||||
''' Crop the control tensor into tiles and cache them '''
|
''' Crop the control tensor into tiles and cache them '''
|
||||||
if not refresh:
|
if not refresh:
|
||||||
if self.control_tensor_batch is not None or self.control_params is not None: return
|
if self.control_tensor_batch is not None or self.control_params is not None: return
|
||||||
@@ -254,7 +287,7 @@ class AbstractDiffusion:
|
|||||||
for bbox in bboxes:
|
for bbox in bboxes:
|
||||||
if len(control_tensor.shape) == 3:
|
if len(control_tensor.shape) == 3:
|
||||||
control_tensor.unsqueeze_(0)
|
control_tensor.unsqueeze_(0)
|
||||||
control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f]
|
control_tile = control_tensor[:, :, bbox[1] * opt_f:bbox[3] * opt_f, bbox[0] * opt_f:bbox[2] * opt_f]
|
||||||
single_batch_tensors.append(control_tile)
|
single_batch_tensors.append(control_tile)
|
||||||
control_tile = torch.cat(single_batch_tensors, dim=0)
|
control_tile = torch.cat(single_batch_tensors, dim=0)
|
||||||
if self.control_tensor_cpu:
|
if self.control_tensor_cpu:
|
||||||
@@ -267,14 +300,14 @@ class AbstractDiffusion:
|
|||||||
for bbox in self.custom_bboxes:
|
for bbox in self.custom_bboxes:
|
||||||
if len(control_tensor.shape) == 3:
|
if len(control_tensor.shape) == 3:
|
||||||
control_tensor.unsqueeze_(0)
|
control_tensor.unsqueeze_(0)
|
||||||
control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f]
|
control_tile = control_tensor[:, :, bbox[1] * opt_f:bbox[3] * opt_f, bbox[0] * opt_f:bbox[2] * opt_f]
|
||||||
if self.control_tensor_cpu:
|
if self.control_tensor_cpu:
|
||||||
control_tile = control_tile.cpu()
|
control_tile = control_tile.cpu()
|
||||||
custom_control_tile_list.append(control_tile)
|
custom_control_tile_list.append(control_tile)
|
||||||
self.control_tensor_custom.append(custom_control_tile_list)
|
self.control_tensor_custom.append(custom_control_tile_list)
|
||||||
|
|
||||||
@controlnet
|
@controlnet
|
||||||
def switch_controlnet_tensors(self, batch_id:int, x_batch_size:int, tile_batch_size:int, is_denoise=False):
|
def switch_controlnet_tensors(self, batch_id: int, x_batch_size: int, tile_batch_size: int, is_denoise=False):
|
||||||
# if not self.enable_controlnet: return
|
# if not self.enable_controlnet: return
|
||||||
if self.control_tensor_batch is None: return
|
if self.control_tensor_batch is None: return
|
||||||
# self.control_params = [0]
|
# self.control_params = [0]
|
||||||
@@ -284,12 +317,12 @@ class AbstractDiffusion:
|
|||||||
# tensor that was concatenated in `prepare_controlnet_tensors`
|
# tensor that was concatenated in `prepare_controlnet_tensors`
|
||||||
control_tile = self.control_tensor_batch[param_id][batch_id]
|
control_tile = self.control_tensor_batch[param_id][batch_id]
|
||||||
# broadcast to latent batch size
|
# broadcast to latent batch size
|
||||||
if x_batch_size > 1: # self.is_kdiff:
|
if x_batch_size > 1: # self.is_kdiff:
|
||||||
all_control_tile = []
|
all_control_tile = []
|
||||||
for i in range(tile_batch_size):
|
for i in range(tile_batch_size):
|
||||||
this_control_tile = [control_tile[i].unsqueeze(0)] * x_batch_size
|
this_control_tile = [control_tile[i].unsqueeze(0)] * x_batch_size
|
||||||
all_control_tile.append(torch.cat(this_control_tile, dim=0))
|
all_control_tile.append(torch.cat(this_control_tile, dim=0))
|
||||||
control_tile = torch.cat(all_control_tile, dim=0) # [:x_tile.shape[0]]
|
control_tile = torch.cat(all_control_tile, dim=0) # [:x_tile.shape[0]]
|
||||||
self.control_tensor_batch[param_id][batch_id] = control_tile
|
self.control_tensor_batch[param_id][batch_id] = control_tile
|
||||||
# else:
|
# else:
|
||||||
# control_tile = control_tile.repeat([x_batch_size if is_denoise else x_batch_size * 2, 1, 1, 1])
|
# control_tile = control_tile.repeat([x_batch_size if is_denoise else x_batch_size * 2, 1, 1, 1])
|
||||||
@@ -297,17 +330,17 @@ class AbstractDiffusion:
|
|||||||
|
|
||||||
def process_controlnet(self, x_shape, x_dtype, c_in: dict, cond_or_uncond: List, bboxes, batch_size: int, batch_id: int):
|
def process_controlnet(self, x_shape, x_dtype, c_in: dict, cond_or_uncond: List, bboxes, batch_size: int, batch_id: int):
|
||||||
control: ControlNet = c_in['control_model']
|
control: ControlNet = c_in['control_model']
|
||||||
param_id = -1 # current controlnet & previous_controlnets
|
param_id = -1 # current controlnet & previous_controlnets
|
||||||
tuple_key = tuple(cond_or_uncond) + tuple(x_shape)
|
tuple_key = tuple(cond_or_uncond) + tuple(x_shape)
|
||||||
while control is not None:
|
while control is not None:
|
||||||
param_id += 1
|
param_id += 1
|
||||||
PH, PW = self.h*8, self.w*8
|
PH, PW = self.h * 8, self.w * 8
|
||||||
|
|
||||||
if self.control_params.get(tuple_key, None) is None:
|
if self.control_params.get(tuple_key, None) is None:
|
||||||
self.control_params[tuple_key] = [[None]]
|
self.control_params[tuple_key] = [[None]]
|
||||||
val = self.control_params[tuple_key]
|
val = self.control_params[tuple_key]
|
||||||
if param_id+1 >= len(val):
|
if param_id + 1 >= len(val):
|
||||||
val.extend([[None] for _ in range(param_id+1)])
|
val.extend([[None] for _ in range(param_id + 1)])
|
||||||
if len(self.batched_bboxes) >= len(val[param_id]):
|
if len(self.batched_bboxes) >= len(val[param_id]):
|
||||||
val[param_id].extend([[None] for _ in range(len(self.batched_bboxes))])
|
val[param_id].extend([[None] for _ in range(len(self.batched_bboxes))])
|
||||||
|
|
||||||
@@ -319,59 +352,64 @@ class AbstractDiffusion:
|
|||||||
if dtype is None: dtype = x_dtype
|
if dtype is None: dtype = x_dtype
|
||||||
if isinstance(control, T2IAdapter):
|
if isinstance(control, T2IAdapter):
|
||||||
width, height = control.scale_image_to(PW, PH)
|
width, height = control.scale_image_to(PW, PH)
|
||||||
control.cond_hint = ldm_patched.modules.utils.common_upscale(control.cond_hint_original, width, height, 'nearest-exact', "center").float().to(control.device)
|
control.cond_hint = adaptive_resize(control.cond_hint_original, width, height, 'nearest-exact', "center").float().to(control.device)
|
||||||
if control.channels_in == 1 and control.cond_hint.shape[1] > 1:
|
if control.channels_in == 1 and control.cond_hint.shape[1] > 1:
|
||||||
control.cond_hint = torch.mean(control.cond_hint, 1, keepdim=True)
|
control.cond_hint = torch.mean(control.cond_hint, 1, keepdim=True)
|
||||||
elif control.__class__.__name__ == 'ControlLLLiteAdvanced':
|
elif control.__class__.__name__ == 'ControlLLLiteAdvanced':
|
||||||
if control.sub_idxs is not None and control.cond_hint_original.shape[0] >= control.full_latent_length:
|
if control.sub_idxs is not None and control.cond_hint_original.shape[0] >= control.full_latent_length:
|
||||||
control.cond_hint = ldm_patched.modules.utils.common_upscale(control.cond_hint_original[control.sub_idxs], PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
|
control.cond_hint = adaptive_resize(control.cond_hint_original[control.sub_idxs], PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
|
||||||
else:
|
else:
|
||||||
if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]):
|
if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]):
|
||||||
control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device)
|
control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device)
|
||||||
else:
|
else:
|
||||||
control.cond_hint = ldm_patched.modules.utils.common_upscale(control.cond_hint_original, PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
|
control.cond_hint = adaptive_resize(control.cond_hint_original, PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
|
||||||
else:
|
else:
|
||||||
if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]):
|
if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]):
|
||||||
control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device)
|
control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device)
|
||||||
else:
|
else:
|
||||||
control.cond_hint = ldm_patched.modules.utils.common_upscale(control.cond_hint_original, PW, PH, 'nearest-exact', 'center').to(dtype=dtype, device=control.device)
|
control.cond_hint = adaptive_resize(control.cond_hint_original, PW, PH, 'nearest-exact', 'center').to(dtype=dtype, device=control.device)
|
||||||
|
|
||||||
# Broadcast then tile
|
# Broadcast then tile
|
||||||
#
|
#
|
||||||
# Below can be in the parent's if clause because self.refresh will trigger on resolution change, e.g. cause of ConditioningSetArea
|
# Below can be in the parent's if clause because self.refresh will trigger on resolution change, e.g. cause of ConditioningSetArea
|
||||||
# so that particular case isn't cached atm.
|
# so that particular case isn't cached atm.
|
||||||
cond_hint_pre_tile = control.cond_hint
|
cond_hint_pre_tile = control.cond_hint
|
||||||
if control.cond_hint.shape[0] < batch_size :
|
if control.cond_hint.shape[0] < batch_size:
|
||||||
cond_hint_pre_tile = self.repeat_tensor(control.cond_hint, ceildiv(batch_size, control.cond_hint.shape[0]))[:batch_size]
|
cond_hint_pre_tile = self.repeat_tensor(control.cond_hint, ceildiv(batch_size, control.cond_hint.shape[0]))[:batch_size]
|
||||||
cns = [cond_hint_pre_tile[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f] for bbox in bboxes]
|
cns = [cond_hint_pre_tile[:, :, bbox[1] * opt_f:bbox[3] * opt_f, bbox[0] * opt_f:bbox[2] * opt_f] for bbox in bboxes]
|
||||||
control.cond_hint = torch.cat(cns, dim=0)
|
control.cond_hint = torch.cat(cns, dim=0)
|
||||||
self.control_params[tuple_key][param_id][batch_id]=control.cond_hint
|
self.control_params[tuple_key][param_id][batch_id] = control.cond_hint
|
||||||
else:
|
else:
|
||||||
control.cond_hint = self.control_params[tuple_key][param_id][batch_id]
|
control.cond_hint = self.control_params[tuple_key][param_id][batch_id]
|
||||||
control = control.previous_controlnet
|
control = control.previous_controlnet
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy import pi, exp, sqrt
|
from numpy import pi, exp, sqrt
|
||||||
def gaussian_weights(tile_w:int, tile_h:int) -> Tensor:
|
|
||||||
|
|
||||||
|
def gaussian_weights(tile_w: int, tile_h: int) -> Tensor:
|
||||||
'''
|
'''
|
||||||
Copy from the original implementation of Mixture of Diffusers
|
Copy from the original implementation of Mixture of Diffusers
|
||||||
https://github.com/albarji/mixture-of-diffusers/blob/master/mixdiff/tiling.py
|
https://github.com/albarji/mixture-of-diffusers/blob/master/mixdiff/tiling.py
|
||||||
This generates gaussian weights to smooth the noise of each tile.
|
This generates gaussian weights to smooth the noise of each tile.
|
||||||
This is critical for this method to work.
|
This is critical for this method to work.
|
||||||
'''
|
'''
|
||||||
f = lambda x, midpoint, var=0.01: exp(-(x-midpoint)*(x-midpoint) / (tile_w*tile_w) / (2*var)) / sqrt(2*pi*var)
|
f = lambda x, midpoint, var=0.01: exp(-(x - midpoint) * (x - midpoint) / (tile_w * tile_w) / (2 * var)) / sqrt(2 * pi * var)
|
||||||
x_probs = [f(x, (tile_w - 1) / 2) for x in range(tile_w)] # -1 because index goes from 0 to latent_width - 1
|
x_probs = [f(x, (tile_w - 1) / 2) for x in range(tile_w)] # -1 because index goes from 0 to latent_width - 1
|
||||||
y_probs = [f(y, tile_h / 2) for y in range(tile_h)]
|
y_probs = [f(y, tile_h / 2) for y in range(tile_h)]
|
||||||
|
|
||||||
w = np.outer(y_probs, x_probs)
|
w = np.outer(y_probs, x_probs)
|
||||||
return torch.from_numpy(w).to(devices.device, dtype=torch.float32)
|
return torch.from_numpy(w).to(devices.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
class CondDict: ...
|
class CondDict: ...
|
||||||
|
|
||||||
|
|
||||||
class MultiDiffusion(AbstractDiffusion):
|
class MultiDiffusion(AbstractDiffusion):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, model_function: BaseModel.apply_model, args: dict):
|
def __call__(self, model_function, args: dict):
|
||||||
x_in: Tensor = args["input"]
|
x_in: Tensor = args["input"]
|
||||||
t_in: Tensor = args["timestep"]
|
t_in: Tensor = args["timestep"]
|
||||||
c_in: dict = args["c"]
|
c_in: dict = args["c"]
|
||||||
@@ -395,12 +433,12 @@ class MultiDiffusion(AbstractDiffusion):
|
|||||||
# Background sampling (grid bbox)
|
# Background sampling (grid bbox)
|
||||||
if self.draw_background:
|
if self.draw_background:
|
||||||
for batch_id, bboxes in enumerate(self.batched_bboxes):
|
for batch_id, bboxes in enumerate(self.batched_bboxes):
|
||||||
if ldm_patched.modules.model_management.processing_interrupted():
|
if memory_management.processing_interrupted():
|
||||||
# self.pbar.close()
|
# self.pbar.close()
|
||||||
return x_in
|
return x_in
|
||||||
|
|
||||||
# batching & compute tiles
|
# batching & compute tiles
|
||||||
x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) # [TB, C, TH, TW]
|
x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) # [TB, C, TH, TW]
|
||||||
n_rep = len(bboxes)
|
n_rep = len(bboxes)
|
||||||
ts_tile = self.repeat_tensor(t_in, n_rep)
|
ts_tile = self.repeat_tensor(t_in, n_rep)
|
||||||
cond_tile = self.repeat_tensor(c_crossattn, n_rep)
|
cond_tile = self.repeat_tensor(c_crossattn, n_rep)
|
||||||
@@ -428,7 +466,7 @@ class MultiDiffusion(AbstractDiffusion):
|
|||||||
x_tile_out = model_function(x_tile, ts_tile, **c_tile)
|
x_tile_out = model_function(x_tile, ts_tile, **c_tile)
|
||||||
|
|
||||||
for i, bbox in enumerate(bboxes):
|
for i, bbox in enumerate(bboxes):
|
||||||
self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :]
|
self.x_buffer[bbox.slicer] += x_tile_out[i * N:(i + 1) * N, :, :, :]
|
||||||
del x_tile_out, x_tile, ts_tile, c_tile
|
del x_tile_out, x_tile, ts_tile, c_tile
|
||||||
|
|
||||||
# update progress bar
|
# update progress bar
|
||||||
@@ -439,6 +477,7 @@ class MultiDiffusion(AbstractDiffusion):
|
|||||||
|
|
||||||
return x_out
|
return x_out
|
||||||
|
|
||||||
|
|
||||||
class MixtureOfDiffusers(AbstractDiffusion):
|
class MixtureOfDiffusers(AbstractDiffusion):
|
||||||
"""
|
"""
|
||||||
Mixture-of-Diffusers Implementation
|
Mixture-of-Diffusers Implementation
|
||||||
@@ -470,11 +509,11 @@ class MixtureOfDiffusers(AbstractDiffusion):
|
|||||||
return self.tile_weights
|
return self.tile_weights
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, model_function: BaseModel.apply_model, args: dict):
|
def __call__(self, model_function, args: dict):
|
||||||
x_in: Tensor = args["input"]
|
x_in: Tensor = args["input"]
|
||||||
t_in: Tensor = args["timestep"]
|
t_in: Tensor = args["timestep"]
|
||||||
c_in: dict = args["c"]
|
c_in: dict = args["c"]
|
||||||
cond_or_uncond: List= args["cond_or_uncond"]
|
cond_or_uncond: List = args["cond_or_uncond"]
|
||||||
c_crossattn: Tensor = c_in['c_crossattn']
|
c_crossattn: Tensor = c_in['c_crossattn']
|
||||||
|
|
||||||
N, C, H, W = x_in.shape
|
N, C, H, W = x_in.shape
|
||||||
@@ -496,14 +535,14 @@ class MixtureOfDiffusers(AbstractDiffusion):
|
|||||||
|
|
||||||
# Global sampling
|
# Global sampling
|
||||||
if self.draw_background:
|
if self.draw_background:
|
||||||
for batch_id, bboxes in enumerate(self.batched_bboxes): # batch_id is the `Latent tile batch size`
|
for batch_id, bboxes in enumerate(self.batched_bboxes): # batch_id is the `Latent tile batch size`
|
||||||
if ldm_patched.modules.model_management.processing_interrupted():
|
if memory_management.processing_interrupted():
|
||||||
# self.pbar.close()
|
# self.pbar.close()
|
||||||
return x_in
|
return x_in
|
||||||
|
|
||||||
# batching
|
# batching
|
||||||
x_tile_list = []
|
x_tile_list = []
|
||||||
t_tile_list = []
|
t_tile_list = []
|
||||||
icond_map = {}
|
icond_map = {}
|
||||||
# tcond_tile_list = []
|
# tcond_tile_list = []
|
||||||
# icond_tile_list = []
|
# icond_tile_list = []
|
||||||
@@ -519,7 +558,7 @@ class MixtureOfDiffusers(AbstractDiffusion):
|
|||||||
# present in sdxl
|
# present in sdxl
|
||||||
for key in ['y', 'c_concat']:
|
for key in ['y', 'c_concat']:
|
||||||
if key in c_in:
|
if key in c_in:
|
||||||
icond=c_in[key] # self.get_icond(c_in)
|
icond = c_in[key] # self.get_icond(c_in)
|
||||||
if icond.shape[2:] == (self.h, self.w):
|
if icond.shape[2:] == (self.h, self.w):
|
||||||
icond = icond[bbox.slicer]
|
icond = icond[bbox.slicer]
|
||||||
if icond_map.get(key, None) is None:
|
if icond_map.get(key, None) is None:
|
||||||
@@ -531,13 +570,13 @@ class MixtureOfDiffusers(AbstractDiffusion):
|
|||||||
else:
|
else:
|
||||||
print('>> [WARN] not supported, make an issue on github!!')
|
print('>> [WARN] not supported, make an issue on github!!')
|
||||||
n_rep = len(bboxes)
|
n_rep = len(bboxes)
|
||||||
x_tile = torch.cat(x_tile_list, dim=0) # differs each
|
x_tile = torch.cat(x_tile_list, dim=0) # differs each
|
||||||
t_tile = self.repeat_tensor(t_in, n_rep) # just repeat
|
t_tile = self.repeat_tensor(t_in, n_rep) # just repeat
|
||||||
tcond_tile = self.repeat_tensor(c_crossattn, n_rep) # just repeat
|
tcond_tile = self.repeat_tensor(c_crossattn, n_rep) # just repeat
|
||||||
c_tile = c_in.copy()
|
c_tile = c_in.copy()
|
||||||
c_tile['c_crossattn'] = tcond_tile
|
c_tile['c_crossattn'] = tcond_tile
|
||||||
if 'time_context' in c_in:
|
if 'time_context' in c_in:
|
||||||
c_tile['time_context'] = self.repeat_tensor(c_in['time_context'], n_rep) # just repeat
|
c_tile['time_context'] = self.repeat_tensor(c_in['time_context'], n_rep) # just repeat
|
||||||
for key in c_tile:
|
for key in c_tile:
|
||||||
if key in ['y', 'c_concat']:
|
if key in ['y', 'c_concat']:
|
||||||
icond_tile = torch.cat(icond_map[key], dim=0) # differs each
|
icond_tile = torch.cat(icond_map[key], dim=0) # differs each
|
||||||
@@ -547,10 +586,10 @@ class MixtureOfDiffusers(AbstractDiffusion):
|
|||||||
# controlnet
|
# controlnet
|
||||||
# self.switch_controlnet_tensors(batch_id, N, len(bboxes), is_denoise=True)
|
# self.switch_controlnet_tensors(batch_id, N, len(bboxes), is_denoise=True)
|
||||||
if 'control' in c_in:
|
if 'control' in c_in:
|
||||||
control=c_in['control']
|
control = c_in['control']
|
||||||
self.process_controlnet(x_tile.shape, x_tile.dtype, c_in, cond_or_uncond, bboxes, N, batch_id)
|
self.process_controlnet(x_tile.shape, x_tile.dtype, c_in, cond_or_uncond, bboxes, N, batch_id)
|
||||||
c_tile['control'] = control.get_control(x_tile, t_tile, c_tile, len(cond_or_uncond))
|
c_tile['control'] = control.get_control(x_tile, t_tile, c_tile, len(cond_or_uncond))
|
||||||
|
|
||||||
# stablesr
|
# stablesr
|
||||||
# self.switch_stablesr_tensors(batch_id)
|
# self.switch_stablesr_tensors(batch_id)
|
||||||
|
|
||||||
@@ -562,7 +601,7 @@ class MixtureOfDiffusers(AbstractDiffusion):
|
|||||||
# These weights can be calcluated in advance, but will cost a lot of vram
|
# These weights can be calcluated in advance, but will cost a lot of vram
|
||||||
# when you have many tiles. So we calculate it here.
|
# when you have many tiles. So we calculate it here.
|
||||||
w = self.tile_weights * self.rescale_factor[bbox.slicer]
|
w = self.tile_weights * self.rescale_factor[bbox.slicer]
|
||||||
self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] * w
|
self.x_buffer[bbox.slicer] += x_tile_out[i * N:(i + 1) * N, :, :, :] * w
|
||||||
del x_tile_out, x_tile, t_tile, c_tile
|
del x_tile_out, x_tile, t_tile, c_tile
|
||||||
|
|
||||||
# self.update_pbar()
|
# self.update_pbar()
|
||||||
@@ -573,19 +612,22 @@ class MixtureOfDiffusers(AbstractDiffusion):
|
|||||||
return x_out
|
return x_out
|
||||||
|
|
||||||
|
|
||||||
MAX_RESOLUTION=8192
|
MAX_RESOLUTION = 8192
|
||||||
|
|
||||||
|
|
||||||
class TiledDiffusion():
|
class TiledDiffusion():
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"model": ("MODEL", ),
|
return {"required": {"model": ("MODEL",),
|
||||||
"method": (["MultiDiffusion", "Mixture of Diffusers"], {"default": "Mixture of Diffusers"}),
|
"method": (["MultiDiffusion", "Mixture of Diffusers"], {"default": "Mixture of Diffusers"}),
|
||||||
# "tile_width": ("INT", {"default": 96, "min": 16, "max": 256, "step": 16}),
|
# "tile_width": ("INT", {"default": 96, "min": 16, "max": 256, "step": 16}),
|
||||||
"tile_width": ("INT", {"default": 96*opt_f, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
"tile_width": ("INT", {"default": 96 * opt_f, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
||||||
# "tile_height": ("INT", {"default": 96, "min": 16, "max": 256, "step": 16}),
|
# "tile_height": ("INT", {"default": 96, "min": 16, "max": 256, "step": 16}),
|
||||||
"tile_height": ("INT", {"default": 96*opt_f, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
"tile_height": ("INT", {"default": 96 * opt_f, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
||||||
"tile_overlap": ("INT", {"default": 8*opt_f, "min": 0, "max": 256*opt_f, "step": 4*opt_f}),
|
"tile_overlap": ("INT", {"default": 8 * opt_f, "min": 0, "max": 256 * opt_f, "step": 4 * opt_f}),
|
||||||
"tile_batch_size": ("INT", {"default": 4, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
"tile_batch_size": ("INT", {"default": 4, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "apply"
|
FUNCTION = "apply"
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
@@ -595,7 +637,7 @@ class TiledDiffusion():
|
|||||||
implement = MixtureOfDiffusers()
|
implement = MixtureOfDiffusers()
|
||||||
else:
|
else:
|
||||||
implement = MultiDiffusion()
|
implement = MultiDiffusion()
|
||||||
|
|
||||||
# if noise_inversion:
|
# if noise_inversion:
|
||||||
# get_cache_callback = self.noise_inverse_get_cache
|
# get_cache_callback = self.noise_inverse_get_cache
|
||||||
# set_cache_callback = None # lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, steps, retouch)
|
# set_cache_callback = None # lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, steps, retouch)
|
||||||
|
|||||||
Reference in New Issue
Block a user