diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index cd68727c..9164de9f 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -7,7 +7,6 @@ import torch import modules.scripts as scripts from modules import shared, script_callbacks, masking, images from modules.ui_components import InputAccordion -from modules.api.api import decode_base64_to_image import gradio as gr from lib_controlnet import global_state, external_code diff --git a/modules/api/api.py b/modules/api/api.py index 78d10969..25ce7ca0 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -24,7 +24,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin -from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import Any @@ -725,7 +724,7 @@ class Api: def get_sd_models(self): import modules.sd_models as sd_models - return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename} for x in sd_models.checkpoints_list.values()] def get_sd_vaes(self): import modules.sd_vae as sd_vae diff --git a/modules/initialize.py b/modules/initialize.py index ec4d58a4..dd55d6c3 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -10,16 +10,6 @@ from threading import Thread from modules.timer import startup_timer -class HiddenPrints: - def __enter__(self): - self._original_stdout = sys.stdout - sys.stdout = open(os.devnull, 'w') - - def __exit__(self, exc_type, exc_val, exc_tb): - sys.stdout.close() - sys.stdout = self._original_stdout - - def imports(): logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) @@ -35,16 +25,8 @@ def imports(): import gradio # noqa: F401 startup_timer.record("import gradio") - with HiddenPrints(): - from modules import paths, timer, import_hook, errors # noqa: F401 - startup_timer.record("setup paths") - - import ldm.modules.encoders.modules # noqa: F401 - import ldm.modules.diffusionmodules.model - startup_timer.record("import ldm") - - import sgm.modules.encoders.modules # noqa: F401 - startup_timer.record("import sgm") + from modules import paths, timer, import_hook, errors # noqa: F401 + startup_timer.record("setup paths") from modules import shared_init shared_init.initialize() @@ -141,11 +123,6 @@ def initialize_rest(*, reload_script_modules=False): textual_inversion.textual_inversion.list_textual_inversion_templates() startup_timer.record("refresh textual inversion templates") - from modules import script_callbacks, sd_hijack_optimizations, sd_hijack - script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers) - sd_hijack.list_optimizers() - startup_timer.record("scripts list_optimizers") - from modules import sd_unet sd_unet.list_unets() startup_timer.record("scripts list_unets") diff --git a/modules/launch_utils.py b/modules/launch_utils.py index f933de64..8c1823a1 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -391,15 +391,15 @@ def prepare_environment(): openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git") - stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") - stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") + # stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") + # stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917") - stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") - stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") + # stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") + # stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") @@ -456,8 +456,8 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash) - git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) - git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) + # git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) + # git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) diff --git a/modules/paths.py b/modules/paths.py index 501ff658..18494b6c 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -36,8 +36,8 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl mute_sdxl_imports() path_dirs = [ - (sd_path, 'ldm', 'Stable Diffusion', []), - (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), + # (sd_path, 'ldm', 'Stable Diffusion', []), + # (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []), @@ -53,13 +53,13 @@ for d, must_exist, what, options in path_dirs: d = os.path.abspath(d) if "atstart" in options: sys.path.insert(0, d) - elif "sgm" in options: - # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we - # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. - - sys.path.insert(0, d) - import sgm # noqa: F401 - sys.path.pop(0) + # elif "sgm" in options: + # # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we + # # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. + # + # sys.path.insert(0, d) + # import sgm # noqa: F401 + # sys.path.pop(0) else: sys.path.append(d) paths[what] = d diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f292073b..eb06a849 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -1,124 +1,3 @@ -import torch -from torch.nn.functional import silu -from types import MethodType - -from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches -from modules.hypernetworks import hypernetwork -from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 - -import ldm.modules.attention -import ldm.modules.diffusionmodules.model -import ldm.modules.diffusionmodules.openaimodel -import ldm.models.diffusion.ddpm -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms -import ldm.modules.encoders.modules - -import sgm.modules.attention -import sgm.modules.diffusionmodules.model -import sgm.modules.diffusionmodules.openaimodel -import sgm.modules.encoders.modules - -attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward -diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity -diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward - -# new memory efficient cross attention blocks do not support hypernets and we already -# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention -ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention -ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention - -# silence new console spam from SD2 -ldm.modules.attention.print = shared.ldm_print -ldm.modules.diffusionmodules.model.print = shared.ldm_print -ldm.util.print = shared.ldm_print -ldm.models.diffusion.ddpm.print = shared.ldm_print - -optimizers = [] -current_optimizer: sd_hijack_optimizations.SdOptimization = None - -ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward) -ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward) - -sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward) -sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward) - - -def list_optimizers(): - new_optimizers = script_callbacks.list_optimizers_callback() - - new_optimizers = [x for x in new_optimizers if x.is_available()] - - new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True) - - optimizers.clear() - optimizers.extend(new_optimizers) - - -def apply_optimizations(option=None): - return - - -def undo_optimizations(): - return - - -def fix_checkpoint(): - """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want - checkpoints to be added when not training (there's a warning)""" - - pass - - -def weighted_loss(sd_model, pred, target, mean=True): - #Calculate the weight normally, but ignore the mean - loss = sd_model._old_get_loss(pred, target, mean=False) - - #Check if we have weights available - weight = getattr(sd_model, '_custom_loss_weight', None) - if weight is not None: - loss *= weight - - #Return the loss, as mean if specified - return loss.mean() if mean else loss - -def weighted_forward(sd_model, x, c, w, *args, **kwargs): - try: - #Temporarily append weights to a place accessible during loss calc - sd_model._custom_loss_weight = w - - #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely - #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set - if not hasattr(sd_model, '_old_get_loss'): - sd_model._old_get_loss = sd_model.get_loss - sd_model.get_loss = MethodType(weighted_loss, sd_model) - - #Run the standard forward function, but with the patched 'get_loss' - return sd_model.forward(x, c, *args, **kwargs) - finally: - try: - #Delete temporary weights if appended - del sd_model._custom_loss_weight - except AttributeError: - pass - - #If we have an old loss function, reset the loss function to the original one - if hasattr(sd_model, '_old_get_loss'): - sd_model.get_loss = sd_model._old_get_loss - del sd_model._old_get_loss - -def apply_weighted_forward(sd_model): - #Add new function 'weighted_forward' that can be called to calc weighted loss - sd_model.weighted_forward = MethodType(weighted_forward, sd_model) - -def undo_weighted_forward(sd_model): - try: - del sd_model.weighted_forward - except AttributeError: - pass - - class StableDiffusionModelHijack: fixes = None layers = None @@ -156,74 +35,201 @@ class StableDiffusionModelHijack: pass -class EmbeddingsWithFixes(torch.nn.Module): - def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'): - super().__init__() - self.wrapped = wrapped - self.embeddings = embeddings - self.textual_inversion_key = textual_inversion_key - self.weight = self.wrapped.weight - - def forward(self, input_ids): - batch_fixes = self.embeddings.fixes - self.embeddings.fixes = None - - inputs_embeds = self.wrapped(input_ids) - - if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: - return inputs_embeds - - vecs = [] - for fixes, tensor in zip(batch_fixes, inputs_embeds): - for offset, embedding in fixes: - vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec - emb = devices.cond_cast_unet(vec) - emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) - tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype) - - vecs.append(tensor) - - return torch.stack(vecs) - - -class TextualInversionEmbeddings(torch.nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs): - super().__init__(num_embeddings, embedding_dim, **kwargs) - - self.embeddings = model_hijack - self.textual_inversion_key = textual_inversion_key - - @property - def wrapped(self): - return super().forward - - def forward(self, input_ids): - return EmbeddingsWithFixes.forward(self, input_ids) - - -def add_circular_option_to_conv_2d(): - conv2d_constructor = torch.nn.Conv2d.__init__ - - def conv2d_constructor_circular(self, *args, **kwargs): - return conv2d_constructor(self, *args, padding_mode='circular', **kwargs) - - torch.nn.Conv2d.__init__ = conv2d_constructor_circular - - model_hijack = StableDiffusionModelHijack() -def register_buffer(self, name, attr): - """ - Fix register buffer bug for Mac OS. - """ - - if type(attr) == torch.Tensor: - if attr.device != devices.device: - attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None)) - - setattr(self, name, attr) - - -ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer -ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer +# import torch +# from torch.nn.functional import silu +# from types import MethodType +# +# from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches +# from modules.hypernetworks import hypernetwork +# from modules.shared import cmd_opts +# from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 +# +# import ldm.modules.attention +# import ldm.modules.diffusionmodules.model +# import ldm.modules.diffusionmodules.openaimodel +# import ldm.models.diffusion.ddpm +# import ldm.models.diffusion.ddim +# import ldm.models.diffusion.plms +# import ldm.modules.encoders.modules +# +# import sgm.modules.attention +# import sgm.modules.diffusionmodules.model +# import sgm.modules.diffusionmodules.openaimodel +# import sgm.modules.encoders.modules +# +# attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward +# diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity +# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +# +# # new memory efficient cross attention blocks do not support hypernets and we already +# # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention +# ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention +# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention +# +# # silence new console spam from SD2 +# ldm.modules.attention.print = shared.ldm_print +# ldm.modules.diffusionmodules.model.print = shared.ldm_print +# ldm.util.print = shared.ldm_print +# ldm.models.diffusion.ddpm.print = shared.ldm_print +# +# optimizers = [] +# current_optimizer: sd_hijack_optimizations.SdOptimization = None +# +# ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward) +# ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward) +# +# sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward) +# sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward) +# +# +# def list_optimizers(): +# new_optimizers = script_callbacks.list_optimizers_callback() +# +# new_optimizers = [x for x in new_optimizers if x.is_available()] +# +# new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True) +# +# optimizers.clear() +# optimizers.extend(new_optimizers) +# +# +# def apply_optimizations(option=None): +# return +# +# +# def undo_optimizations(): +# return +# +# +# def fix_checkpoint(): +# """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want +# checkpoints to be added when not training (there's a warning)""" +# +# pass +# +# +# def weighted_loss(sd_model, pred, target, mean=True): +# #Calculate the weight normally, but ignore the mean +# loss = sd_model._old_get_loss(pred, target, mean=False) +# +# #Check if we have weights available +# weight = getattr(sd_model, '_custom_loss_weight', None) +# if weight is not None: +# loss *= weight +# +# #Return the loss, as mean if specified +# return loss.mean() if mean else loss +# +# def weighted_forward(sd_model, x, c, w, *args, **kwargs): +# try: +# #Temporarily append weights to a place accessible during loss calc +# sd_model._custom_loss_weight = w +# +# #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely +# #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set +# if not hasattr(sd_model, '_old_get_loss'): +# sd_model._old_get_loss = sd_model.get_loss +# sd_model.get_loss = MethodType(weighted_loss, sd_model) +# +# #Run the standard forward function, but with the patched 'get_loss' +# return sd_model.forward(x, c, *args, **kwargs) +# finally: +# try: +# #Delete temporary weights if appended +# del sd_model._custom_loss_weight +# except AttributeError: +# pass +# +# #If we have an old loss function, reset the loss function to the original one +# if hasattr(sd_model, '_old_get_loss'): +# sd_model.get_loss = sd_model._old_get_loss +# del sd_model._old_get_loss +# +# def apply_weighted_forward(sd_model): +# #Add new function 'weighted_forward' that can be called to calc weighted loss +# sd_model.weighted_forward = MethodType(weighted_forward, sd_model) +# +# def undo_weighted_forward(sd_model): +# try: +# del sd_model.weighted_forward +# except AttributeError: +# pass +# +# +# +# +# +# class EmbeddingsWithFixes(torch.nn.Module): +# def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'): +# super().__init__() +# self.wrapped = wrapped +# self.embeddings = embeddings +# self.textual_inversion_key = textual_inversion_key +# self.weight = self.wrapped.weight +# +# def forward(self, input_ids): +# batch_fixes = self.embeddings.fixes +# self.embeddings.fixes = None +# +# inputs_embeds = self.wrapped(input_ids) +# +# if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: +# return inputs_embeds +# +# vecs = [] +# for fixes, tensor in zip(batch_fixes, inputs_embeds): +# for offset, embedding in fixes: +# vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec +# emb = devices.cond_cast_unet(vec) +# emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) +# tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype) +# +# vecs.append(tensor) +# +# return torch.stack(vecs) +# +# +# class TextualInversionEmbeddings(torch.nn.Embedding): +# def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs): +# super().__init__(num_embeddings, embedding_dim, **kwargs) +# +# self.embeddings = model_hijack +# self.textual_inversion_key = textual_inversion_key +# +# @property +# def wrapped(self): +# return super().forward +# +# def forward(self, input_ids): +# return EmbeddingsWithFixes.forward(self, input_ids) +# +# +# def add_circular_option_to_conv_2d(): +# conv2d_constructor = torch.nn.Conv2d.__init__ +# +# def conv2d_constructor_circular(self, *args, **kwargs): +# return conv2d_constructor(self, *args, padding_mode='circular', **kwargs) +# +# torch.nn.Conv2d.__init__ = conv2d_constructor_circular +# +# +# model_hijack = StableDiffusionModelHijack() +# +# +# def register_buffer(self, name, attr): +# """ +# Fix register buffer bug for Mac OS. +# """ +# +# if type(attr) == torch.Tensor: +# if attr.device != devices.device: +# attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None)) +# +# setattr(self, name, attr) +# +# +# ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer +# ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 0269f1f5..696835ad 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,677 +1,677 @@ -from __future__ import annotations -import math -import psutil -import platform - -import torch -from torch import einsum - -from ldm.util import default -from einops import rearrange - -from modules import shared, errors, devices, sub_quadratic_attention -from modules.hypernetworks import hypernetwork - -import ldm.modules.attention -import ldm.modules.diffusionmodules.model - -import sgm.modules.attention -import sgm.modules.diffusionmodules.model - -diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward -sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward - - -class SdOptimization: - name: str = None - label: str | None = None - cmd_opt: str | None = None - priority: int = 0 - - def title(self): - if self.label is None: - return self.name - - return f"{self.name} - {self.label}" - - def is_available(self): - return True - - def apply(self): - pass - - def undo(self): - ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward - - sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward - sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward - - -class SdOptimizationXformers(SdOptimization): - name = "xformers" - cmd_opt = "xformers" - priority = 100 - - def is_available(self): - return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)) - - def apply(self): - ldm.modules.attention.CrossAttention.forward = xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward - sgm.modules.attention.CrossAttention.forward = xformers_attention_forward - sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward - - -class SdOptimizationSdpNoMem(SdOptimization): - name = "sdp-no-mem" - label = "scaled dot product without memory efficient attention" - cmd_opt = "opt_sdp_no_mem_attention" - priority = 80 - - def is_available(self): - return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) - - def apply(self): - ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward - sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward - sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward - - -class SdOptimizationSdp(SdOptimizationSdpNoMem): - name = "sdp" - label = "scaled dot product" - cmd_opt = "opt_sdp_attention" - priority = 70 - - def apply(self): - ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward - sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward - sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward - - -class SdOptimizationSubQuad(SdOptimization): - name = "sub-quadratic" - cmd_opt = "opt_sub_quad_attention" - - @property - def priority(self): - return 1000 if shared.device.type == 'mps' else 10 - - def apply(self): - ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward - sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward - sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward - - -class SdOptimizationV1(SdOptimization): - name = "V1" - label = "original v1" - cmd_opt = "opt_split_attention_v1" - priority = 10 - - def apply(self): - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 - sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 - - -class SdOptimizationInvokeAI(SdOptimization): - name = "InvokeAI" - cmd_opt = "opt_split_attention_invokeai" - - @property - def priority(self): - return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10 - - def apply(self): - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI - sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI - - -class SdOptimizationDoggettx(SdOptimization): - name = "Doggettx" - cmd_opt = "opt_split_attention" - priority = 90 - - def apply(self): - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward - sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward - sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward - - -def list_optimizers(res): - res.extend([ - SdOptimizationXformers(), - SdOptimizationSdpNoMem(), - SdOptimizationSdp(), - SdOptimizationSubQuad(), - SdOptimizationV1(), - SdOptimizationInvokeAI(), - SdOptimizationDoggettx(), - ]) - - -if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: - try: - import xformers.ops - shared.xformers_available = True - except Exception: - errors.report("Cannot import xformers", exc_info=True) - - -def get_available_vram(): - if shared.device.type == 'cuda': - stats = torch.cuda.memory_stats(shared.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - return mem_free_total - else: - return psutil.virtual_memory().available - - -# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - - context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) - del context, context_k, context_v, x - - q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) - del q_in, k_in, v_in - - dtype = q.dtype - if shared.opts.upcast_attn: - q, k, v = q.float(), k.float(), v.float() - - with devices.without_autocast(disable=not shared.opts.upcast_attn): - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale - - s2 = s1.softmax(dim=-1) - del s1 - - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - del q, k, v - - r1 = r1.to(dtype) - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - - -# taken from https://github.com/Doggettx/stable-diffusion and modified -def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - - context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) - - dtype = q_in.dtype - if shared.opts.upcast_attn: - q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() - - with devices.without_autocast(disable=not shared.opts.upcast_attn): - k_in = k_in * self.scale - - del context, x - - q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - mem_free_total = get_available_vram() - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps - for i in range(0, q.shape[1], slice_size): - end = min(i + slice_size, q.shape[1]) - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - - del q, k, v - - r1 = r1.to(dtype) - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - - -# -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- -mem_total_gb = psutil.virtual_memory().total // (1 << 30) - - -def einsum_op_compvis(q, k, v): - s = einsum('b i d, b j d -> b i j', q, k) - s = s.softmax(dim=-1, dtype=s.dtype) - return einsum('b i j, b j d -> b i d', s, v) - - -def einsum_op_slice_0(q, k, v, slice_size): - r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[0], slice_size): - end = i + slice_size - r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) - return r - - -def einsum_op_slice_1(q, k, v, slice_size): - r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) - return r - - -def einsum_op_mps_v1(q, k, v): - if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 - return einsum_op_compvis(q, k, v) - else: - slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - if slice_size % 4096 == 0: - slice_size -= 1 - return einsum_op_slice_1(q, k, v, slice_size) - - -def einsum_op_mps_v2(q, k, v): - if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: - return einsum_op_compvis(q, k, v) - else: - return einsum_op_slice_0(q, k, v, 1) - - -def einsum_op_tensor_mem(q, k, v, max_tensor_mb): - size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) - if size_mb <= max_tensor_mb: - return einsum_op_compvis(q, k, v) - div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() - if div <= q.shape[0]: - return einsum_op_slice_0(q, k, v, q.shape[0] // div) - return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) - - -def einsum_op_cuda(q, k, v): - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - # Divide factor of safety as there's copying and fragmentation - return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) - - -def einsum_op(q, k, v): - if q.device.type == 'cuda': - return einsum_op_cuda(q, k, v) - - if q.device.type == 'mps': - if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18: - return einsum_op_mps_v1(q, k, v) - return einsum_op_mps_v2(q, k, v) - - # Smaller slices are faster due to L2/L3/SLC caches. - # Tested on i7 with 8MB L3 cache. - return einsum_op_tensor_mem(q, k, v, 32) - - -def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, context_k, context_v, x - - dtype = q.dtype - if shared.opts.upcast_attn: - q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() - - with devices.without_autocast(disable=not shared.opts.upcast_attn): - k = k * self.scale - - q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v)) - r = einsum_op(q, k, v) - r = r.to(dtype) - return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) - -# -- End of code from https://github.com/invoke-ai/InvokeAI -- - - -# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 -# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface -def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs): - assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." - - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, context_k, context_v, x - - q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) - k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) - v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) - - if q.device.type == 'mps': - q, k, v = q.contiguous(), k.contiguous(), v.contiguous() - - dtype = q.dtype - if shared.opts.upcast_attn: - q, k = q.float(), k.float() - - x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) - - x = x.to(dtype) - - x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) - - out_proj, dropout = self.to_out - x = out_proj(x) - x = dropout(x) - - return x - - -def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): - bytes_per_token = torch.finfo(q.dtype).bits//8 - batch_x_heads, q_tokens, _ = q.shape - _, k_tokens, _ = k.shape - qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - - if chunk_threshold is None: - if q.device.type == 'mps': - chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token) - else: - chunk_threshold_bytes = int(get_available_vram() * 0.7) - elif chunk_threshold == 0: - chunk_threshold_bytes = None - else: - chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) - - if kv_chunk_size_min is None and chunk_threshold_bytes is not None: - kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) - elif kv_chunk_size_min == 0: - kv_chunk_size_min = None - - if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: - # the big matmul fits into our memory limit; do everything in 1 chunk, - # i.e. send it down the unchunked fast-path - kv_chunk_size = k_tokens - - with devices.without_autocast(disable=q.dtype == v.dtype): - return sub_quadratic_attention.efficient_dot_product_attention( - q, - k, - v, - query_chunk_size=q_chunk_size, - kv_chunk_size=kv_chunk_size, - kv_chunk_size_min = kv_chunk_size_min, - use_checkpoint=use_checkpoint, - ) - - -def get_xformers_flash_attention_op(q, k, v): - if not shared.cmd_opts.xformers_flash_attention: - return None - - try: - flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp - fw, bw = flash_attention_op - if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): - return flash_attention_op - except Exception as e: - errors.display_once(e, "enabling flash attention") - - return None - - -def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): - h = self.heads - q_in = self.to_q(x) - context = default(context, x) - - context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) - - q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in)) - - del q_in, k_in, v_in - - dtype = q.dtype - if shared.opts.upcast_attn: - q, k, v = q.float(), k.float(), v.float() - - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) - - out = out.to(dtype) - - b, n, h, d = out.shape - out = out.reshape(b, n, h * d) - return self.to_out(out) - - -# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py -# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface -def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs): - batch_size, sequence_length, inner_dim = x.shape - - if mask is not None: - mask = self.prepare_attention_mask(mask, sequence_length, batch_size) - mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) - - h = self.heads - q_in = self.to_q(x) - context = default(context, x) - - context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) - - head_dim = inner_dim // h - q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) - k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) - v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) - - del q_in, k_in, v_in - - dtype = q.dtype - if shared.opts.upcast_attn: - q, k, v = q.float(), k.float(), v.float() - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - hidden_states = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim) - hidden_states = hidden_states.to(dtype) - - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - return hidden_states - - -def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs): - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): - return scaled_dot_product_attention_forward(self, x, context, mask) - - -def cross_attention_attnblock_forward(self, x): - h_ = x - h_ = self.norm(h_) - q1 = self.q(h_) - k1 = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q1.shape - - q2 = q1.reshape(b, c, h*w) - del q1 - - q = q2.permute(0, 2, 1) # b,hw,c - del q2 - - k = k1.reshape(b, c, h*w) # b,c,hw - del k1 - - h_ = torch.zeros_like(k, device=q.device) - - mem_free_total = get_available_vram() - - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - mem_required = tensor_size * 2.5 - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - - w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w2 = w1 * (int(c)**(-0.5)) - del w1 - w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) - del w2 - - # attend to values - v1 = v.reshape(b, c, h*w) - w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w3 - - h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w4 - - h2 = h_.reshape(b, c, h, w) - del h_ - - h3 = self.proj_out(h2) - del h2 - - h3 += x - - return h3 - - -def xformers_attnblock_forward(self, x): - try: - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - b, c, h, w = q.shape - q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) - dtype = q.dtype - if shared.opts.upcast_attn: - q, k = q.float(), k.float() - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) - out = out.to(dtype) - out = rearrange(out, 'b (h w) c -> b c h w', h=h) - out = self.proj_out(out) - return x + out - except NotImplementedError: - return cross_attention_attnblock_forward(self, x) - - -def sdp_attnblock_forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - b, c, h, w = q.shape - q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) - dtype = q.dtype - if shared.opts.upcast_attn: - q, k, v = q.float(), k.float(), v.float() - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) - out = out.to(dtype) - out = rearrange(out, 'b (h w) c -> b c h w', h=h) - out = self.proj_out(out) - return x + out - - -def sdp_no_mem_attnblock_forward(self, x): - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): - return sdp_attnblock_forward(self, x) - - -def sub_quad_attnblock_forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - b, c, h, w = q.shape - q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) - out = rearrange(out, 'b (h w) c -> b c h w', h=h) - out = self.proj_out(out) - return x + out +# from __future__ import annotations +# import math +# import psutil +# import platform +# +# import torch +# from torch import einsum +# +# from ldm.util import default +# from einops import rearrange +# +# from modules import shared, errors, devices, sub_quadratic_attention +# from modules.hypernetworks import hypernetwork +# +# import ldm.modules.attention +# import ldm.modules.diffusionmodules.model +# +# import sgm.modules.attention +# import sgm.modules.diffusionmodules.model +# +# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +# sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward +# +# +# class SdOptimization: +# name: str = None +# label: str | None = None +# cmd_opt: str | None = None +# priority: int = 0 +# +# def title(self): +# if self.label is None: +# return self.name +# +# return f"{self.name} - {self.label}" +# +# def is_available(self): +# return True +# +# def apply(self): +# pass +# +# def undo(self): +# ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward +# ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward +# +# sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward +# sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward +# +# +# class SdOptimizationXformers(SdOptimization): +# name = "xformers" +# cmd_opt = "xformers" +# priority = 100 +# +# def is_available(self): +# return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)) +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = xformers_attention_forward +# ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward +# sgm.modules.attention.CrossAttention.forward = xformers_attention_forward +# sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward +# +# +# class SdOptimizationSdpNoMem(SdOptimization): +# name = "sdp-no-mem" +# label = "scaled dot product without memory efficient attention" +# cmd_opt = "opt_sdp_no_mem_attention" +# priority = 80 +# +# def is_available(self): +# return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward +# ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward +# sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward +# sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward +# +# +# class SdOptimizationSdp(SdOptimizationSdpNoMem): +# name = "sdp" +# label = "scaled dot product" +# cmd_opt = "opt_sdp_attention" +# priority = 70 +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward +# ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward +# sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward +# sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward +# +# +# class SdOptimizationSubQuad(SdOptimization): +# name = "sub-quadratic" +# cmd_opt = "opt_sub_quad_attention" +# +# @property +# def priority(self): +# return 1000 if shared.device.type == 'mps' else 10 +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward +# ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward +# sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward +# sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward +# +# +# class SdOptimizationV1(SdOptimization): +# name = "V1" +# label = "original v1" +# cmd_opt = "opt_split_attention_v1" +# priority = 10 +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 +# sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 +# +# +# class SdOptimizationInvokeAI(SdOptimization): +# name = "InvokeAI" +# cmd_opt = "opt_split_attention_invokeai" +# +# @property +# def priority(self): +# return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10 +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI +# sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI +# +# +# class SdOptimizationDoggettx(SdOptimization): +# name = "Doggettx" +# cmd_opt = "opt_split_attention" +# priority = 90 +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward +# ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward +# sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward +# sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward +# +# +# def list_optimizers(res): +# res.extend([ +# SdOptimizationXformers(), +# SdOptimizationSdpNoMem(), +# SdOptimizationSdp(), +# SdOptimizationSubQuad(), +# SdOptimizationV1(), +# SdOptimizationInvokeAI(), +# SdOptimizationDoggettx(), +# ]) +# +# +# if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: +# try: +# import xformers.ops +# shared.xformers_available = True +# except Exception: +# errors.report("Cannot import xformers", exc_info=True) +# +# +# def get_available_vram(): +# if shared.device.type == 'cuda': +# stats = torch.cuda.memory_stats(shared.device) +# mem_active = stats['active_bytes.all.current'] +# mem_reserved = stats['reserved_bytes.all.current'] +# mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) +# mem_free_torch = mem_reserved - mem_active +# mem_free_total = mem_free_cuda + mem_free_torch +# return mem_free_total +# else: +# return psutil.virtual_memory().available +# +# +# # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion +# def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs): +# h = self.heads +# +# q_in = self.to_q(x) +# context = default(context, x) +# +# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) +# k_in = self.to_k(context_k) +# v_in = self.to_v(context_v) +# del context, context_k, context_v, x +# +# q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) +# del q_in, k_in, v_in +# +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k, v = q.float(), k.float(), v.float() +# +# with devices.without_autocast(disable=not shared.opts.upcast_attn): +# r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) +# for i in range(0, q.shape[0], 2): +# end = i + 2 +# s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) +# s1 *= self.scale +# +# s2 = s1.softmax(dim=-1) +# del s1 +# +# r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) +# del s2 +# del q, k, v +# +# r1 = r1.to(dtype) +# +# r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) +# del r1 +# +# return self.to_out(r2) +# +# +# # taken from https://github.com/Doggettx/stable-diffusion and modified +# def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs): +# h = self.heads +# +# q_in = self.to_q(x) +# context = default(context, x) +# +# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) +# k_in = self.to_k(context_k) +# v_in = self.to_v(context_v) +# +# dtype = q_in.dtype +# if shared.opts.upcast_attn: +# q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() +# +# with devices.without_autocast(disable=not shared.opts.upcast_attn): +# k_in = k_in * self.scale +# +# del context, x +# +# q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) +# del q_in, k_in, v_in +# +# r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) +# +# mem_free_total = get_available_vram() +# +# gb = 1024 ** 3 +# tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() +# modifier = 3 if q.element_size() == 2 else 2.5 +# mem_required = tensor_size * modifier +# steps = 1 +# +# if mem_required > mem_free_total: +# steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) +# # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " +# # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") +# +# if steps > 64: +# max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 +# raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' +# f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') +# +# slice_size = q.shape[1] // steps +# for i in range(0, q.shape[1], slice_size): +# end = min(i + slice_size, q.shape[1]) +# s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) +# +# s2 = s1.softmax(dim=-1, dtype=q.dtype) +# del s1 +# +# r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) +# del s2 +# +# del q, k, v +# +# r1 = r1.to(dtype) +# +# r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) +# del r1 +# +# return self.to_out(r2) +# +# +# # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- +# mem_total_gb = psutil.virtual_memory().total // (1 << 30) +# +# +# def einsum_op_compvis(q, k, v): +# s = einsum('b i d, b j d -> b i j', q, k) +# s = s.softmax(dim=-1, dtype=s.dtype) +# return einsum('b i j, b j d -> b i d', s, v) +# +# +# def einsum_op_slice_0(q, k, v, slice_size): +# r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) +# for i in range(0, q.shape[0], slice_size): +# end = i + slice_size +# r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) +# return r +# +# +# def einsum_op_slice_1(q, k, v, slice_size): +# r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) +# for i in range(0, q.shape[1], slice_size): +# end = i + slice_size +# r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) +# return r +# +# +# def einsum_op_mps_v1(q, k, v): +# if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 +# return einsum_op_compvis(q, k, v) +# else: +# slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) +# if slice_size % 4096 == 0: +# slice_size -= 1 +# return einsum_op_slice_1(q, k, v, slice_size) +# +# +# def einsum_op_mps_v2(q, k, v): +# if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: +# return einsum_op_compvis(q, k, v) +# else: +# return einsum_op_slice_0(q, k, v, 1) +# +# +# def einsum_op_tensor_mem(q, k, v, max_tensor_mb): +# size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) +# if size_mb <= max_tensor_mb: +# return einsum_op_compvis(q, k, v) +# div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() +# if div <= q.shape[0]: +# return einsum_op_slice_0(q, k, v, q.shape[0] // div) +# return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) +# +# +# def einsum_op_cuda(q, k, v): +# stats = torch.cuda.memory_stats(q.device) +# mem_active = stats['active_bytes.all.current'] +# mem_reserved = stats['reserved_bytes.all.current'] +# mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) +# mem_free_torch = mem_reserved - mem_active +# mem_free_total = mem_free_cuda + mem_free_torch +# # Divide factor of safety as there's copying and fragmentation +# return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) +# +# +# def einsum_op(q, k, v): +# if q.device.type == 'cuda': +# return einsum_op_cuda(q, k, v) +# +# if q.device.type == 'mps': +# if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18: +# return einsum_op_mps_v1(q, k, v) +# return einsum_op_mps_v2(q, k, v) +# +# # Smaller slices are faster due to L2/L3/SLC caches. +# # Tested on i7 with 8MB L3 cache. +# return einsum_op_tensor_mem(q, k, v, 32) +# +# +# def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs): +# h = self.heads +# +# q = self.to_q(x) +# context = default(context, x) +# +# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) +# k = self.to_k(context_k) +# v = self.to_v(context_v) +# del context, context_k, context_v, x +# +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() +# +# with devices.without_autocast(disable=not shared.opts.upcast_attn): +# k = k * self.scale +# +# q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v)) +# r = einsum_op(q, k, v) +# r = r.to(dtype) +# return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) +# +# # -- End of code from https://github.com/invoke-ai/InvokeAI -- +# +# +# # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +# # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface +# def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs): +# assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." +# +# h = self.heads +# +# q = self.to_q(x) +# context = default(context, x) +# +# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) +# k = self.to_k(context_k) +# v = self.to_v(context_v) +# del context, context_k, context_v, x +# +# q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) +# k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) +# v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) +# +# if q.device.type == 'mps': +# q, k, v = q.contiguous(), k.contiguous(), v.contiguous() +# +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k = q.float(), k.float() +# +# x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) +# +# x = x.to(dtype) +# +# x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) +# +# out_proj, dropout = self.to_out +# x = out_proj(x) +# x = dropout(x) +# +# return x +# +# +# def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): +# bytes_per_token = torch.finfo(q.dtype).bits//8 +# batch_x_heads, q_tokens, _ = q.shape +# _, k_tokens, _ = k.shape +# qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens +# +# if chunk_threshold is None: +# if q.device.type == 'mps': +# chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token) +# else: +# chunk_threshold_bytes = int(get_available_vram() * 0.7) +# elif chunk_threshold == 0: +# chunk_threshold_bytes = None +# else: +# chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) +# +# if kv_chunk_size_min is None and chunk_threshold_bytes is not None: +# kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) +# elif kv_chunk_size_min == 0: +# kv_chunk_size_min = None +# +# if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: +# # the big matmul fits into our memory limit; do everything in 1 chunk, +# # i.e. send it down the unchunked fast-path +# kv_chunk_size = k_tokens +# +# with devices.without_autocast(disable=q.dtype == v.dtype): +# return sub_quadratic_attention.efficient_dot_product_attention( +# q, +# k, +# v, +# query_chunk_size=q_chunk_size, +# kv_chunk_size=kv_chunk_size, +# kv_chunk_size_min = kv_chunk_size_min, +# use_checkpoint=use_checkpoint, +# ) +# +# +# def get_xformers_flash_attention_op(q, k, v): +# if not shared.cmd_opts.xformers_flash_attention: +# return None +# +# try: +# flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp +# fw, bw = flash_attention_op +# if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): +# return flash_attention_op +# except Exception as e: +# errors.display_once(e, "enabling flash attention") +# +# return None +# +# +# def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): +# h = self.heads +# q_in = self.to_q(x) +# context = default(context, x) +# +# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) +# k_in = self.to_k(context_k) +# v_in = self.to_v(context_v) +# +# q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in)) +# +# del q_in, k_in, v_in +# +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k, v = q.float(), k.float(), v.float() +# +# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) +# +# out = out.to(dtype) +# +# b, n, h, d = out.shape +# out = out.reshape(b, n, h * d) +# return self.to_out(out) +# +# +# # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py +# # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface +# def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs): +# batch_size, sequence_length, inner_dim = x.shape +# +# if mask is not None: +# mask = self.prepare_attention_mask(mask, sequence_length, batch_size) +# mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) +# +# h = self.heads +# q_in = self.to_q(x) +# context = default(context, x) +# +# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) +# k_in = self.to_k(context_k) +# v_in = self.to_v(context_v) +# +# head_dim = inner_dim // h +# q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) +# k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) +# v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) +# +# del q_in, k_in, v_in +# +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k, v = q.float(), k.float(), v.float() +# +# # the output of sdp = (batch, num_heads, seq_len, head_dim) +# hidden_states = torch.nn.functional.scaled_dot_product_attention( +# q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False +# ) +# +# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim) +# hidden_states = hidden_states.to(dtype) +# +# # linear proj +# hidden_states = self.to_out[0](hidden_states) +# # dropout +# hidden_states = self.to_out[1](hidden_states) +# return hidden_states +# +# +# def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs): +# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): +# return scaled_dot_product_attention_forward(self, x, context, mask) +# +# +# def cross_attention_attnblock_forward(self, x): +# h_ = x +# h_ = self.norm(h_) +# q1 = self.q(h_) +# k1 = self.k(h_) +# v = self.v(h_) +# +# # compute attention +# b, c, h, w = q1.shape +# +# q2 = q1.reshape(b, c, h*w) +# del q1 +# +# q = q2.permute(0, 2, 1) # b,hw,c +# del q2 +# +# k = k1.reshape(b, c, h*w) # b,c,hw +# del k1 +# +# h_ = torch.zeros_like(k, device=q.device) +# +# mem_free_total = get_available_vram() +# +# tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() +# mem_required = tensor_size * 2.5 +# steps = 1 +# +# if mem_required > mem_free_total: +# steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) +# +# slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] +# for i in range(0, q.shape[1], slice_size): +# end = i + slice_size +# +# w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] +# w2 = w1 * (int(c)**(-0.5)) +# del w1 +# w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) +# del w2 +# +# # attend to values +# v1 = v.reshape(b, c, h*w) +# w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) +# del w3 +# +# h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] +# del v1, w4 +# +# h2 = h_.reshape(b, c, h, w) +# del h_ +# +# h3 = self.proj_out(h2) +# del h2 +# +# h3 += x +# +# return h3 +# +# +# def xformers_attnblock_forward(self, x): +# try: +# h_ = x +# h_ = self.norm(h_) +# q = self.q(h_) +# k = self.k(h_) +# v = self.v(h_) +# b, c, h, w = q.shape +# q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k = q.float(), k.float() +# q = q.contiguous() +# k = k.contiguous() +# v = v.contiguous() +# out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) +# out = out.to(dtype) +# out = rearrange(out, 'b (h w) c -> b c h w', h=h) +# out = self.proj_out(out) +# return x + out +# except NotImplementedError: +# return cross_attention_attnblock_forward(self, x) +# +# +# def sdp_attnblock_forward(self, x): +# h_ = x +# h_ = self.norm(h_) +# q = self.q(h_) +# k = self.k(h_) +# v = self.v(h_) +# b, c, h, w = q.shape +# q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k, v = q.float(), k.float(), v.float() +# q = q.contiguous() +# k = k.contiguous() +# v = v.contiguous() +# out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) +# out = out.to(dtype) +# out = rearrange(out, 'b (h w) c -> b c h w', h=h) +# out = self.proj_out(out) +# return x + out +# +# +# def sdp_no_mem_attnblock_forward(self, x): +# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): +# return sdp_attnblock_forward(self, x) +# +# +# def sub_quad_attnblock_forward(self, x): +# h_ = x +# h_ = self.norm(h_) +# q = self.q(h_) +# k = self.k(h_) +# v = self.v(h_) +# b, c, h, w = q.shape +# q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) +# q = q.contiguous() +# k = k.contiguous() +# v = v.contiguous() +# out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) +# out = rearrange(out, 'b (h w) c -> b c h w', h=h) +# out = self.proj_out(out) +# return x + out diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index b4f03b13..eb4a0af4 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -1,154 +1,154 @@ -import torch -from packaging import version -from einops import repeat -import math - -from modules import devices -from modules.sd_hijack_utils import CondFunc - - -class TorchHijackForUnet: - """ - This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; - this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 - """ - - def __getattr__(self, item): - if item == 'cat': - return self.cat - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") - - def cat(self, tensors, *args, **kwargs): - if len(tensors) == 2: - a, b = tensors - if a.shape[-2:] != b.shape[-2:]: - a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") - - tensors = (a, b) - - return torch.cat(tensors, *args, **kwargs) - - -th = TorchHijackForUnet() - - -# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling -def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): - """Always make sure inputs to unet are in correct dtype.""" - if isinstance(cond, dict): - for y in cond.keys(): - if isinstance(cond[y], list): - cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] - else: - cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] - - with devices.autocast(): - result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) - if devices.unet_needs_upcast: - return result.float() - else: - return result - - -# Monkey patch to create timestep embed tensor on device, avoiding a block. -def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - else: - embedding = repeat(timesteps, 'b -> b d', d=dim) - return embedding - - -# Monkey patch to SpatialTransformer removing unnecessary contiguous calls. -# Prevents a lot of unnecessary aten::copy_ calls -def spatial_transformer_forward(_, self, x: torch.Tensor, context=None): - # note: if no context is given, cross-attention defaults to self-attention - if not isinstance(context, list): - context = [context] - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - if not self.use_linear: - x = self.proj_in(x) - x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) - if self.use_linear: - x = self.proj_in(x) - for i, block in enumerate(self.transformer_blocks): - x = block(x, context=context[i]) - if self.use_linear: - x = self.proj_out(x) - x = x.view(b, h, w, c).permute(0, 3, 1, 2) - if not self.use_linear: - x = self.proj_out(x) - return x + x_in - - -class GELUHijack(torch.nn.GELU, torch.nn.Module): - def __init__(self, *args, **kwargs): - torch.nn.GELU.__init__(self, *args, **kwargs) - def forward(self, x): - if devices.unet_needs_upcast: - return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) - else: - return torch.nn.GELU.forward(self, x) - - -ddpm_edit_hijack = None -def hijack_ddpm_edit(): - global ddpm_edit_hijack - if not ddpm_edit_hijack: - CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) - CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) - ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model) - - -unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) -CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) - -if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): - CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) - CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) - CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) - -first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 -first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) - -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) -CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) - - -def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): - if devices.unet_needs_upcast and timesteps.dtype == torch.int64: - dtype = torch.float32 - else: - dtype = devices.dtype_unet - return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) - - -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) -CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) +# import torch +# from packaging import version +# from einops import repeat +# import math +# +# from modules import devices +# from modules.sd_hijack_utils import CondFunc +# +# +# class TorchHijackForUnet: +# """ +# This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; +# this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 +# """ +# +# def __getattr__(self, item): +# if item == 'cat': +# return self.cat +# +# if hasattr(torch, item): +# return getattr(torch, item) +# +# raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") +# +# def cat(self, tensors, *args, **kwargs): +# if len(tensors) == 2: +# a, b = tensors +# if a.shape[-2:] != b.shape[-2:]: +# a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") +# +# tensors = (a, b) +# +# return torch.cat(tensors, *args, **kwargs) +# +# +# th = TorchHijackForUnet() +# +# +# # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling +# def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): +# """Always make sure inputs to unet are in correct dtype.""" +# if isinstance(cond, dict): +# for y in cond.keys(): +# if isinstance(cond[y], list): +# cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] +# else: +# cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] +# +# with devices.autocast(): +# result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) +# if devices.unet_needs_upcast: +# return result.float() +# else: +# return result +# +# +# # Monkey patch to create timestep embed tensor on device, avoiding a block. +# def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False): +# """ +# Create sinusoidal timestep embeddings. +# :param timesteps: a 1-D Tensor of N indices, one per batch element. +# These may be fractional. +# :param dim: the dimension of the output. +# :param max_period: controls the minimum frequency of the embeddings. +# :return: an [N x dim] Tensor of positional embeddings. +# """ +# if not repeat_only: +# half = dim // 2 +# freqs = torch.exp( +# -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half +# ) +# args = timesteps[:, None].float() * freqs[None] +# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) +# if dim % 2: +# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) +# else: +# embedding = repeat(timesteps, 'b -> b d', d=dim) +# return embedding +# +# +# # Monkey patch to SpatialTransformer removing unnecessary contiguous calls. +# # Prevents a lot of unnecessary aten::copy_ calls +# def spatial_transformer_forward(_, self, x: torch.Tensor, context=None): +# # note: if no context is given, cross-attention defaults to self-attention +# if not isinstance(context, list): +# context = [context] +# b, c, h, w = x.shape +# x_in = x +# x = self.norm(x) +# if not self.use_linear: +# x = self.proj_in(x) +# x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) +# if self.use_linear: +# x = self.proj_in(x) +# for i, block in enumerate(self.transformer_blocks): +# x = block(x, context=context[i]) +# if self.use_linear: +# x = self.proj_out(x) +# x = x.view(b, h, w, c).permute(0, 3, 1, 2) +# if not self.use_linear: +# x = self.proj_out(x) +# return x + x_in +# +# +# class GELUHijack(torch.nn.GELU, torch.nn.Module): +# def __init__(self, *args, **kwargs): +# torch.nn.GELU.__init__(self, *args, **kwargs) +# def forward(self, x): +# if devices.unet_needs_upcast: +# return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) +# else: +# return torch.nn.GELU.forward(self, x) +# +# +# ddpm_edit_hijack = None +# def hijack_ddpm_edit(): +# global ddpm_edit_hijack +# if not ddpm_edit_hijack: +# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) +# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) +# ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model) +# +# +# unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast +# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) +# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) +# CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) +# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) +# +# if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): +# CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) +# CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) +# CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) +# +# first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 +# first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) +# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) +# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) +# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) +# +# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) +# CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) +# +# +# def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): +# if devices.unet_needs_upcast and timesteps.dtype == torch.int64: +# dtype = torch.float32 +# else: +# dtype = devices.dtype_unet +# return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) +# +# +# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) +# CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 41e5087d..faddd9a2 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -1,137 +1,137 @@ -import os - -import torch - -from modules import shared, paths, sd_disable_initialization, devices - -sd_configs_path = shared.sd_configs_path -sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") -sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference") - - -config_default = shared.sd_default_config -# config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") -config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") -config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") -config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") -config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") -config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") -config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") -config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") -config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") -config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") -config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") -config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") -config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") -config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") - - -def is_using_v_parameterization_for_sd2(state_dict): - """ - Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. - """ - - import ldm.modules.diffusionmodules.openaimodel - - device = devices.device - - with sd_disable_initialization.DisableInitialization(): - unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( - use_checkpoint=False, - use_fp16=False, - image_size=32, - in_channels=4, - out_channels=4, - model_channels=320, - attention_resolutions=[4, 2, 1], - num_res_blocks=2, - channel_mult=[1, 2, 4, 4], - num_head_channels=64, - use_spatial_transformer=True, - use_linear_in_transformer=True, - transformer_depth=1, - context_dim=1024, - legacy=False - ) - unet.eval() - - with torch.no_grad(): - unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} - unet.load_state_dict(unet_sd, strict=True) - unet.to(device=device, dtype=devices.dtype_unet) - - test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 - x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 - - with devices.autocast(): - out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item() - - return out < -1 - - -def guess_model_config_from_state_dict(sd, filename): - sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) - diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) - sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) - - if "model.diffusion_model.x_embedder.proj.weight" in sd: - return config_sd3 - - if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: - if diffusion_model_input.shape[1] == 9: - return config_sdxl_inpainting - else: - return config_sdxl - - if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: - return config_sdxl_refiner - elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: - return config_depth_model - elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: - return config_unclip - elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024: - return config_unopenclip - - if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: - if diffusion_model_input.shape[1] == 9: - return config_sd2_inpainting - # elif is_using_v_parameterization_for_sd2(sd): - # return config_sd2v - else: - return config_sd2v - - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - return config_inpainting - if diffusion_model_input.shape[1] == 8: - return config_instruct_pix2pix - - if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: - if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: - return config_alt_diffusion_m18 - return config_alt_diffusion - - return config_default - - -def find_checkpoint_config(state_dict, info): - if info is None: - return guess_model_config_from_state_dict(state_dict, "") - - config = find_checkpoint_config_near_filename(info) - if config is not None: - return config - - return guess_model_config_from_state_dict(state_dict, info.filename) - - -def find_checkpoint_config_near_filename(info): - if info is None: - return None - - config = f"{os.path.splitext(info.filename)[0]}.yaml" - if os.path.exists(config): - return config - - return None - +# import os +# +# import torch +# +# from modules import shared, paths, sd_disable_initialization, devices +# +# sd_configs_path = shared.sd_configs_path +# sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") +# sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference") +# +# +# config_default = shared.sd_default_config +# # config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") +# config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +# config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") +# config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") +# config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") +# config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") +# config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") +# config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") +# config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") +# config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") +# config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") +# config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") +# config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") +# config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") +# +# +# def is_using_v_parameterization_for_sd2(state_dict): +# """ +# Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. +# """ +# +# import ldm.modules.diffusionmodules.openaimodel +# +# device = devices.device +# +# with sd_disable_initialization.DisableInitialization(): +# unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( +# use_checkpoint=False, +# use_fp16=False, +# image_size=32, +# in_channels=4, +# out_channels=4, +# model_channels=320, +# attention_resolutions=[4, 2, 1], +# num_res_blocks=2, +# channel_mult=[1, 2, 4, 4], +# num_head_channels=64, +# use_spatial_transformer=True, +# use_linear_in_transformer=True, +# transformer_depth=1, +# context_dim=1024, +# legacy=False +# ) +# unet.eval() +# +# with torch.no_grad(): +# unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} +# unet.load_state_dict(unet_sd, strict=True) +# unet.to(device=device, dtype=devices.dtype_unet) +# +# test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 +# x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 +# +# with devices.autocast(): +# out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item() +# +# return out < -1 +# +# +# def guess_model_config_from_state_dict(sd, filename): +# sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) +# diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) +# sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) +# +# if "model.diffusion_model.x_embedder.proj.weight" in sd: +# return config_sd3 +# +# if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: +# if diffusion_model_input.shape[1] == 9: +# return config_sdxl_inpainting +# else: +# return config_sdxl +# +# if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: +# return config_sdxl_refiner +# elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: +# return config_depth_model +# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: +# return config_unclip +# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024: +# return config_unopenclip +# +# if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: +# if diffusion_model_input.shape[1] == 9: +# return config_sd2_inpainting +# # elif is_using_v_parameterization_for_sd2(sd): +# # return config_sd2v +# else: +# return config_sd2v +# +# if diffusion_model_input is not None: +# if diffusion_model_input.shape[1] == 9: +# return config_inpainting +# if diffusion_model_input.shape[1] == 8: +# return config_instruct_pix2pix +# +# if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: +# if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: +# return config_alt_diffusion_m18 +# return config_alt_diffusion +# +# return config_default +# +# +# def find_checkpoint_config(state_dict, info): +# if info is None: +# return guess_model_config_from_state_dict(state_dict, "") +# +# config = find_checkpoint_config_near_filename(info) +# if config is not None: +# return config +# +# return guess_model_config_from_state_dict(state_dict, info.filename) +# +# +# def find_checkpoint_config_near_filename(info): +# if info is None: +# return None +# +# config = f"{os.path.splitext(info.filename)[0]}.yaml" +# if os.path.exists(config): +# return config +# +# return None +# diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 3f1bab96..0b84f2fc 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -1,115 +1,115 @@ -from __future__ import annotations - -import torch - -import sgm.models.diffusion -import sgm.modules.diffusionmodules.denoiser_scaling -import sgm.modules.diffusionmodules.discretizer -from modules import devices, shared, prompt_parser -from modules import torch_utils - -from backend import memory_management - - -def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): - - for embedder in self.conditioner.embedders: - embedder.ucg_rate = 0.0 - - width = getattr(batch, 'width', 1024) or 1024 - height = getattr(batch, 'height', 1024) or 1024 - is_negative_prompt = getattr(batch, 'is_negative_prompt', False) - aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score - - devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype()) - - sdxl_conds = { - "txt": batch, - "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), - "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), - "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), - "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), - } - - force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) - c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) - - return c - - -def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs): - if self.model.diffusion_model.in_channels == 9: - x = torch.cat([x] + cond['c_concat'], dim=1) - - return self.model(x, t, cond, *args, **kwargs) - - -def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility - return x - - -sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning -sgm.models.diffusion.DiffusionEngine.apply_model = apply_model -sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding - - -def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): - res = [] - - for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: - encoded = embedder.encode_embedding_init_text(init_text, nvpt) - res.append(encoded) - - return torch.cat(res, dim=1) - - -def tokenize(self: sgm.modules.GeneralConditioner, texts): - for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]: - return embedder.tokenize(texts) - - raise AssertionError('no tokenizer available') - - - -def process_texts(self, texts): - for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: - return embedder.process_texts(texts) - - -def get_target_prompt_token_count(self, token_count): - for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: - return embedder.get_target_prompt_token_count(token_count) - - -# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist -sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text -sgm.modules.GeneralConditioner.tokenize = tokenize -sgm.modules.GeneralConditioner.process_texts = process_texts -sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count - - -def extend_sdxl(model): - """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - - dtype = torch_utils.get_param(model.model.diffusion_model).dtype - model.model.diffusion_model.dtype = dtype - model.model.conditioning_key = 'crossattn' - model.cond_stage_key = 'txt' - # model.cond_stage_model will be set in sd_hijack - - model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" - - discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() - model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) - - model.conditioner.wrapped = torch.nn.Module() - - -sgm.modules.attention.print = shared.ldm_print -sgm.modules.diffusionmodules.model.print = shared.ldm_print -sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print -sgm.modules.encoders.modules.print = shared.ldm_print - -# this gets the code to load the vanilla attention that we override -sgm.modules.attention.SDP_IS_AVAILABLE = True -sgm.modules.attention.XFORMERS_IS_AVAILABLE = False +# from __future__ import annotations +# +# import torch +# +# import sgm.models.diffusion +# import sgm.modules.diffusionmodules.denoiser_scaling +# import sgm.modules.diffusionmodules.discretizer +# from modules import devices, shared, prompt_parser +# from modules import torch_utils +# +# from backend import memory_management +# +# +# def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): +# +# for embedder in self.conditioner.embedders: +# embedder.ucg_rate = 0.0 +# +# width = getattr(batch, 'width', 1024) or 1024 +# height = getattr(batch, 'height', 1024) or 1024 +# is_negative_prompt = getattr(batch, 'is_negative_prompt', False) +# aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score +# +# devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype()) +# +# sdxl_conds = { +# "txt": batch, +# "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), +# "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), +# "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), +# "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), +# } +# +# force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) +# c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) +# +# return c +# +# +# def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs): +# if self.model.diffusion_model.in_channels == 9: +# x = torch.cat([x] + cond['c_concat'], dim=1) +# +# return self.model(x, t, cond, *args, **kwargs) +# +# +# def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility +# return x +# +# +# sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +# sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +# sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding +# +# +# def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): +# res = [] +# +# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: +# encoded = embedder.encode_embedding_init_text(init_text, nvpt) +# res.append(encoded) +# +# return torch.cat(res, dim=1) +# +# +# def tokenize(self: sgm.modules.GeneralConditioner, texts): +# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]: +# return embedder.tokenize(texts) +# +# raise AssertionError('no tokenizer available') +# +# +# +# def process_texts(self, texts): +# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: +# return embedder.process_texts(texts) +# +# +# def get_target_prompt_token_count(self, token_count): +# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: +# return embedder.get_target_prompt_token_count(token_count) +# +# +# # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist +# sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text +# sgm.modules.GeneralConditioner.tokenize = tokenize +# sgm.modules.GeneralConditioner.process_texts = process_texts +# sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count +# +# +# def extend_sdxl(model): +# """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" +# +# dtype = torch_utils.get_param(model.model.diffusion_model).dtype +# model.model.diffusion_model.dtype = dtype +# model.model.conditioning_key = 'crossattn' +# model.cond_stage_key = 'txt' +# # model.cond_stage_model will be set in sd_hijack +# +# model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" +# +# discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() +# model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) +# +# model.conditioner.wrapped = torch.nn.Module() +# +# +# sgm.modules.attention.print = shared.ldm_print +# sgm.modules.diffusionmodules.model.print = shared.ldm_print +# sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print +# sgm.modules.encoders.modules.print = shared.ldm_print +# +# # this gets the code to load the vanilla attention that we override +# sgm.modules.attention.SDP_IS_AVAILABLE = True +# sgm.modules.attention.XFORMERS_IS_AVAILABLE = False diff --git a/modules/shared_items.py b/modules/shared_items.py index 11f10b3f..1568ba36 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -35,9 +35,7 @@ def refresh_vae_list(): def cross_attention_optimizations(): - import modules.sd_hijack - - return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"] + return ["Automatic"] def sd_unet_items():