mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 18:51:31 +00:00
Free WebUI from its Prison
Congratulations WebUI. Say Hello to freedom.
This commit is contained in:
@@ -7,7 +7,6 @@ import torch
|
|||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
from modules import shared, script_callbacks, masking, images
|
from modules import shared, script_callbacks, masking, images
|
||||||
from modules.ui_components import InputAccordion
|
from modules.ui_components import InputAccordion
|
||||||
from modules.api.api import decode_base64_to_image
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from lib_controlnet import global_state, external_code
|
from lib_controlnet import global_state, external_code
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
|
|||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin
|
from PIL import PngImagePlugin
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -725,7 +724,7 @@ class Api:
|
|||||||
|
|
||||||
def get_sd_models(self):
|
def get_sd_models(self):
|
||||||
import modules.sd_models as sd_models
|
import modules.sd_models as sd_models
|
||||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
|
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename} for x in sd_models.checkpoints_list.values()]
|
||||||
|
|
||||||
def get_sd_vaes(self):
|
def get_sd_vaes(self):
|
||||||
import modules.sd_vae as sd_vae
|
import modules.sd_vae as sd_vae
|
||||||
|
|||||||
@@ -10,16 +10,6 @@ from threading import Thread
|
|||||||
from modules.timer import startup_timer
|
from modules.timer import startup_timer
|
||||||
|
|
||||||
|
|
||||||
class HiddenPrints:
|
|
||||||
def __enter__(self):
|
|
||||||
self._original_stdout = sys.stdout
|
|
||||||
sys.stdout = open(os.devnull, 'w')
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
sys.stdout.close()
|
|
||||||
sys.stdout = self._original_stdout
|
|
||||||
|
|
||||||
|
|
||||||
def imports():
|
def imports():
|
||||||
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
@@ -35,16 +25,8 @@ def imports():
|
|||||||
import gradio # noqa: F401
|
import gradio # noqa: F401
|
||||||
startup_timer.record("import gradio")
|
startup_timer.record("import gradio")
|
||||||
|
|
||||||
with HiddenPrints():
|
from modules import paths, timer, import_hook, errors # noqa: F401
|
||||||
from modules import paths, timer, import_hook, errors # noqa: F401
|
startup_timer.record("setup paths")
|
||||||
startup_timer.record("setup paths")
|
|
||||||
|
|
||||||
import ldm.modules.encoders.modules # noqa: F401
|
|
||||||
import ldm.modules.diffusionmodules.model
|
|
||||||
startup_timer.record("import ldm")
|
|
||||||
|
|
||||||
import sgm.modules.encoders.modules # noqa: F401
|
|
||||||
startup_timer.record("import sgm")
|
|
||||||
|
|
||||||
from modules import shared_init
|
from modules import shared_init
|
||||||
shared_init.initialize()
|
shared_init.initialize()
|
||||||
@@ -141,11 +123,6 @@ def initialize_rest(*, reload_script_modules=False):
|
|||||||
textual_inversion.textual_inversion.list_textual_inversion_templates()
|
textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||||
startup_timer.record("refresh textual inversion templates")
|
startup_timer.record("refresh textual inversion templates")
|
||||||
|
|
||||||
from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
|
|
||||||
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
|
|
||||||
sd_hijack.list_optimizers()
|
|
||||||
startup_timer.record("scripts list_optimizers")
|
|
||||||
|
|
||||||
from modules import sd_unet
|
from modules import sd_unet
|
||||||
sd_unet.list_unets()
|
sd_unet.list_unets()
|
||||||
startup_timer.record("scripts list_unets")
|
startup_timer.record("scripts list_unets")
|
||||||
|
|||||||
@@ -391,15 +391,15 @@ def prepare_environment():
|
|||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||||
|
|
||||||
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
||||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
# stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||||
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
# stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||||
huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git')
|
huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git')
|
||||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||||
|
|
||||||
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
# stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
# stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4")
|
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
@@ -456,8 +456,8 @@ def prepare_environment():
|
|||||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||||
|
|
||||||
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
||||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
# git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
# git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||||
git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash)
|
git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash)
|
||||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
|
|||||||
mute_sdxl_imports()
|
mute_sdxl_imports()
|
||||||
|
|
||||||
path_dirs = [
|
path_dirs = [
|
||||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
# (sd_path, 'ldm', 'Stable Diffusion', []),
|
||||||
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
# (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
||||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||||
(os.path.join(sd_path, '../huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
|
(os.path.join(sd_path, '../huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
|
||||||
@@ -53,13 +53,13 @@ for d, must_exist, what, options in path_dirs:
|
|||||||
d = os.path.abspath(d)
|
d = os.path.abspath(d)
|
||||||
if "atstart" in options:
|
if "atstart" in options:
|
||||||
sys.path.insert(0, d)
|
sys.path.insert(0, d)
|
||||||
elif "sgm" in options:
|
# elif "sgm" in options:
|
||||||
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
|
# # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
|
||||||
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
|
# # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
|
||||||
|
#
|
||||||
sys.path.insert(0, d)
|
# sys.path.insert(0, d)
|
||||||
import sgm # noqa: F401
|
# import sgm # noqa: F401
|
||||||
sys.path.pop(0)
|
# sys.path.pop(0)
|
||||||
else:
|
else:
|
||||||
sys.path.append(d)
|
sys.path.append(d)
|
||||||
paths[what] = d
|
paths[what] = d
|
||||||
|
|||||||
@@ -1,124 +1,3 @@
|
|||||||
import torch
|
|
||||||
from torch.nn.functional import silu
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
|
||||||
from modules.hypernetworks import hypernetwork
|
|
||||||
from modules.shared import cmd_opts
|
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
|
||||||
|
|
||||||
import ldm.modules.attention
|
|
||||||
import ldm.modules.diffusionmodules.model
|
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
|
||||||
import ldm.models.diffusion.ddpm
|
|
||||||
import ldm.models.diffusion.ddim
|
|
||||||
import ldm.models.diffusion.plms
|
|
||||||
import ldm.modules.encoders.modules
|
|
||||||
|
|
||||||
import sgm.modules.attention
|
|
||||||
import sgm.modules.diffusionmodules.model
|
|
||||||
import sgm.modules.diffusionmodules.openaimodel
|
|
||||||
import sgm.modules.encoders.modules
|
|
||||||
|
|
||||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
|
||||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
|
||||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
|
||||||
|
|
||||||
# new memory efficient cross attention blocks do not support hypernets and we already
|
|
||||||
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
|
||||||
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
|
||||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
|
||||||
|
|
||||||
# silence new console spam from SD2
|
|
||||||
ldm.modules.attention.print = shared.ldm_print
|
|
||||||
ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
|
||||||
ldm.util.print = shared.ldm_print
|
|
||||||
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
|
||||||
|
|
||||||
optimizers = []
|
|
||||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
|
||||||
|
|
||||||
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
|
||||||
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
|
|
||||||
|
|
||||||
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
|
||||||
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
|
|
||||||
|
|
||||||
|
|
||||||
def list_optimizers():
|
|
||||||
new_optimizers = script_callbacks.list_optimizers_callback()
|
|
||||||
|
|
||||||
new_optimizers = [x for x in new_optimizers if x.is_available()]
|
|
||||||
|
|
||||||
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
|
||||||
|
|
||||||
optimizers.clear()
|
|
||||||
optimizers.extend(new_optimizers)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_optimizations(option=None):
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def undo_optimizations():
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def fix_checkpoint():
|
|
||||||
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
|
||||||
checkpoints to be added when not training (there's a warning)"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def weighted_loss(sd_model, pred, target, mean=True):
|
|
||||||
#Calculate the weight normally, but ignore the mean
|
|
||||||
loss = sd_model._old_get_loss(pred, target, mean=False)
|
|
||||||
|
|
||||||
#Check if we have weights available
|
|
||||||
weight = getattr(sd_model, '_custom_loss_weight', None)
|
|
||||||
if weight is not None:
|
|
||||||
loss *= weight
|
|
||||||
|
|
||||||
#Return the loss, as mean if specified
|
|
||||||
return loss.mean() if mean else loss
|
|
||||||
|
|
||||||
def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|
||||||
try:
|
|
||||||
#Temporarily append weights to a place accessible during loss calc
|
|
||||||
sd_model._custom_loss_weight = w
|
|
||||||
|
|
||||||
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
|
||||||
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
|
||||||
if not hasattr(sd_model, '_old_get_loss'):
|
|
||||||
sd_model._old_get_loss = sd_model.get_loss
|
|
||||||
sd_model.get_loss = MethodType(weighted_loss, sd_model)
|
|
||||||
|
|
||||||
#Run the standard forward function, but with the patched 'get_loss'
|
|
||||||
return sd_model.forward(x, c, *args, **kwargs)
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
#Delete temporary weights if appended
|
|
||||||
del sd_model._custom_loss_weight
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
#If we have an old loss function, reset the loss function to the original one
|
|
||||||
if hasattr(sd_model, '_old_get_loss'):
|
|
||||||
sd_model.get_loss = sd_model._old_get_loss
|
|
||||||
del sd_model._old_get_loss
|
|
||||||
|
|
||||||
def apply_weighted_forward(sd_model):
|
|
||||||
#Add new function 'weighted_forward' that can be called to calc weighted loss
|
|
||||||
sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
|
|
||||||
|
|
||||||
def undo_weighted_forward(sd_model):
|
|
||||||
try:
|
|
||||||
del sd_model.weighted_forward
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionModelHijack:
|
class StableDiffusionModelHijack:
|
||||||
fixes = None
|
fixes = None
|
||||||
layers = None
|
layers = None
|
||||||
@@ -156,74 +35,201 @@ class StableDiffusionModelHijack:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithFixes(torch.nn.Module):
|
|
||||||
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
|
||||||
super().__init__()
|
|
||||||
self.wrapped = wrapped
|
|
||||||
self.embeddings = embeddings
|
|
||||||
self.textual_inversion_key = textual_inversion_key
|
|
||||||
self.weight = self.wrapped.weight
|
|
||||||
|
|
||||||
def forward(self, input_ids):
|
|
||||||
batch_fixes = self.embeddings.fixes
|
|
||||||
self.embeddings.fixes = None
|
|
||||||
|
|
||||||
inputs_embeds = self.wrapped(input_ids)
|
|
||||||
|
|
||||||
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
vecs = []
|
|
||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
|
||||||
for offset, embedding in fixes:
|
|
||||||
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
|
||||||
emb = devices.cond_cast_unet(vec)
|
|
||||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
|
||||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
|
||||||
|
|
||||||
vecs.append(tensor)
|
|
||||||
|
|
||||||
return torch.stack(vecs)
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionEmbeddings(torch.nn.Embedding):
|
|
||||||
def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
|
|
||||||
super().__init__(num_embeddings, embedding_dim, **kwargs)
|
|
||||||
|
|
||||||
self.embeddings = model_hijack
|
|
||||||
self.textual_inversion_key = textual_inversion_key
|
|
||||||
|
|
||||||
@property
|
|
||||||
def wrapped(self):
|
|
||||||
return super().forward
|
|
||||||
|
|
||||||
def forward(self, input_ids):
|
|
||||||
return EmbeddingsWithFixes.forward(self, input_ids)
|
|
||||||
|
|
||||||
|
|
||||||
def add_circular_option_to_conv_2d():
|
|
||||||
conv2d_constructor = torch.nn.Conv2d.__init__
|
|
||||||
|
|
||||||
def conv2d_constructor_circular(self, *args, **kwargs):
|
|
||||||
return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
|
||||||
|
|
||||||
torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
|
||||||
|
|
||||||
|
|
||||||
model_hijack = StableDiffusionModelHijack()
|
model_hijack = StableDiffusionModelHijack()
|
||||||
|
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
# import torch
|
||||||
"""
|
# from torch.nn.functional import silu
|
||||||
Fix register buffer bug for Mac OS.
|
# from types import MethodType
|
||||||
"""
|
#
|
||||||
|
# from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
||||||
if type(attr) == torch.Tensor:
|
# from modules.hypernetworks import hypernetwork
|
||||||
if attr.device != devices.device:
|
# from modules.shared import cmd_opts
|
||||||
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
# from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
||||||
|
#
|
||||||
setattr(self, name, attr)
|
# import ldm.modules.attention
|
||||||
|
# import ldm.modules.diffusionmodules.model
|
||||||
|
# import ldm.modules.diffusionmodules.openaimodel
|
||||||
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
# import ldm.models.diffusion.ddpm
|
||||||
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
# import ldm.models.diffusion.ddim
|
||||||
|
# import ldm.models.diffusion.plms
|
||||||
|
# import ldm.modules.encoders.modules
|
||||||
|
#
|
||||||
|
# import sgm.modules.attention
|
||||||
|
# import sgm.modules.diffusionmodules.model
|
||||||
|
# import sgm.modules.diffusionmodules.openaimodel
|
||||||
|
# import sgm.modules.encoders.modules
|
||||||
|
#
|
||||||
|
# attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||||
|
# diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
|
# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
|
#
|
||||||
|
# # new memory efficient cross attention blocks do not support hypernets and we already
|
||||||
|
# # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
||||||
|
# ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
||||||
|
# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||||
|
#
|
||||||
|
# # silence new console spam from SD2
|
||||||
|
# ldm.modules.attention.print = shared.ldm_print
|
||||||
|
# ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||||
|
# ldm.util.print = shared.ldm_print
|
||||||
|
# ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||||
|
#
|
||||||
|
# optimizers = []
|
||||||
|
# current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||||
|
#
|
||||||
|
# ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||||
|
# ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
|
||||||
|
#
|
||||||
|
# sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||||
|
# sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def list_optimizers():
|
||||||
|
# new_optimizers = script_callbacks.list_optimizers_callback()
|
||||||
|
#
|
||||||
|
# new_optimizers = [x for x in new_optimizers if x.is_available()]
|
||||||
|
#
|
||||||
|
# new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
||||||
|
#
|
||||||
|
# optimizers.clear()
|
||||||
|
# optimizers.extend(new_optimizers)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def apply_optimizations(option=None):
|
||||||
|
# return
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def undo_optimizations():
|
||||||
|
# return
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def fix_checkpoint():
|
||||||
|
# """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||||
|
# checkpoints to be added when not training (there's a warning)"""
|
||||||
|
#
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def weighted_loss(sd_model, pred, target, mean=True):
|
||||||
|
# #Calculate the weight normally, but ignore the mean
|
||||||
|
# loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||||
|
#
|
||||||
|
# #Check if we have weights available
|
||||||
|
# weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||||
|
# if weight is not None:
|
||||||
|
# loss *= weight
|
||||||
|
#
|
||||||
|
# #Return the loss, as mean if specified
|
||||||
|
# return loss.mean() if mean else loss
|
||||||
|
#
|
||||||
|
# def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
||||||
|
# try:
|
||||||
|
# #Temporarily append weights to a place accessible during loss calc
|
||||||
|
# sd_model._custom_loss_weight = w
|
||||||
|
#
|
||||||
|
# #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||||
|
# #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||||
|
# if not hasattr(sd_model, '_old_get_loss'):
|
||||||
|
# sd_model._old_get_loss = sd_model.get_loss
|
||||||
|
# sd_model.get_loss = MethodType(weighted_loss, sd_model)
|
||||||
|
#
|
||||||
|
# #Run the standard forward function, but with the patched 'get_loss'
|
||||||
|
# return sd_model.forward(x, c, *args, **kwargs)
|
||||||
|
# finally:
|
||||||
|
# try:
|
||||||
|
# #Delete temporary weights if appended
|
||||||
|
# del sd_model._custom_loss_weight
|
||||||
|
# except AttributeError:
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# #If we have an old loss function, reset the loss function to the original one
|
||||||
|
# if hasattr(sd_model, '_old_get_loss'):
|
||||||
|
# sd_model.get_loss = sd_model._old_get_loss
|
||||||
|
# del sd_model._old_get_loss
|
||||||
|
#
|
||||||
|
# def apply_weighted_forward(sd_model):
|
||||||
|
# #Add new function 'weighted_forward' that can be called to calc weighted loss
|
||||||
|
# sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
|
||||||
|
#
|
||||||
|
# def undo_weighted_forward(sd_model):
|
||||||
|
# try:
|
||||||
|
# del sd_model.weighted_forward
|
||||||
|
# except AttributeError:
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
|
# def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||||
|
# super().__init__()
|
||||||
|
# self.wrapped = wrapped
|
||||||
|
# self.embeddings = embeddings
|
||||||
|
# self.textual_inversion_key = textual_inversion_key
|
||||||
|
# self.weight = self.wrapped.weight
|
||||||
|
#
|
||||||
|
# def forward(self, input_ids):
|
||||||
|
# batch_fixes = self.embeddings.fixes
|
||||||
|
# self.embeddings.fixes = None
|
||||||
|
#
|
||||||
|
# inputs_embeds = self.wrapped(input_ids)
|
||||||
|
#
|
||||||
|
# if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||||
|
# return inputs_embeds
|
||||||
|
#
|
||||||
|
# vecs = []
|
||||||
|
# for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
|
# for offset, embedding in fixes:
|
||||||
|
# vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||||
|
# emb = devices.cond_cast_unet(vec)
|
||||||
|
# emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||||
|
# tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
||||||
|
#
|
||||||
|
# vecs.append(tensor)
|
||||||
|
#
|
||||||
|
# return torch.stack(vecs)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class TextualInversionEmbeddings(torch.nn.Embedding):
|
||||||
|
# def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
|
||||||
|
# super().__init__(num_embeddings, embedding_dim, **kwargs)
|
||||||
|
#
|
||||||
|
# self.embeddings = model_hijack
|
||||||
|
# self.textual_inversion_key = textual_inversion_key
|
||||||
|
#
|
||||||
|
# @property
|
||||||
|
# def wrapped(self):
|
||||||
|
# return super().forward
|
||||||
|
#
|
||||||
|
# def forward(self, input_ids):
|
||||||
|
# return EmbeddingsWithFixes.forward(self, input_ids)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def add_circular_option_to_conv_2d():
|
||||||
|
# conv2d_constructor = torch.nn.Conv2d.__init__
|
||||||
|
#
|
||||||
|
# def conv2d_constructor_circular(self, *args, **kwargs):
|
||||||
|
# return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
||||||
|
#
|
||||||
|
# torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# model_hijack = StableDiffusionModelHijack()
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def register_buffer(self, name, attr):
|
||||||
|
# """
|
||||||
|
# Fix register buffer bug for Mac OS.
|
||||||
|
# """
|
||||||
|
#
|
||||||
|
# if type(attr) == torch.Tensor:
|
||||||
|
# if attr.device != devices.device:
|
||||||
|
# attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
||||||
|
#
|
||||||
|
# setattr(self, name, attr)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
||||||
|
# ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,154 +1,154 @@
|
|||||||
import torch
|
# import torch
|
||||||
from packaging import version
|
# from packaging import version
|
||||||
from einops import repeat
|
# from einops import repeat
|
||||||
import math
|
# import math
|
||||||
|
#
|
||||||
from modules import devices
|
# from modules import devices
|
||||||
from modules.sd_hijack_utils import CondFunc
|
# from modules.sd_hijack_utils import CondFunc
|
||||||
|
#
|
||||||
|
#
|
||||||
class TorchHijackForUnet:
|
# class TorchHijackForUnet:
|
||||||
"""
|
# """
|
||||||
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
# This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||||
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
# this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
||||||
"""
|
# """
|
||||||
|
#
|
||||||
def __getattr__(self, item):
|
# def __getattr__(self, item):
|
||||||
if item == 'cat':
|
# if item == 'cat':
|
||||||
return self.cat
|
# return self.cat
|
||||||
|
#
|
||||||
if hasattr(torch, item):
|
# if hasattr(torch, item):
|
||||||
return getattr(torch, item)
|
# return getattr(torch, item)
|
||||||
|
#
|
||||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
# raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||||
|
#
|
||||||
def cat(self, tensors, *args, **kwargs):
|
# def cat(self, tensors, *args, **kwargs):
|
||||||
if len(tensors) == 2:
|
# if len(tensors) == 2:
|
||||||
a, b = tensors
|
# a, b = tensors
|
||||||
if a.shape[-2:] != b.shape[-2:]:
|
# if a.shape[-2:] != b.shape[-2:]:
|
||||||
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
# a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||||
|
#
|
||||||
tensors = (a, b)
|
# tensors = (a, b)
|
||||||
|
#
|
||||||
return torch.cat(tensors, *args, **kwargs)
|
# return torch.cat(tensors, *args, **kwargs)
|
||||||
|
#
|
||||||
|
#
|
||||||
th = TorchHijackForUnet()
|
# th = TorchHijackForUnet()
|
||||||
|
#
|
||||||
|
#
|
||||||
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
# # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||||
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
# def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||||
"""Always make sure inputs to unet are in correct dtype."""
|
# """Always make sure inputs to unet are in correct dtype."""
|
||||||
if isinstance(cond, dict):
|
# if isinstance(cond, dict):
|
||||||
for y in cond.keys():
|
# for y in cond.keys():
|
||||||
if isinstance(cond[y], list):
|
# if isinstance(cond[y], list):
|
||||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
# cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||||
else:
|
# else:
|
||||||
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
# cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||||
|
#
|
||||||
with devices.autocast():
|
# with devices.autocast():
|
||||||
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
# result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
||||||
if devices.unet_needs_upcast:
|
# if devices.unet_needs_upcast:
|
||||||
return result.float()
|
# return result.float()
|
||||||
else:
|
# else:
|
||||||
return result
|
# return result
|
||||||
|
#
|
||||||
|
#
|
||||||
# Monkey patch to create timestep embed tensor on device, avoiding a block.
|
# # Monkey patch to create timestep embed tensor on device, avoiding a block.
|
||||||
def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
|
# def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
|
||||||
"""
|
# """
|
||||||
Create sinusoidal timestep embeddings.
|
# Create sinusoidal timestep embeddings.
|
||||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||||
These may be fractional.
|
# These may be fractional.
|
||||||
:param dim: the dimension of the output.
|
# :param dim: the dimension of the output.
|
||||||
:param max_period: controls the minimum frequency of the embeddings.
|
# :param max_period: controls the minimum frequency of the embeddings.
|
||||||
:return: an [N x dim] Tensor of positional embeddings.
|
# :return: an [N x dim] Tensor of positional embeddings.
|
||||||
"""
|
# """
|
||||||
if not repeat_only:
|
# if not repeat_only:
|
||||||
half = dim // 2
|
# half = dim // 2
|
||||||
freqs = torch.exp(
|
# freqs = torch.exp(
|
||||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
# -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
||||||
)
|
# )
|
||||||
args = timesteps[:, None].float() * freqs[None]
|
# args = timesteps[:, None].float() * freqs[None]
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
if dim % 2:
|
# if dim % 2:
|
||||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
else:
|
# else:
|
||||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
# embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||||
return embedding
|
# return embedding
|
||||||
|
#
|
||||||
|
#
|
||||||
# Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
|
# # Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
|
||||||
# Prevents a lot of unnecessary aten::copy_ calls
|
# # Prevents a lot of unnecessary aten::copy_ calls
|
||||||
def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
|
# def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
|
||||||
# note: if no context is given, cross-attention defaults to self-attention
|
# # note: if no context is given, cross-attention defaults to self-attention
|
||||||
if not isinstance(context, list):
|
# if not isinstance(context, list):
|
||||||
context = [context]
|
# context = [context]
|
||||||
b, c, h, w = x.shape
|
# b, c, h, w = x.shape
|
||||||
x_in = x
|
# x_in = x
|
||||||
x = self.norm(x)
|
# x = self.norm(x)
|
||||||
if not self.use_linear:
|
# if not self.use_linear:
|
||||||
x = self.proj_in(x)
|
# x = self.proj_in(x)
|
||||||
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
# x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||||
if self.use_linear:
|
# if self.use_linear:
|
||||||
x = self.proj_in(x)
|
# x = self.proj_in(x)
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
# for i, block in enumerate(self.transformer_blocks):
|
||||||
x = block(x, context=context[i])
|
# x = block(x, context=context[i])
|
||||||
if self.use_linear:
|
# if self.use_linear:
|
||||||
x = self.proj_out(x)
|
# x = self.proj_out(x)
|
||||||
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
# x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
||||||
if not self.use_linear:
|
# if not self.use_linear:
|
||||||
x = self.proj_out(x)
|
# x = self.proj_out(x)
|
||||||
return x + x_in
|
# return x + x_in
|
||||||
|
#
|
||||||
|
#
|
||||||
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
# class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||||
def __init__(self, *args, **kwargs):
|
# def __init__(self, *args, **kwargs):
|
||||||
torch.nn.GELU.__init__(self, *args, **kwargs)
|
# torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||||
def forward(self, x):
|
# def forward(self, x):
|
||||||
if devices.unet_needs_upcast:
|
# if devices.unet_needs_upcast:
|
||||||
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
# return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||||
else:
|
# else:
|
||||||
return torch.nn.GELU.forward(self, x)
|
# return torch.nn.GELU.forward(self, x)
|
||||||
|
#
|
||||||
|
#
|
||||||
ddpm_edit_hijack = None
|
# ddpm_edit_hijack = None
|
||||||
def hijack_ddpm_edit():
|
# def hijack_ddpm_edit():
|
||||||
global ddpm_edit_hijack
|
# global ddpm_edit_hijack
|
||||||
if not ddpm_edit_hijack:
|
# if not ddpm_edit_hijack:
|
||||||
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
|
# ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
|
||||||
|
#
|
||||||
|
#
|
||||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
# unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
||||||
CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
|
# CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
|
||||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||||
|
#
|
||||||
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
# if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
# CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
# CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
# CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||||
|
#
|
||||||
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
# first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||||
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
# first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||||
|
#
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
||||||
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
# CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
||||||
|
#
|
||||||
|
#
|
||||||
def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
# def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
||||||
if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
# if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
||||||
dtype = torch.float32
|
# dtype = torch.float32
|
||||||
else:
|
# else:
|
||||||
dtype = devices.dtype_unet
|
# dtype = devices.dtype_unet
|
||||||
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
# return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
||||||
|
#
|
||||||
|
#
|
||||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||||
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
# CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||||
|
|||||||
@@ -1,137 +1,137 @@
|
|||||||
import os
|
# import os
|
||||||
|
#
|
||||||
import torch
|
# import torch
|
||||||
|
#
|
||||||
from modules import shared, paths, sd_disable_initialization, devices
|
# from modules import shared, paths, sd_disable_initialization, devices
|
||||||
|
#
|
||||||
sd_configs_path = shared.sd_configs_path
|
# sd_configs_path = shared.sd_configs_path
|
||||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
# sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||||
sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
# sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
||||||
|
#
|
||||||
|
#
|
||||||
config_default = shared.sd_default_config
|
# config_default = shared.sd_default_config
|
||||||
# config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
# # config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
# config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
# config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||||
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
# config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||||
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
# config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||||
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
|
# config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
|
||||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
# config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
# config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
# config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
# config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
# config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
# config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||||
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
# config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||||
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
|
# config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
|
||||||
|
#
|
||||||
|
#
|
||||||
def is_using_v_parameterization_for_sd2(state_dict):
|
# def is_using_v_parameterization_for_sd2(state_dict):
|
||||||
"""
|
# """
|
||||||
Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
# Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
||||||
"""
|
# """
|
||||||
|
#
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
# import ldm.modules.diffusionmodules.openaimodel
|
||||||
|
#
|
||||||
device = devices.device
|
# device = devices.device
|
||||||
|
#
|
||||||
with sd_disable_initialization.DisableInitialization():
|
# with sd_disable_initialization.DisableInitialization():
|
||||||
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
# unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
||||||
use_checkpoint=False,
|
# use_checkpoint=False,
|
||||||
use_fp16=False,
|
# use_fp16=False,
|
||||||
image_size=32,
|
# image_size=32,
|
||||||
in_channels=4,
|
# in_channels=4,
|
||||||
out_channels=4,
|
# out_channels=4,
|
||||||
model_channels=320,
|
# model_channels=320,
|
||||||
attention_resolutions=[4, 2, 1],
|
# attention_resolutions=[4, 2, 1],
|
||||||
num_res_blocks=2,
|
# num_res_blocks=2,
|
||||||
channel_mult=[1, 2, 4, 4],
|
# channel_mult=[1, 2, 4, 4],
|
||||||
num_head_channels=64,
|
# num_head_channels=64,
|
||||||
use_spatial_transformer=True,
|
# use_spatial_transformer=True,
|
||||||
use_linear_in_transformer=True,
|
# use_linear_in_transformer=True,
|
||||||
transformer_depth=1,
|
# transformer_depth=1,
|
||||||
context_dim=1024,
|
# context_dim=1024,
|
||||||
legacy=False
|
# legacy=False
|
||||||
)
|
# )
|
||||||
unet.eval()
|
# unet.eval()
|
||||||
|
#
|
||||||
with torch.no_grad():
|
# with torch.no_grad():
|
||||||
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
# unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
||||||
unet.load_state_dict(unet_sd, strict=True)
|
# unet.load_state_dict(unet_sd, strict=True)
|
||||||
unet.to(device=device, dtype=devices.dtype_unet)
|
# unet.to(device=device, dtype=devices.dtype_unet)
|
||||||
|
#
|
||||||
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
# test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
||||||
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
# x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
||||||
|
#
|
||||||
with devices.autocast():
|
# with devices.autocast():
|
||||||
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
|
# out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
|
||||||
|
#
|
||||||
return out < -1
|
# return out < -1
|
||||||
|
#
|
||||||
|
#
|
||||||
def guess_model_config_from_state_dict(sd, filename):
|
# def guess_model_config_from_state_dict(sd, filename):
|
||||||
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
# sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
# diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
# sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||||
|
#
|
||||||
if "model.diffusion_model.x_embedder.proj.weight" in sd:
|
# if "model.diffusion_model.x_embedder.proj.weight" in sd:
|
||||||
return config_sd3
|
# return config_sd3
|
||||||
|
#
|
||||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
# if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||||
if diffusion_model_input.shape[1] == 9:
|
# if diffusion_model_input.shape[1] == 9:
|
||||||
return config_sdxl_inpainting
|
# return config_sdxl_inpainting
|
||||||
else:
|
# else:
|
||||||
return config_sdxl
|
# return config_sdxl
|
||||||
|
#
|
||||||
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
# if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||||
return config_sdxl_refiner
|
# return config_sdxl_refiner
|
||||||
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
# elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||||
return config_depth_model
|
# return config_depth_model
|
||||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||||
return config_unclip
|
# return config_unclip
|
||||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
|
# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
|
||||||
return config_unopenclip
|
# return config_unopenclip
|
||||||
|
#
|
||||||
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
# if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||||
if diffusion_model_input.shape[1] == 9:
|
# if diffusion_model_input.shape[1] == 9:
|
||||||
return config_sd2_inpainting
|
# return config_sd2_inpainting
|
||||||
# elif is_using_v_parameterization_for_sd2(sd):
|
# # elif is_using_v_parameterization_for_sd2(sd):
|
||||||
# return config_sd2v
|
# # return config_sd2v
|
||||||
else:
|
# else:
|
||||||
return config_sd2v
|
# return config_sd2v
|
||||||
|
#
|
||||||
if diffusion_model_input is not None:
|
# if diffusion_model_input is not None:
|
||||||
if diffusion_model_input.shape[1] == 9:
|
# if diffusion_model_input.shape[1] == 9:
|
||||||
return config_inpainting
|
# return config_inpainting
|
||||||
if diffusion_model_input.shape[1] == 8:
|
# if diffusion_model_input.shape[1] == 8:
|
||||||
return config_instruct_pix2pix
|
# return config_instruct_pix2pix
|
||||||
|
#
|
||||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
# if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||||
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
# if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
||||||
return config_alt_diffusion_m18
|
# return config_alt_diffusion_m18
|
||||||
return config_alt_diffusion
|
# return config_alt_diffusion
|
||||||
|
#
|
||||||
return config_default
|
# return config_default
|
||||||
|
#
|
||||||
|
#
|
||||||
def find_checkpoint_config(state_dict, info):
|
# def find_checkpoint_config(state_dict, info):
|
||||||
if info is None:
|
# if info is None:
|
||||||
return guess_model_config_from_state_dict(state_dict, "")
|
# return guess_model_config_from_state_dict(state_dict, "")
|
||||||
|
#
|
||||||
config = find_checkpoint_config_near_filename(info)
|
# config = find_checkpoint_config_near_filename(info)
|
||||||
if config is not None:
|
# if config is not None:
|
||||||
return config
|
# return config
|
||||||
|
#
|
||||||
return guess_model_config_from_state_dict(state_dict, info.filename)
|
# return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||||
|
#
|
||||||
|
#
|
||||||
def find_checkpoint_config_near_filename(info):
|
# def find_checkpoint_config_near_filename(info):
|
||||||
if info is None:
|
# if info is None:
|
||||||
return None
|
# return None
|
||||||
|
#
|
||||||
config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
# config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
||||||
if os.path.exists(config):
|
# if os.path.exists(config):
|
||||||
return config
|
# return config
|
||||||
|
#
|
||||||
return None
|
# return None
|
||||||
|
#
|
||||||
|
|||||||
@@ -1,115 +1,115 @@
|
|||||||
from __future__ import annotations
|
# from __future__ import annotations
|
||||||
|
#
|
||||||
import torch
|
# import torch
|
||||||
|
#
|
||||||
import sgm.models.diffusion
|
# import sgm.models.diffusion
|
||||||
import sgm.modules.diffusionmodules.denoiser_scaling
|
# import sgm.modules.diffusionmodules.denoiser_scaling
|
||||||
import sgm.modules.diffusionmodules.discretizer
|
# import sgm.modules.diffusionmodules.discretizer
|
||||||
from modules import devices, shared, prompt_parser
|
# from modules import devices, shared, prompt_parser
|
||||||
from modules import torch_utils
|
# from modules import torch_utils
|
||||||
|
#
|
||||||
from backend import memory_management
|
# from backend import memory_management
|
||||||
|
#
|
||||||
|
#
|
||||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
# def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||||
|
#
|
||||||
for embedder in self.conditioner.embedders:
|
# for embedder in self.conditioner.embedders:
|
||||||
embedder.ucg_rate = 0.0
|
# embedder.ucg_rate = 0.0
|
||||||
|
#
|
||||||
width = getattr(batch, 'width', 1024) or 1024
|
# width = getattr(batch, 'width', 1024) or 1024
|
||||||
height = getattr(batch, 'height', 1024) or 1024
|
# height = getattr(batch, 'height', 1024) or 1024
|
||||||
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
# is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||||
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
# aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||||
|
#
|
||||||
devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
|
# devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
|
||||||
|
#
|
||||||
sdxl_conds = {
|
# sdxl_conds = {
|
||||||
"txt": batch,
|
# "txt": batch,
|
||||||
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
# "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||||
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
# "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
||||||
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
# "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||||
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
# "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
||||||
}
|
# }
|
||||||
|
#
|
||||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
# force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
||||||
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
# c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
||||||
|
#
|
||||||
return c
|
# return c
|
||||||
|
#
|
||||||
|
#
|
||||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
|
# def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
|
||||||
if self.model.diffusion_model.in_channels == 9:
|
# if self.model.diffusion_model.in_channels == 9:
|
||||||
x = torch.cat([x] + cond['c_concat'], dim=1)
|
# x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||||
|
#
|
||||||
return self.model(x, t, cond, *args, **kwargs)
|
# return self.model(x, t, cond, *args, **kwargs)
|
||||||
|
#
|
||||||
|
#
|
||||||
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
# def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||||
return x
|
# return x
|
||||||
|
#
|
||||||
|
#
|
||||||
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
# sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||||
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
# sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||||||
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
# sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
||||||
|
#
|
||||||
|
#
|
||||||
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
# def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
||||||
res = []
|
# res = []
|
||||||
|
#
|
||||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
||||||
encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
# encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
||||||
res.append(encoded)
|
# res.append(encoded)
|
||||||
|
#
|
||||||
return torch.cat(res, dim=1)
|
# return torch.cat(res, dim=1)
|
||||||
|
#
|
||||||
|
#
|
||||||
def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
# def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
||||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
||||||
return embedder.tokenize(texts)
|
# return embedder.tokenize(texts)
|
||||||
|
#
|
||||||
raise AssertionError('no tokenizer available')
|
# raise AssertionError('no tokenizer available')
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
def process_texts(self, texts):
|
# def process_texts(self, texts):
|
||||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||||
return embedder.process_texts(texts)
|
# return embedder.process_texts(texts)
|
||||||
|
#
|
||||||
|
#
|
||||||
def get_target_prompt_token_count(self, token_count):
|
# def get_target_prompt_token_count(self, token_count):
|
||||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
||||||
return embedder.get_target_prompt_token_count(token_count)
|
# return embedder.get_target_prompt_token_count(token_count)
|
||||||
|
#
|
||||||
|
#
|
||||||
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
# # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||||
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
# sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||||
sgm.modules.GeneralConditioner.tokenize = tokenize
|
# sgm.modules.GeneralConditioner.tokenize = tokenize
|
||||||
sgm.modules.GeneralConditioner.process_texts = process_texts
|
# sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||||
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
# sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||||
|
#
|
||||||
|
#
|
||||||
def extend_sdxl(model):
|
# def extend_sdxl(model):
|
||||||
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
# """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||||
|
#
|
||||||
dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
# dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
||||||
model.model.diffusion_model.dtype = dtype
|
# model.model.diffusion_model.dtype = dtype
|
||||||
model.model.conditioning_key = 'crossattn'
|
# model.model.conditioning_key = 'crossattn'
|
||||||
model.cond_stage_key = 'txt'
|
# model.cond_stage_key = 'txt'
|
||||||
# model.cond_stage_model will be set in sd_hijack
|
# # model.cond_stage_model will be set in sd_hijack
|
||||||
|
#
|
||||||
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
# model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||||
|
#
|
||||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
# discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
# model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
||||||
|
#
|
||||||
model.conditioner.wrapped = torch.nn.Module()
|
# model.conditioner.wrapped = torch.nn.Module()
|
||||||
|
#
|
||||||
|
#
|
||||||
sgm.modules.attention.print = shared.ldm_print
|
# sgm.modules.attention.print = shared.ldm_print
|
||||||
sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
# sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||||
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
# sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||||
sgm.modules.encoders.modules.print = shared.ldm_print
|
# sgm.modules.encoders.modules.print = shared.ldm_print
|
||||||
|
#
|
||||||
# this gets the code to load the vanilla attention that we override
|
# # this gets the code to load the vanilla attention that we override
|
||||||
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
# sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||||
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
# sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
||||||
|
|||||||
@@ -35,9 +35,7 @@ def refresh_vae_list():
|
|||||||
|
|
||||||
|
|
||||||
def cross_attention_optimizations():
|
def cross_attention_optimizations():
|
||||||
import modules.sd_hijack
|
return ["Automatic"]
|
||||||
|
|
||||||
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
|
|
||||||
|
|
||||||
|
|
||||||
def sd_unet_items():
|
def sd_unet_items():
|
||||||
|
|||||||
Reference in New Issue
Block a user