mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
move to new backend - part 1
This commit is contained in:
162
backend/misc/tomesd.py
Normal file
162
backend/misc/tomesd.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
from typing import Tuple, Callable
|
||||||
|
|
||||||
|
|
||||||
|
def do_nothing(x: torch.Tensor, mode: str = None):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def mps_gather_workaround(input, dim, index):
|
||||||
|
if input.shape[-1] == 1:
|
||||||
|
return torch.gather(
|
||||||
|
input.unsqueeze(-1),
|
||||||
|
dim - 1 if dim < 0 else dim,
|
||||||
|
index.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
else:
|
||||||
|
return torch.gather(input, dim, index)
|
||||||
|
|
||||||
|
|
||||||
|
def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||||
|
w: int, h: int, sx: int, sy: int, r: int,
|
||||||
|
no_rand: bool = False) -> Tuple[Callable, Callable]:
|
||||||
|
"""
|
||||||
|
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
||||||
|
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
||||||
|
Args:
|
||||||
|
- metric [B, N, C]: metric to use for similarity
|
||||||
|
- w: image width in tokens
|
||||||
|
- h: image height in tokens
|
||||||
|
- sx: stride in the x dimension for dst, must divide w
|
||||||
|
- sy: stride in the y dimension for dst, must divide h
|
||||||
|
- r: number of tokens to remove (by merging)
|
||||||
|
- no_rand: if true, disable randomness (use top left corner only)
|
||||||
|
"""
|
||||||
|
B, N, _ = metric.shape
|
||||||
|
|
||||||
|
if r <= 0 or w == 1 or h == 1:
|
||||||
|
return do_nothing, do_nothing
|
||||||
|
|
||||||
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
hsy, wsx = h // sy, w // sx
|
||||||
|
|
||||||
|
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||||
|
if no_rand:
|
||||||
|
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
|
||||||
|
else:
|
||||||
|
rand_idx = torch.randint(sy * sx, size=(hsy, wsx, 1), device=metric.device)
|
||||||
|
|
||||||
|
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
|
||||||
|
idx_buffer_view = torch.zeros(hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64)
|
||||||
|
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
|
||||||
|
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
|
||||||
|
|
||||||
|
# Image is not divisible by sx or sy so we need to move it into a new buffer
|
||||||
|
if (hsy * sy) < h or (wsx * sx) < w:
|
||||||
|
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
|
||||||
|
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
|
||||||
|
else:
|
||||||
|
idx_buffer = idx_buffer_view
|
||||||
|
|
||||||
|
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
|
||||||
|
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
|
||||||
|
|
||||||
|
# We're finished with these
|
||||||
|
del idx_buffer, idx_buffer_view
|
||||||
|
|
||||||
|
# rand_idx is currently dst|src, so split them
|
||||||
|
num_dst = hsy * wsx
|
||||||
|
a_idx = rand_idx[:, num_dst:, :] # src
|
||||||
|
b_idx = rand_idx[:, :num_dst, :] # dst
|
||||||
|
|
||||||
|
def split(x):
|
||||||
|
C = x.shape[-1]
|
||||||
|
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
|
||||||
|
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
||||||
|
return src, dst
|
||||||
|
|
||||||
|
# Cosine similarity between A and B
|
||||||
|
metric = metric / metric.norm(dim=-1, keepdim=True)
|
||||||
|
a, b = split(metric)
|
||||||
|
scores = a @ b.transpose(-1, -2)
|
||||||
|
|
||||||
|
# Can't reduce more than the # tokens in src
|
||||||
|
r = min(a.shape[1], r)
|
||||||
|
|
||||||
|
# Find the most similar greedily
|
||||||
|
node_max, node_idx = scores.max(dim=-1)
|
||||||
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
||||||
|
|
||||||
|
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
||||||
|
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
||||||
|
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
|
||||||
|
|
||||||
|
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||||
|
src, dst = split(x)
|
||||||
|
n, t1, c = src.shape
|
||||||
|
|
||||||
|
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||||
|
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
|
||||||
|
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||||
|
|
||||||
|
return torch.cat([unm, dst], dim=1)
|
||||||
|
|
||||||
|
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
unm_len = unm_idx.shape[1]
|
||||||
|
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
||||||
|
_, _, c = unm.shape
|
||||||
|
|
||||||
|
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
|
||||||
|
|
||||||
|
# Combine back to the original shape
|
||||||
|
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
||||||
|
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
||||||
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
|
||||||
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
return merge, unmerge
|
||||||
|
|
||||||
|
|
||||||
|
def get_functions(x, ratio, original_shape):
|
||||||
|
b, c, original_h, original_w = original_shape
|
||||||
|
original_tokens = original_h * original_w
|
||||||
|
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
|
||||||
|
stride_x = 2
|
||||||
|
stride_y = 2
|
||||||
|
max_downsample = 1
|
||||||
|
|
||||||
|
if downsample <= max_downsample:
|
||||||
|
w = int(math.ceil(original_w / downsample))
|
||||||
|
h = int(math.ceil(original_h / downsample))
|
||||||
|
r = int(x.shape[1] * ratio)
|
||||||
|
no_rand = False
|
||||||
|
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
|
||||||
|
return m, u
|
||||||
|
|
||||||
|
nothing = lambda y: y
|
||||||
|
return nothing, nothing
|
||||||
|
|
||||||
|
|
||||||
|
class TomePatcher:
|
||||||
|
def __init__(self):
|
||||||
|
self.u = None
|
||||||
|
|
||||||
|
def patch(self, model, ratio):
|
||||||
|
def tomesd_m(q, k, v, extra_options):
|
||||||
|
m, self.u = get_functions(q, ratio, extra_options["original_shape"])
|
||||||
|
return m(q), k, v
|
||||||
|
|
||||||
|
def tomesd_u(n, extra_options):
|
||||||
|
return self.u(n)
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_attn1_patch(tomesd_m)
|
||||||
|
m.set_model_attn1_output_patch(tomesd_u)
|
||||||
|
return m
|
||||||
@@ -3,9 +3,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from modules.paths_internal import normalized_filepath, models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
from modules.paths_internal import normalized_filepath, models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ldm_patched.modules import args_parser
|
from backend.args import parser
|
||||||
|
|
||||||
parser = args_parser.parser
|
|
||||||
|
|
||||||
parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
|
parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
|
||||||
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
|
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modules import modelloader, paths, deepbooru_model, images, shared
|
from modules import modelloader, paths, deepbooru_model, images, shared
|
||||||
from ldm_patched.modules import model_management
|
from backend import memory_management
|
||||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
from backend.patcher.base import ModelPatcher
|
||||||
|
|
||||||
|
|
||||||
re_special = re.compile(r'([\\()])')
|
re_special = re.compile(r'([\\()])')
|
||||||
@@ -15,11 +15,11 @@ re_special = re.compile(r'([\\()])')
|
|||||||
class DeepDanbooru:
|
class DeepDanbooru:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model = None
|
self.model = None
|
||||||
self.load_device = model_management.text_encoder_device()
|
self.load_device = memory_management.text_encoder_device()
|
||||||
self.offload_device = model_management.text_encoder_offload_device()
|
self.offload_device = memory_management.text_encoder_offload_device()
|
||||||
self.dtype = torch.float32
|
self.dtype = torch.float32
|
||||||
|
|
||||||
if model_management.should_use_fp16(device=self.load_device):
|
if memory_management.should_use_fp16(device=self.load_device):
|
||||||
self.dtype = torch.float16
|
self.dtype = torch.float16
|
||||||
|
|
||||||
self.patcher = None
|
self.patcher = None
|
||||||
@@ -45,7 +45,7 @@ class DeepDanbooru:
|
|||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
self.load()
|
self.load()
|
||||||
model_management.load_models_gpu([self.patcher])
|
memory_management.load_models_gpu([self.patcher])
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import torch
|
import torch
|
||||||
import ldm_patched.modules.model_management as model_management
|
from backend import memory_management
|
||||||
|
|
||||||
|
|
||||||
def has_xpu() -> bool:
|
def has_xpu() -> bool:
|
||||||
return model_management.xpu_available
|
return memory_management.xpu_available
|
||||||
|
|
||||||
|
|
||||||
def has_mps() -> bool:
|
def has_mps() -> bool:
|
||||||
return model_management.mps_mode()
|
return memory_management.mps_mode()
|
||||||
|
|
||||||
|
|
||||||
def cuda_no_autocast(device_id=None) -> bool:
|
def cuda_no_autocast(device_id=None) -> bool:
|
||||||
@@ -16,27 +16,27 @@ def cuda_no_autocast(device_id=None) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def get_cuda_device_id():
|
def get_cuda_device_id():
|
||||||
return model_management.get_torch_device().index
|
return memory_management.get_torch_device().index
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_device_string():
|
def get_cuda_device_string():
|
||||||
return str(model_management.get_torch_device())
|
return str(memory_management.get_torch_device())
|
||||||
|
|
||||||
|
|
||||||
def get_optimal_device_name():
|
def get_optimal_device_name():
|
||||||
return model_management.get_torch_device().type
|
return memory_management.get_torch_device().type
|
||||||
|
|
||||||
|
|
||||||
def get_optimal_device():
|
def get_optimal_device():
|
||||||
return model_management.get_torch_device()
|
return memory_management.get_torch_device()
|
||||||
|
|
||||||
|
|
||||||
def get_device_for(task):
|
def get_device_for(task):
|
||||||
return model_management.get_torch_device()
|
return memory_management.get_torch_device()
|
||||||
|
|
||||||
|
|
||||||
def torch_gc():
|
def torch_gc():
|
||||||
model_management.soft_empty_cache()
|
memory_management.soft_empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def torch_npu_set_device():
|
def torch_npu_set_device():
|
||||||
@@ -49,15 +49,15 @@ def enable_tf32():
|
|||||||
|
|
||||||
cpu: torch.device = torch.device("cpu")
|
cpu: torch.device = torch.device("cpu")
|
||||||
fp8: bool = False
|
fp8: bool = False
|
||||||
device: torch.device = model_management.get_torch_device()
|
device: torch.device = memory_management.get_torch_device()
|
||||||
device_interrogate: torch.device = model_management.text_encoder_device() # for backward compatibility, not used now
|
device_interrogate: torch.device = memory_management.text_encoder_device() # for backward compatibility, not used now
|
||||||
device_gfpgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
device_gfpgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||||
device_esrgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
device_esrgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||||
device_codeformer: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
device_codeformer: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||||
dtype: torch.dtype = model_management.unet_dtype()
|
dtype: torch.dtype = memory_management.unet_dtype()
|
||||||
dtype_vae: torch.dtype = model_management.vae_dtype()
|
dtype_vae: torch.dtype = memory_management.vae_dtype()
|
||||||
dtype_unet: torch.dtype = model_management.unet_dtype()
|
dtype_unet: torch.dtype = memory_management.unet_dtype()
|
||||||
dtype_inference: torch.dtype = model_management.unet_dtype()
|
dtype_inference: torch.dtype = memory_management.unet_dtype()
|
||||||
unet_needs_upcast = False
|
unet_needs_upcast = False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ from torchvision import transforms
|
|||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
from modules import devices, paths, shared, modelloader, errors
|
from modules import devices, paths, shared, modelloader, errors
|
||||||
from ldm_patched.modules import model_management
|
from backend import memory_management
|
||||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
from backend.patcher.base import ModelPatcher
|
||||||
|
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
@@ -57,11 +57,11 @@ class InterrogateModels:
|
|||||||
self.skip_categories = []
|
self.skip_categories = []
|
||||||
self.content_dir = content_dir
|
self.content_dir = content_dir
|
||||||
|
|
||||||
self.load_device = model_management.text_encoder_device()
|
self.load_device = memory_management.text_encoder_device()
|
||||||
self.offload_device = model_management.text_encoder_offload_device()
|
self.offload_device = memory_management.text_encoder_offload_device()
|
||||||
self.dtype = torch.float32
|
self.dtype = torch.float32
|
||||||
|
|
||||||
if model_management.should_use_fp16(device=self.load_device):
|
if memory_management.should_use_fp16(device=self.load_device):
|
||||||
self.dtype = torch.float16
|
self.dtype = torch.float16
|
||||||
|
|
||||||
self.blip_patcher = None
|
self.blip_patcher = None
|
||||||
@@ -137,7 +137,7 @@ class InterrogateModels:
|
|||||||
self.clip_model = self.clip_model.to(device=self.offload_device, dtype=self.dtype)
|
self.clip_model = self.clip_model.to(device=self.offload_device, dtype=self.dtype)
|
||||||
self.clip_patcher = ModelPatcher(self.clip_model, load_device=self.load_device, offload_device=self.offload_device)
|
self.clip_patcher = ModelPatcher(self.clip_model, load_device=self.load_device, offload_device=self.offload_device)
|
||||||
|
|
||||||
model_management.load_models_gpu([self.blip_patcher, self.clip_patcher])
|
memory_management.load_models_gpu([self.blip_patcher, self.clip_patcher])
|
||||||
return
|
return
|
||||||
|
|
||||||
def send_clip_to_ram(self):
|
def send_clip_to_ram(self):
|
||||||
|
|||||||
@@ -63,13 +63,3 @@ for d, must_exist, what, options in path_dirs:
|
|||||||
else:
|
else:
|
||||||
sys.path.append(d)
|
sys.path.append(d)
|
||||||
paths[what] = d
|
paths[what] = d
|
||||||
|
|
||||||
|
|
||||||
import ldm_patched.utils.path_utils as ldm_patched_path_utils
|
|
||||||
|
|
||||||
ldm_patched_path_utils.base_path = data_path
|
|
||||||
ldm_patched_path_utils.models_dir = models_path
|
|
||||||
ldm_patched_path_utils.output_directory = os.path.join(data_path, "output")
|
|
||||||
ldm_patched_path_utils.temp_directory = os.path.join(data_path, "temp")
|
|
||||||
ldm_patched_path_utils.input_directory = os.path.join(data_path, "input")
|
|
||||||
ldm_patched_path_utils.user_directory = os.path.join(data_path, "user")
|
|
||||||
|
|||||||
@@ -14,13 +14,11 @@ import ldm.modules.midas as midas
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||||
|
from modules.shared import opts
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from modules_forge import forge_loader
|
from modules_forge import forge_loader
|
||||||
import modules_forge.ops as forge_ops
|
from backend import memory_management
|
||||||
from ldm_patched.modules.ops import manual_cast
|
|
||||||
from ldm_patched.modules import model_management as model_management
|
|
||||||
import ldm_patched.modules.model_patcher
|
|
||||||
|
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
@@ -650,8 +648,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
|
|
||||||
model_data.sd_model = None
|
model_data.sd_model = None
|
||||||
model_data.loaded_sd_models = []
|
model_data.loaded_sd_models = []
|
||||||
model_management.unload_all_models()
|
memory_management.unload_all_models()
|
||||||
model_management.soft_empty_cache()
|
memory_management.soft_empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
timer.record("unload existing model")
|
timer.record("unload existing model")
|
||||||
@@ -724,7 +722,7 @@ def apply_token_merging(sd_model, token_merging_ratio):
|
|||||||
|
|
||||||
print(f'token_merging_ratio = {token_merging_ratio}')
|
print(f'token_merging_ratio = {token_merging_ratio}')
|
||||||
|
|
||||||
from ldm_patched.contrib.external_tomesd import TomePatcher
|
from backend.misc.tomesd import TomePatcher
|
||||||
|
|
||||||
sd_model.forge_objects.unet = TomePatcher().patch(
|
sd_model.forge_objects.unet = TomePatcher().patch(
|
||||||
model=sd_model.forge_objects.unet,
|
model=sd_model.forge_objects.unet,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import sgm.modules.diffusionmodules.discretizer
|
|||||||
from modules import devices, shared, prompt_parser
|
from modules import devices, shared, prompt_parser
|
||||||
from modules import torch_utils
|
from modules import torch_utils
|
||||||
|
|
||||||
import ldm_patched.modules.model_management as model_management
|
from backend import memory_management
|
||||||
from modules_forge.forge_clip import move_clip_to_gpu
|
from modules_forge.forge_clip import move_clip_to_gpu
|
||||||
|
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
|||||||
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||||
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||||
|
|
||||||
devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=model_management.text_encoder_dtype())
|
devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
|
||||||
|
|
||||||
sdxl_conds = {
|
sdxl_conds = {
|
||||||
"txt": batch,
|
"txt": batch,
|
||||||
|
|||||||
Reference in New Issue
Block a user