diff --git a/backend/misc/tomesd.py b/backend/misc/tomesd.py new file mode 100644 index 00000000..ff8c6a5f --- /dev/null +++ b/backend/misc/tomesd.py @@ -0,0 +1,162 @@ +import torch +import math + +from typing import Tuple, Callable + + +def do_nothing(x: torch.Tensor, mode: str = None): + return x + + +def mps_gather_workaround(input, dim, index): + if input.shape[-1] == 1: + return torch.gather( + input.unsqueeze(-1), + dim - 1 if dim < 0 else dim, + index.unsqueeze(-1) + ).squeeze(-1) + else: + return torch.gather(input, dim, index) + + +def bipartite_soft_matching_random2d(metric: torch.Tensor, + w: int, h: int, sx: int, sy: int, r: int, + no_rand: bool = False) -> Tuple[Callable, Callable]: + """ + Partitions the tokens into src and dst and merges r tokens from src to dst. + Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. + Args: + - metric [B, N, C]: metric to use for similarity + - w: image width in tokens + - h: image height in tokens + - sx: stride in the x dimension for dst, must divide w + - sy: stride in the y dimension for dst, must divide h + - r: number of tokens to remove (by merging) + - no_rand: if true, disable randomness (use top left corner only) + """ + B, N, _ = metric.shape + + if r <= 0 or w == 1 or h == 1: + return do_nothing, do_nothing + + gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather + + with torch.no_grad(): + + hsy, wsx = h // sy, w // sx + + # For each sy by sx kernel, randomly assign one token to be dst and the rest src + if no_rand: + rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) + else: + rand_idx = torch.randint(sy * sx, size=(hsy, wsx, 1), device=metric.device) + + # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead + idx_buffer_view = torch.zeros(hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64) + idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) + idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) + + # Image is not divisible by sx or sy so we need to move it into a new buffer + if (hsy * sy) < h or (wsx * sx) < w: + idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64) + idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view + else: + idx_buffer = idx_buffer_view + + # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices + rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) + + # We're finished with these + del idx_buffer, idx_buffer_view + + # rand_idx is currently dst|src, so split them + num_dst = hsy * wsx + a_idx = rand_idx[:, num_dst:, :] # src + b_idx = rand_idx[:, :num_dst, :] # dst + + def split(x): + C = x.shape[-1] + src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) + dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) + return src, dst + + # Cosine similarity between A and B + metric = metric / metric.norm(dim=-1, keepdim=True) + a, b = split(metric) + scores = a @ b.transpose(-1, -2) + + # Can't reduce more than the # tokens in src + r = min(a.shape[1], r) + + # Find the most similar greedily + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) + + def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: + src, dst = split(x) + n, t1, c = src.shape + + unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) + src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) + dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) + + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor) -> torch.Tensor: + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + _, _, c = unm.shape + + src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) + + # Combine back to the original shape + out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) + out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src) + + return out + + return merge, unmerge + + +def get_functions(x, ratio, original_shape): + b, c, original_h, original_w = original_shape + original_tokens = original_h * original_w + downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) + stride_x = 2 + stride_y = 2 + max_downsample = 1 + + if downsample <= max_downsample: + w = int(math.ceil(original_w / downsample)) + h = int(math.ceil(original_h / downsample)) + r = int(x.shape[1] * ratio) + no_rand = False + m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) + return m, u + + nothing = lambda y: y + return nothing, nothing + + +class TomePatcher: + def __init__(self): + self.u = None + + def patch(self, model, ratio): + def tomesd_m(q, k, v, extra_options): + m, self.u = get_functions(q, ratio, extra_options["original_shape"]) + return m(q), k, v + + def tomesd_u(n, extra_options): + return self.u(n) + + m = model.clone() + m.set_model_attn1_patch(tomesd_m) + m.set_model_attn1_output_patch(tomesd_u) + return m diff --git a/modules/cmd_args.py b/modules/cmd_args.py index fdc30cf1..04569ff4 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -3,9 +3,7 @@ import json import os from modules.paths_internal import normalized_filepath, models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401 from pathlib import Path -from ldm_patched.modules import args_parser - -parser = args_parser.parser +from backend.args import parser parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program") diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 246c9b25..3d8a94a9 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -5,8 +5,8 @@ import torch import numpy as np from modules import modelloader, paths, deepbooru_model, images, shared -from ldm_patched.modules import model_management -from ldm_patched.modules.model_patcher import ModelPatcher +from backend import memory_management +from backend.patcher.base import ModelPatcher re_special = re.compile(r'([\\()])') @@ -15,11 +15,11 @@ re_special = re.compile(r'([\\()])') class DeepDanbooru: def __init__(self): self.model = None - self.load_device = model_management.text_encoder_device() - self.offload_device = model_management.text_encoder_offload_device() + self.load_device = memory_management.text_encoder_device() + self.offload_device = memory_management.text_encoder_offload_device() self.dtype = torch.float32 - if model_management.should_use_fp16(device=self.load_device): + if memory_management.should_use_fp16(device=self.load_device): self.dtype = torch.float16 self.patcher = None @@ -45,7 +45,7 @@ class DeepDanbooru: def start(self): self.load() - model_management.load_models_gpu([self.patcher]) + memory_management.load_models_gpu([self.patcher]) def stop(self): pass diff --git a/modules/devices.py b/modules/devices.py index 08d0d706..0bda9325 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,14 +1,14 @@ import contextlib import torch -import ldm_patched.modules.model_management as model_management +from backend import memory_management def has_xpu() -> bool: - return model_management.xpu_available + return memory_management.xpu_available def has_mps() -> bool: - return model_management.mps_mode() + return memory_management.mps_mode() def cuda_no_autocast(device_id=None) -> bool: @@ -16,27 +16,27 @@ def cuda_no_autocast(device_id=None) -> bool: def get_cuda_device_id(): - return model_management.get_torch_device().index + return memory_management.get_torch_device().index def get_cuda_device_string(): - return str(model_management.get_torch_device()) + return str(memory_management.get_torch_device()) def get_optimal_device_name(): - return model_management.get_torch_device().type + return memory_management.get_torch_device().type def get_optimal_device(): - return model_management.get_torch_device() + return memory_management.get_torch_device() def get_device_for(task): - return model_management.get_torch_device() + return memory_management.get_torch_device() def torch_gc(): - model_management.soft_empty_cache() + memory_management.soft_empty_cache() def torch_npu_set_device(): @@ -49,15 +49,15 @@ def enable_tf32(): cpu: torch.device = torch.device("cpu") fp8: bool = False -device: torch.device = model_management.get_torch_device() -device_interrogate: torch.device = model_management.text_encoder_device() # for backward compatibility, not used now -device_gfpgan: torch.device = model_management.get_torch_device() # will be managed by memory management system -device_esrgan: torch.device = model_management.get_torch_device() # will be managed by memory management system -device_codeformer: torch.device = model_management.get_torch_device() # will be managed by memory management system -dtype: torch.dtype = model_management.unet_dtype() -dtype_vae: torch.dtype = model_management.vae_dtype() -dtype_unet: torch.dtype = model_management.unet_dtype() -dtype_inference: torch.dtype = model_management.unet_dtype() +device: torch.device = memory_management.get_torch_device() +device_interrogate: torch.device = memory_management.text_encoder_device() # for backward compatibility, not used now +device_gfpgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system +device_esrgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system +device_codeformer: torch.device = memory_management.get_torch_device() # will be managed by memory management system +dtype: torch.dtype = memory_management.unet_dtype() +dtype_vae: torch.dtype = memory_management.vae_dtype() +dtype_unet: torch.dtype = memory_management.unet_dtype() +dtype_inference: torch.dtype = memory_management.unet_dtype() unet_needs_upcast = False diff --git a/modules/interrogate.py b/modules/interrogate.py index 270cc0aa..ae413b17 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -11,8 +11,8 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, modelloader, errors -from ldm_patched.modules import model_management -from ldm_patched.modules.model_patcher import ModelPatcher +from backend import memory_management +from backend.patcher.base import ModelPatcher blip_image_eval_size = 384 @@ -57,11 +57,11 @@ class InterrogateModels: self.skip_categories = [] self.content_dir = content_dir - self.load_device = model_management.text_encoder_device() - self.offload_device = model_management.text_encoder_offload_device() + self.load_device = memory_management.text_encoder_device() + self.offload_device = memory_management.text_encoder_offload_device() self.dtype = torch.float32 - if model_management.should_use_fp16(device=self.load_device): + if memory_management.should_use_fp16(device=self.load_device): self.dtype = torch.float16 self.blip_patcher = None @@ -137,7 +137,7 @@ class InterrogateModels: self.clip_model = self.clip_model.to(device=self.offload_device, dtype=self.dtype) self.clip_patcher = ModelPatcher(self.clip_model, load_device=self.load_device, offload_device=self.offload_device) - model_management.load_models_gpu([self.blip_patcher, self.clip_patcher]) + memory_management.load_models_gpu([self.blip_patcher, self.clip_patcher]) return def send_clip_to_ram(self): diff --git a/modules/paths.py b/modules/paths.py index 66e13f10..501ff658 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -63,13 +63,3 @@ for d, must_exist, what, options in path_dirs: else: sys.path.append(d) paths[what] = d - - -import ldm_patched.utils.path_utils as ldm_patched_path_utils - -ldm_patched_path_utils.base_path = data_path -ldm_patched_path_utils.models_dir = models_path -ldm_patched_path_utils.output_directory = os.path.join(data_path, "output") -ldm_patched_path_utils.temp_directory = os.path.join(data_path, "temp") -ldm_patched_path_utils.input_directory = os.path.join(data_path, "input") -ldm_patched_path_utils.user_directory = os.path.join(data_path, "user") diff --git a/modules/sd_models.py b/modules/sd_models.py index 538f5577..2c11e81d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,13 +14,11 @@ import ldm.modules.midas as midas import gc from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches +from modules.shared import opts from modules.timer import Timer import numpy as np from modules_forge import forge_loader -import modules_forge.ops as forge_ops -from ldm_patched.modules.ops import manual_cast -from ldm_patched.modules import model_management as model_management -import ldm_patched.modules.model_patcher +from backend import memory_management model_dir = "Stable-diffusion" @@ -650,8 +648,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): model_data.sd_model = None model_data.loaded_sd_models = [] - model_management.unload_all_models() - model_management.soft_empty_cache() + memory_management.unload_all_models() + memory_management.soft_empty_cache() gc.collect() timer.record("unload existing model") @@ -724,7 +722,7 @@ def apply_token_merging(sd_model, token_merging_ratio): print(f'token_merging_ratio = {token_merging_ratio}') - from ldm_patched.contrib.external_tomesd import TomePatcher + from backend.misc.tomesd import TomePatcher sd_model.forge_objects.unet = TomePatcher().patch( model=sd_model.forge_objects.unet, diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 4f8c7ee1..1b820592 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -8,7 +8,7 @@ import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser from modules import torch_utils -import ldm_patched.modules.model_management as model_management +from backend import memory_management from modules_forge.forge_clip import move_clip_to_gpu @@ -23,7 +23,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: is_negative_prompt = getattr(batch, 'is_negative_prompt', False) aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score - devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=model_management.text_encoder_dtype()) + devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype()) sdxl_conds = { "txt": batch,