tiled diffusion

This commit is contained in:
layerdiffusion
2024-08-03 14:32:47 -07:00
parent e6bd652a4a
commit 4e4296b6fa

View File

@@ -7,61 +7,91 @@
from __future__ import division from __future__ import division
import torch import torch
from torch import Tensor from torch import Tensor
import ldm_patched.modules.model_management from backend import memory_management
from ldm_patched.modules.model_patcher import ModelPatcher from backend.misc.image_resize import adaptive_resize
import ldm_patched.modules.model_patcher from backend.patcher.base import ModelPatcher
from ldm_patched.modules.model_base import BaseModel
from typing import List, Union, Tuple, Dict from typing import List, Union, Tuple, Dict
from ldm_patched.contrib.external import ImageScale
import ldm_patched.modules.utils
from backend.patcher.controlnet import ControlNet, T2IAdapter from backend.patcher.controlnet import ControlNet, T2IAdapter
class ImageScale:
def upscale(self, image, upscale_method, width, height, crop):
if width == 0 and height == 0:
s = image
else:
samples = image.movedim(-1, 1)
if width == 0:
width = max(1, round(samples.shape[3] * height / samples.shape[2]))
elif height == 0:
height = max(1, round(samples.shape[2] * width / samples.shape[3]))
s = adaptive_resize(samples, width, height, upscale_method, crop)
s = s.movedim(1, -1)
return (s,)
opt_C = 4 opt_C = 4
opt_f = 8 opt_f = 8
def ceildiv(big, small): def ceildiv(big, small):
# Correct ceiling division that avoids floating-point errors and importing math.ceil. # Correct ceiling division that avoids floating-point errors and importing math.ceil.
return -(big // -small) return -(big // -small)
from enum import Enum from enum import Enum
class BlendMode(Enum): # i.e. LayerType class BlendMode(Enum): # i.e. LayerType
FOREGROUND = 'Foreground' FOREGROUND = 'Foreground'
BACKGROUND = 'Background' BACKGROUND = 'Background'
class Processing: ... class Processing: ...
class Device: ... class Device: ...
devices = Device() devices = Device()
devices.device = ldm_patched.modules.model_management.get_torch_device() devices.device = memory_management.get_torch_device()
def null_decorator(fn): def null_decorator(fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
return fn(*args, **kwargs) return fn(*args, **kwargs)
return wrapper return wrapper
keep_signature = null_decorator keep_signature = null_decorator
controlnet = null_decorator controlnet = null_decorator
stablesr = null_decorator stablesr = null_decorator
grid_bbox = null_decorator grid_bbox = null_decorator
custom_bbox = null_decorator custom_bbox = null_decorator
noise_inverse = null_decorator noise_inverse = null_decorator
class BBox: class BBox:
''' grid bbox ''' ''' grid bbox '''
def __init__(self, x:int, y:int, w:int, h:int): def __init__(self, x: int, y: int, w: int, h: int):
self.x = x self.x = x
self.y = y self.y = y
self.w = w self.w = w
self.h = h self.h = h
self.box = [x, y, x+w, y+h] self.box = [x, y, x + w, y + h]
self.slicer = slice(None), slice(None), slice(y, y+h), slice(x, x+w) self.slicer = slice(None), slice(None), slice(y, y + h), slice(x, x + w)
def __getitem__(self, idx:int) -> int: def __getitem__(self, idx: int) -> int:
return self.box[idx] return self.box[idx]
def split_bboxes(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]:
cols = ceildiv((w - overlap) , (tile_w - overlap)) def split_bboxes(w: int, h: int, tile_w: int, tile_h: int, overlap: int = 16, init_weight: Union[Tensor, float] = 1.0) -> Tuple[List[BBox], Tensor]:
rows = ceildiv((h - overlap) , (tile_h - overlap)) cols = ceildiv((w - overlap), (tile_w - overlap))
rows = ceildiv((h - overlap), (tile_h - overlap))
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
@@ -78,16 +108,17 @@ def split_bboxes(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16, init_weig
return bbox_list, weight return bbox_list, weight
class CustomBBox(BBox): class CustomBBox(BBox):
''' region control bbox ''' ''' region control bbox '''
pass pass
class AbstractDiffusion: class AbstractDiffusion:
def __init__(self): def __init__(self):
self.method = self.__class__.__name__ self.method = self.__class__.__name__
self.pbar = None self.pbar = None
self.w: int = 0 self.w: int = 0
self.h: int = 0 self.h: int = 0
self.tile_width: int = None self.tile_width: int = None
@@ -107,8 +138,8 @@ class AbstractDiffusion:
self._init_done = None self._init_done = None
# count the step correctly # count the step correctly
self.step_count = 0 self.step_count = 0
self.inner_loop_count = 0 self.inner_loop_count = 0
self.kdiff_step = -1 self.kdiff_step = -1
# ext. Grid tiling painting (grid bbox) # ext. Grid tiling painting (grid bbox)
@@ -138,7 +169,7 @@ class AbstractDiffusion:
self.control_tensor_cpu: bool = None self.control_tensor_cpu: bool = None
self.control_tensor_custom: List[List[Tensor]] = [] self.control_tensor_custom: List[List[Tensor]] = []
self.draw_background: bool = True # by default we draw major prompts in grid tiles self.draw_background: bool = True # by default we draw major prompts in grid tiles
self.control_tensor_cpu = False self.control_tensor_cpu = False
self.weights = None self.weights = None
self.imagescale = ImageScale() self.imagescale = ImageScale()
@@ -154,19 +185,20 @@ class AbstractDiffusion:
self.tile_overlap = tile_overlap self.tile_overlap = tile_overlap
self.tile_batch_size = tile_batch_size self.tile_batch_size = tile_batch_size
def repeat_tensor(self, x:Tensor, n:int, concat=False, concat_to=0) -> Tensor: def repeat_tensor(self, x: Tensor, n: int, concat=False, concat_to=0) -> Tensor:
''' repeat the tensor on it's first dim ''' ''' repeat the tensor on it's first dim '''
if n == 1: return x if n == 1: return x
B = x.shape[0] B = x.shape[0]
r_dims = len(x.shape) - 1 r_dims = len(x.shape) - 1
if B == 1: # batch_size = 1 (not `tile_batch_size`) if B == 1: # batch_size = 1 (not `tile_batch_size`)
shape = [n] + [-1] * r_dims # [N, -1, ...] shape = [n] + [-1] * r_dims # [N, -1, ...]
return x.expand(shape) # `expand` is much lighter than `tile` return x.expand(shape) # `expand` is much lighter than `tile`
else: else:
if concat: if concat:
return torch.cat([x for _ in range(n)], dim=0)[:concat_to] return torch.cat([x for _ in range(n)], dim=0)[:concat_to]
shape = [n] + [1] * r_dims # [N, 1, ...] shape = [n] + [1] * r_dims # [N, 1, ...]
return x.repeat(shape) return x.repeat(shape)
def update_pbar(self): def update_pbar(self):
if self.pbar.n >= self.pbar.total: if self.pbar.n >= self.pbar.total:
self.pbar.close() self.pbar.close()
@@ -180,7 +212,8 @@ class AbstractDiffusion:
else: else:
self.step_count = sampling_step self.step_count = sampling_step
self.inner_loop_count = 0 self.inner_loop_count = 0
def reset_buffer(self, x_in:Tensor):
def reset_buffer(self, x_in: Tensor):
# Judge if the shape of x_in is the same as the shape of x_buffer # Judge if the shape of x_in is the same as the shape of x_buffer
if self.x_buffer is None or self.x_buffer.shape != x_in.shape: if self.x_buffer is None or self.x_buffer.shape != x_in.shape:
self.x_buffer = torch.zeros_like(x_in, device=x_in.device, dtype=x_in.dtype) self.x_buffer = torch.zeros_like(x_in, device=x_in.device, dtype=x_in.dtype)
@@ -188,7 +221,7 @@ class AbstractDiffusion:
self.x_buffer.zero_() self.x_buffer.zero_()
@grid_bbox @grid_bbox
def init_grid_bbox(self, tile_w:int, tile_h:int, overlap:int, tile_bs:int): def init_grid_bbox(self, tile_w: int, tile_h: int, overlap: int, tile_bs: int):
# if self._init_grid_bbox is not None: return # if self._init_grid_bbox is not None: return
# self._init_grid_bbox = True # self._init_grid_bbox = True
self.weights = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32) self.weights = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32)
@@ -202,16 +235,16 @@ class AbstractDiffusion:
bboxes, weights = split_bboxes(self.w, self.h, self.tile_w, self.tile_h, overlap, self.get_tile_weights()) bboxes, weights = split_bboxes(self.w, self.h, self.tile_w, self.tile_h, overlap, self.get_tile_weights())
self.weights += weights self.weights += weights
self.num_tiles = len(bboxes) self.num_tiles = len(bboxes)
self.num_batches = ceildiv(self.num_tiles , tile_bs) self.num_batches = ceildiv(self.num_tiles, tile_bs)
self.tile_bs = ceildiv(len(bboxes) , self.num_batches) # optimal_batch_size self.tile_bs = ceildiv(len(bboxes), self.num_batches) # optimal_batch_size
self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)] self.batched_bboxes = [bboxes[i * self.tile_bs:(i + 1) * self.tile_bs] for i in range(self.num_batches)]
@grid_bbox @grid_bbox
def get_tile_weights(self) -> Union[Tensor, float]: def get_tile_weights(self) -> Union[Tensor, float]:
return 1.0 return 1.0
@noise_inverse @noise_inverse
def init_noise_inverse(self, steps:int, retouch:float, get_cache_callback, set_cache_callback, renoise_strength:float, renoise_kernel:int): def init_noise_inverse(self, steps: int, retouch: float, get_cache_callback, set_cache_callback, renoise_strength: float, renoise_kernel: int):
self.noise_inverse_enabled = True self.noise_inverse_enabled = True
self.noise_inverse_steps = steps self.noise_inverse_steps = steps
self.noise_inverse_retouch = float(retouch) self.noise_inverse_retouch = float(retouch)
@@ -239,7 +272,7 @@ class AbstractDiffusion:
# self.pbar = tqdm(total=(self.total_bboxes) * sampling_steps, desc=f"{self.method} Sampling: ") # self.pbar = tqdm(total=(self.total_bboxes) * sampling_steps, desc=f"{self.method} Sampling: ")
@controlnet @controlnet
def prepare_controlnet_tensors(self, refresh:bool=False, tensor=None): def prepare_controlnet_tensors(self, refresh: bool = False, tensor=None):
''' Crop the control tensor into tiles and cache them ''' ''' Crop the control tensor into tiles and cache them '''
if not refresh: if not refresh:
if self.control_tensor_batch is not None or self.control_params is not None: return if self.control_tensor_batch is not None or self.control_params is not None: return
@@ -254,7 +287,7 @@ class AbstractDiffusion:
for bbox in bboxes: for bbox in bboxes:
if len(control_tensor.shape) == 3: if len(control_tensor.shape) == 3:
control_tensor.unsqueeze_(0) control_tensor.unsqueeze_(0)
control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f] control_tile = control_tensor[:, :, bbox[1] * opt_f:bbox[3] * opt_f, bbox[0] * opt_f:bbox[2] * opt_f]
single_batch_tensors.append(control_tile) single_batch_tensors.append(control_tile)
control_tile = torch.cat(single_batch_tensors, dim=0) control_tile = torch.cat(single_batch_tensors, dim=0)
if self.control_tensor_cpu: if self.control_tensor_cpu:
@@ -267,14 +300,14 @@ class AbstractDiffusion:
for bbox in self.custom_bboxes: for bbox in self.custom_bboxes:
if len(control_tensor.shape) == 3: if len(control_tensor.shape) == 3:
control_tensor.unsqueeze_(0) control_tensor.unsqueeze_(0)
control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f] control_tile = control_tensor[:, :, bbox[1] * opt_f:bbox[3] * opt_f, bbox[0] * opt_f:bbox[2] * opt_f]
if self.control_tensor_cpu: if self.control_tensor_cpu:
control_tile = control_tile.cpu() control_tile = control_tile.cpu()
custom_control_tile_list.append(control_tile) custom_control_tile_list.append(control_tile)
self.control_tensor_custom.append(custom_control_tile_list) self.control_tensor_custom.append(custom_control_tile_list)
@controlnet @controlnet
def switch_controlnet_tensors(self, batch_id:int, x_batch_size:int, tile_batch_size:int, is_denoise=False): def switch_controlnet_tensors(self, batch_id: int, x_batch_size: int, tile_batch_size: int, is_denoise=False):
# if not self.enable_controlnet: return # if not self.enable_controlnet: return
if self.control_tensor_batch is None: return if self.control_tensor_batch is None: return
# self.control_params = [0] # self.control_params = [0]
@@ -284,12 +317,12 @@ class AbstractDiffusion:
# tensor that was concatenated in `prepare_controlnet_tensors` # tensor that was concatenated in `prepare_controlnet_tensors`
control_tile = self.control_tensor_batch[param_id][batch_id] control_tile = self.control_tensor_batch[param_id][batch_id]
# broadcast to latent batch size # broadcast to latent batch size
if x_batch_size > 1: # self.is_kdiff: if x_batch_size > 1: # self.is_kdiff:
all_control_tile = [] all_control_tile = []
for i in range(tile_batch_size): for i in range(tile_batch_size):
this_control_tile = [control_tile[i].unsqueeze(0)] * x_batch_size this_control_tile = [control_tile[i].unsqueeze(0)] * x_batch_size
all_control_tile.append(torch.cat(this_control_tile, dim=0)) all_control_tile.append(torch.cat(this_control_tile, dim=0))
control_tile = torch.cat(all_control_tile, dim=0) # [:x_tile.shape[0]] control_tile = torch.cat(all_control_tile, dim=0) # [:x_tile.shape[0]]
self.control_tensor_batch[param_id][batch_id] = control_tile self.control_tensor_batch[param_id][batch_id] = control_tile
# else: # else:
# control_tile = control_tile.repeat([x_batch_size if is_denoise else x_batch_size * 2, 1, 1, 1]) # control_tile = control_tile.repeat([x_batch_size if is_denoise else x_batch_size * 2, 1, 1, 1])
@@ -297,17 +330,17 @@ class AbstractDiffusion:
def process_controlnet(self, x_shape, x_dtype, c_in: dict, cond_or_uncond: List, bboxes, batch_size: int, batch_id: int): def process_controlnet(self, x_shape, x_dtype, c_in: dict, cond_or_uncond: List, bboxes, batch_size: int, batch_id: int):
control: ControlNet = c_in['control_model'] control: ControlNet = c_in['control_model']
param_id = -1 # current controlnet & previous_controlnets param_id = -1 # current controlnet & previous_controlnets
tuple_key = tuple(cond_or_uncond) + tuple(x_shape) tuple_key = tuple(cond_or_uncond) + tuple(x_shape)
while control is not None: while control is not None:
param_id += 1 param_id += 1
PH, PW = self.h*8, self.w*8 PH, PW = self.h * 8, self.w * 8
if self.control_params.get(tuple_key, None) is None: if self.control_params.get(tuple_key, None) is None:
self.control_params[tuple_key] = [[None]] self.control_params[tuple_key] = [[None]]
val = self.control_params[tuple_key] val = self.control_params[tuple_key]
if param_id+1 >= len(val): if param_id + 1 >= len(val):
val.extend([[None] for _ in range(param_id+1)]) val.extend([[None] for _ in range(param_id + 1)])
if len(self.batched_bboxes) >= len(val[param_id]): if len(self.batched_bboxes) >= len(val[param_id]):
val[param_id].extend([[None] for _ in range(len(self.batched_bboxes))]) val[param_id].extend([[None] for _ in range(len(self.batched_bboxes))])
@@ -319,59 +352,64 @@ class AbstractDiffusion:
if dtype is None: dtype = x_dtype if dtype is None: dtype = x_dtype
if isinstance(control, T2IAdapter): if isinstance(control, T2IAdapter):
width, height = control.scale_image_to(PW, PH) width, height = control.scale_image_to(PW, PH)
control.cond_hint = ldm_patched.modules.utils.common_upscale(control.cond_hint_original, width, height, 'nearest-exact', "center").float().to(control.device) control.cond_hint = adaptive_resize(control.cond_hint_original, width, height, 'nearest-exact', "center").float().to(control.device)
if control.channels_in == 1 and control.cond_hint.shape[1] > 1: if control.channels_in == 1 and control.cond_hint.shape[1] > 1:
control.cond_hint = torch.mean(control.cond_hint, 1, keepdim=True) control.cond_hint = torch.mean(control.cond_hint, 1, keepdim=True)
elif control.__class__.__name__ == 'ControlLLLiteAdvanced': elif control.__class__.__name__ == 'ControlLLLiteAdvanced':
if control.sub_idxs is not None and control.cond_hint_original.shape[0] >= control.full_latent_length: if control.sub_idxs is not None and control.cond_hint_original.shape[0] >= control.full_latent_length:
control.cond_hint = ldm_patched.modules.utils.common_upscale(control.cond_hint_original[control.sub_idxs], PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device) control.cond_hint = adaptive_resize(control.cond_hint_original[control.sub_idxs], PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
else: else:
if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]): if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]):
control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device) control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device)
else: else:
control.cond_hint = ldm_patched.modules.utils.common_upscale(control.cond_hint_original, PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device) control.cond_hint = adaptive_resize(control.cond_hint_original, PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
else: else:
if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]): if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]):
control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device) control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device)
else: else:
control.cond_hint = ldm_patched.modules.utils.common_upscale(control.cond_hint_original, PW, PH, 'nearest-exact', 'center').to(dtype=dtype, device=control.device) control.cond_hint = adaptive_resize(control.cond_hint_original, PW, PH, 'nearest-exact', 'center').to(dtype=dtype, device=control.device)
# Broadcast then tile # Broadcast then tile
# #
# Below can be in the parent's if clause because self.refresh will trigger on resolution change, e.g. cause of ConditioningSetArea # Below can be in the parent's if clause because self.refresh will trigger on resolution change, e.g. cause of ConditioningSetArea
# so that particular case isn't cached atm. # so that particular case isn't cached atm.
cond_hint_pre_tile = control.cond_hint cond_hint_pre_tile = control.cond_hint
if control.cond_hint.shape[0] < batch_size : if control.cond_hint.shape[0] < batch_size:
cond_hint_pre_tile = self.repeat_tensor(control.cond_hint, ceildiv(batch_size, control.cond_hint.shape[0]))[:batch_size] cond_hint_pre_tile = self.repeat_tensor(control.cond_hint, ceildiv(batch_size, control.cond_hint.shape[0]))[:batch_size]
cns = [cond_hint_pre_tile[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f] for bbox in bboxes] cns = [cond_hint_pre_tile[:, :, bbox[1] * opt_f:bbox[3] * opt_f, bbox[0] * opt_f:bbox[2] * opt_f] for bbox in bboxes]
control.cond_hint = torch.cat(cns, dim=0) control.cond_hint = torch.cat(cns, dim=0)
self.control_params[tuple_key][param_id][batch_id]=control.cond_hint self.control_params[tuple_key][param_id][batch_id] = control.cond_hint
else: else:
control.cond_hint = self.control_params[tuple_key][param_id][batch_id] control.cond_hint = self.control_params[tuple_key][param_id][batch_id]
control = control.previous_controlnet control = control.previous_controlnet
import numpy as np import numpy as np
from numpy import pi, exp, sqrt from numpy import pi, exp, sqrt
def gaussian_weights(tile_w:int, tile_h:int) -> Tensor:
def gaussian_weights(tile_w: int, tile_h: int) -> Tensor:
''' '''
Copy from the original implementation of Mixture of Diffusers Copy from the original implementation of Mixture of Diffusers
https://github.com/albarji/mixture-of-diffusers/blob/master/mixdiff/tiling.py https://github.com/albarji/mixture-of-diffusers/blob/master/mixdiff/tiling.py
This generates gaussian weights to smooth the noise of each tile. This generates gaussian weights to smooth the noise of each tile.
This is critical for this method to work. This is critical for this method to work.
''' '''
f = lambda x, midpoint, var=0.01: exp(-(x-midpoint)*(x-midpoint) / (tile_w*tile_w) / (2*var)) / sqrt(2*pi*var) f = lambda x, midpoint, var=0.01: exp(-(x - midpoint) * (x - midpoint) / (tile_w * tile_w) / (2 * var)) / sqrt(2 * pi * var)
x_probs = [f(x, (tile_w - 1) / 2) for x in range(tile_w)] # -1 because index goes from 0 to latent_width - 1 x_probs = [f(x, (tile_w - 1) / 2) for x in range(tile_w)] # -1 because index goes from 0 to latent_width - 1
y_probs = [f(y, tile_h / 2) for y in range(tile_h)] y_probs = [f(y, tile_h / 2) for y in range(tile_h)]
w = np.outer(y_probs, x_probs) w = np.outer(y_probs, x_probs)
return torch.from_numpy(w).to(devices.device, dtype=torch.float32) return torch.from_numpy(w).to(devices.device, dtype=torch.float32)
class CondDict: ... class CondDict: ...
class MultiDiffusion(AbstractDiffusion): class MultiDiffusion(AbstractDiffusion):
@torch.no_grad() @torch.no_grad()
def __call__(self, model_function: BaseModel.apply_model, args: dict): def __call__(self, model_function, args: dict):
x_in: Tensor = args["input"] x_in: Tensor = args["input"]
t_in: Tensor = args["timestep"] t_in: Tensor = args["timestep"]
c_in: dict = args["c"] c_in: dict = args["c"]
@@ -395,12 +433,12 @@ class MultiDiffusion(AbstractDiffusion):
# Background sampling (grid bbox) # Background sampling (grid bbox)
if self.draw_background: if self.draw_background:
for batch_id, bboxes in enumerate(self.batched_bboxes): for batch_id, bboxes in enumerate(self.batched_bboxes):
if ldm_patched.modules.model_management.processing_interrupted(): if memory_management.processing_interrupted():
# self.pbar.close() # self.pbar.close()
return x_in return x_in
# batching & compute tiles # batching & compute tiles
x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) # [TB, C, TH, TW] x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) # [TB, C, TH, TW]
n_rep = len(bboxes) n_rep = len(bboxes)
ts_tile = self.repeat_tensor(t_in, n_rep) ts_tile = self.repeat_tensor(t_in, n_rep)
cond_tile = self.repeat_tensor(c_crossattn, n_rep) cond_tile = self.repeat_tensor(c_crossattn, n_rep)
@@ -428,7 +466,7 @@ class MultiDiffusion(AbstractDiffusion):
x_tile_out = model_function(x_tile, ts_tile, **c_tile) x_tile_out = model_function(x_tile, ts_tile, **c_tile)
for i, bbox in enumerate(bboxes): for i, bbox in enumerate(bboxes):
self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] self.x_buffer[bbox.slicer] += x_tile_out[i * N:(i + 1) * N, :, :, :]
del x_tile_out, x_tile, ts_tile, c_tile del x_tile_out, x_tile, ts_tile, c_tile
# update progress bar # update progress bar
@@ -439,6 +477,7 @@ class MultiDiffusion(AbstractDiffusion):
return x_out return x_out
class MixtureOfDiffusers(AbstractDiffusion): class MixtureOfDiffusers(AbstractDiffusion):
""" """
Mixture-of-Diffusers Implementation Mixture-of-Diffusers Implementation
@@ -470,11 +509,11 @@ class MixtureOfDiffusers(AbstractDiffusion):
return self.tile_weights return self.tile_weights
@torch.no_grad() @torch.no_grad()
def __call__(self, model_function: BaseModel.apply_model, args: dict): def __call__(self, model_function, args: dict):
x_in: Tensor = args["input"] x_in: Tensor = args["input"]
t_in: Tensor = args["timestep"] t_in: Tensor = args["timestep"]
c_in: dict = args["c"] c_in: dict = args["c"]
cond_or_uncond: List= args["cond_or_uncond"] cond_or_uncond: List = args["cond_or_uncond"]
c_crossattn: Tensor = c_in['c_crossattn'] c_crossattn: Tensor = c_in['c_crossattn']
N, C, H, W = x_in.shape N, C, H, W = x_in.shape
@@ -496,14 +535,14 @@ class MixtureOfDiffusers(AbstractDiffusion):
# Global sampling # Global sampling
if self.draw_background: if self.draw_background:
for batch_id, bboxes in enumerate(self.batched_bboxes): # batch_id is the `Latent tile batch size` for batch_id, bboxes in enumerate(self.batched_bboxes): # batch_id is the `Latent tile batch size`
if ldm_patched.modules.model_management.processing_interrupted(): if memory_management.processing_interrupted():
# self.pbar.close() # self.pbar.close()
return x_in return x_in
# batching # batching
x_tile_list = [] x_tile_list = []
t_tile_list = [] t_tile_list = []
icond_map = {} icond_map = {}
# tcond_tile_list = [] # tcond_tile_list = []
# icond_tile_list = [] # icond_tile_list = []
@@ -519,7 +558,7 @@ class MixtureOfDiffusers(AbstractDiffusion):
# present in sdxl # present in sdxl
for key in ['y', 'c_concat']: for key in ['y', 'c_concat']:
if key in c_in: if key in c_in:
icond=c_in[key] # self.get_icond(c_in) icond = c_in[key] # self.get_icond(c_in)
if icond.shape[2:] == (self.h, self.w): if icond.shape[2:] == (self.h, self.w):
icond = icond[bbox.slicer] icond = icond[bbox.slicer]
if icond_map.get(key, None) is None: if icond_map.get(key, None) is None:
@@ -531,13 +570,13 @@ class MixtureOfDiffusers(AbstractDiffusion):
else: else:
print('>> [WARN] not supported, make an issue on github!!') print('>> [WARN] not supported, make an issue on github!!')
n_rep = len(bboxes) n_rep = len(bboxes)
x_tile = torch.cat(x_tile_list, dim=0) # differs each x_tile = torch.cat(x_tile_list, dim=0) # differs each
t_tile = self.repeat_tensor(t_in, n_rep) # just repeat t_tile = self.repeat_tensor(t_in, n_rep) # just repeat
tcond_tile = self.repeat_tensor(c_crossattn, n_rep) # just repeat tcond_tile = self.repeat_tensor(c_crossattn, n_rep) # just repeat
c_tile = c_in.copy() c_tile = c_in.copy()
c_tile['c_crossattn'] = tcond_tile c_tile['c_crossattn'] = tcond_tile
if 'time_context' in c_in: if 'time_context' in c_in:
c_tile['time_context'] = self.repeat_tensor(c_in['time_context'], n_rep) # just repeat c_tile['time_context'] = self.repeat_tensor(c_in['time_context'], n_rep) # just repeat
for key in c_tile: for key in c_tile:
if key in ['y', 'c_concat']: if key in ['y', 'c_concat']:
icond_tile = torch.cat(icond_map[key], dim=0) # differs each icond_tile = torch.cat(icond_map[key], dim=0) # differs each
@@ -547,10 +586,10 @@ class MixtureOfDiffusers(AbstractDiffusion):
# controlnet # controlnet
# self.switch_controlnet_tensors(batch_id, N, len(bboxes), is_denoise=True) # self.switch_controlnet_tensors(batch_id, N, len(bboxes), is_denoise=True)
if 'control' in c_in: if 'control' in c_in:
control=c_in['control'] control = c_in['control']
self.process_controlnet(x_tile.shape, x_tile.dtype, c_in, cond_or_uncond, bboxes, N, batch_id) self.process_controlnet(x_tile.shape, x_tile.dtype, c_in, cond_or_uncond, bboxes, N, batch_id)
c_tile['control'] = control.get_control(x_tile, t_tile, c_tile, len(cond_or_uncond)) c_tile['control'] = control.get_control(x_tile, t_tile, c_tile, len(cond_or_uncond))
# stablesr # stablesr
# self.switch_stablesr_tensors(batch_id) # self.switch_stablesr_tensors(batch_id)
@@ -562,7 +601,7 @@ class MixtureOfDiffusers(AbstractDiffusion):
# These weights can be calcluated in advance, but will cost a lot of vram # These weights can be calcluated in advance, but will cost a lot of vram
# when you have many tiles. So we calculate it here. # when you have many tiles. So we calculate it here.
w = self.tile_weights * self.rescale_factor[bbox.slicer] w = self.tile_weights * self.rescale_factor[bbox.slicer]
self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] * w self.x_buffer[bbox.slicer] += x_tile_out[i * N:(i + 1) * N, :, :, :] * w
del x_tile_out, x_tile, t_tile, c_tile del x_tile_out, x_tile, t_tile, c_tile
# self.update_pbar() # self.update_pbar()
@@ -573,19 +612,22 @@ class MixtureOfDiffusers(AbstractDiffusion):
return x_out return x_out
MAX_RESOLUTION=8192 MAX_RESOLUTION = 8192
class TiledDiffusion(): class TiledDiffusion():
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ), return {"required": {"model": ("MODEL",),
"method": (["MultiDiffusion", "Mixture of Diffusers"], {"default": "Mixture of Diffusers"}), "method": (["MultiDiffusion", "Mixture of Diffusers"], {"default": "Mixture of Diffusers"}),
# "tile_width": ("INT", {"default": 96, "min": 16, "max": 256, "step": 16}), # "tile_width": ("INT", {"default": 96, "min": 16, "max": 256, "step": 16}),
"tile_width": ("INT", {"default": 96*opt_f, "min": 16, "max": MAX_RESOLUTION, "step": 16}), "tile_width": ("INT", {"default": 96 * opt_f, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
# "tile_height": ("INT", {"default": 96, "min": 16, "max": 256, "step": 16}), # "tile_height": ("INT", {"default": 96, "min": 16, "max": 256, "step": 16}),
"tile_height": ("INT", {"default": 96*opt_f, "min": 16, "max": MAX_RESOLUTION, "step": 16}), "tile_height": ("INT", {"default": 96 * opt_f, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
"tile_overlap": ("INT", {"default": 8*opt_f, "min": 0, "max": 256*opt_f, "step": 4*opt_f}), "tile_overlap": ("INT", {"default": 8 * opt_f, "min": 0, "max": 256 * opt_f, "step": 4 * opt_f}),
"tile_batch_size": ("INT", {"default": 4, "min": 1, "max": MAX_RESOLUTION, "step": 1}), "tile_batch_size": ("INT", {"default": 4, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "apply" FUNCTION = "apply"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
@@ -595,7 +637,7 @@ class TiledDiffusion():
implement = MixtureOfDiffusers() implement = MixtureOfDiffusers()
else: else:
implement = MultiDiffusion() implement = MultiDiffusion()
# if noise_inversion: # if noise_inversion:
# get_cache_callback = self.noise_inverse_get_cache # get_cache_callback = self.noise_inverse_get_cache
# set_cache_callback = None # lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, steps, retouch) # set_cache_callback = None # lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, steps, retouch)