mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 17:09:49 +00:00
Free WebUI from its Prison
Congratulations WebUI. Say Hello to freedom.
This commit is contained in:
@@ -24,7 +24,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
|
||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import Any
|
||||
@@ -725,7 +724,7 @@ class Api:
|
||||
|
||||
def get_sd_models(self):
|
||||
import modules.sd_models as sd_models
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename} for x in sd_models.checkpoints_list.values()]
|
||||
|
||||
def get_sd_vaes(self):
|
||||
import modules.sd_vae as sd_vae
|
||||
|
||||
@@ -9,7 +9,7 @@ import modules.textual_inversion.dataset
|
||||
import torch
|
||||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
from backend.nn.unet import default
|
||||
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||
from modules.textual_inversion import textual_inversion, saving_settings
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
@@ -1,25 +1,12 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
|
||||
from threading import Thread
|
||||
|
||||
from modules.timer import startup_timer
|
||||
|
||||
|
||||
class HiddenPrints:
|
||||
def __enter__(self):
|
||||
self._original_stdout = sys.stdout
|
||||
sys.stdout = open(os.devnull, 'w')
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
sys.stdout.close()
|
||||
sys.stdout = self._original_stdout
|
||||
|
||||
|
||||
def imports():
|
||||
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
@@ -35,16 +22,8 @@ def imports():
|
||||
import gradio # noqa: F401
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
with HiddenPrints():
|
||||
from modules import paths, timer, import_hook, errors # noqa: F401
|
||||
startup_timer.record("setup paths")
|
||||
|
||||
import ldm.modules.encoders.modules # noqa: F401
|
||||
import ldm.modules.diffusionmodules.model
|
||||
startup_timer.record("import ldm")
|
||||
|
||||
import sgm.modules.encoders.modules # noqa: F401
|
||||
startup_timer.record("import sgm")
|
||||
from modules import paths, timer, import_hook, errors # noqa: F401
|
||||
startup_timer.record("setup paths")
|
||||
|
||||
from modules import shared_init
|
||||
shared_init.initialize()
|
||||
@@ -137,15 +116,6 @@ def initialize_rest(*, reload_script_modules=False):
|
||||
sd_vae.refresh_vae_list()
|
||||
startup_timer.record("refresh VAE")
|
||||
|
||||
from modules import textual_inversion
|
||||
textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||
startup_timer.record("refresh textual inversion templates")
|
||||
|
||||
from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
|
||||
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
|
||||
sd_hijack.list_optimizers()
|
||||
startup_timer.record("scripts list_optimizers")
|
||||
|
||||
from modules import sd_unet
|
||||
sd_unet.list_unets()
|
||||
startup_timer.record("scripts list_unets")
|
||||
|
||||
@@ -391,15 +391,15 @@ def prepare_environment():
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||
|
||||
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
# stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
# stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
# stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
# stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
@@ -456,8 +456,8 @@ def prepare_environment():
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
|
||||
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
# git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
# git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
|
||||
@@ -2,45 +2,15 @@ import os
|
||||
import sys
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
|
||||
|
||||
import modules.safe # noqa: F401
|
||||
|
||||
|
||||
def mute_sdxl_imports():
|
||||
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
module = Dummy()
|
||||
module.LPIPS = None
|
||||
sys.modules['taming.modules.losses.lpips'] = module
|
||||
|
||||
module = Dummy()
|
||||
module.StableDataModuleFromConfig = None
|
||||
sys.modules['sgm.data'] = module
|
||||
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
sys.path.insert(0, script_path)
|
||||
|
||||
# search for directory of stable diffusion in following places
|
||||
sd_path = None
|
||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
|
||||
for possible_sd_path in possible_sd_paths:
|
||||
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
||||
sd_path = os.path.abspath(possible_sd_path)
|
||||
break
|
||||
|
||||
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
||||
|
||||
mute_sdxl_imports()
|
||||
sd_path = os.path.dirname(__file__)
|
||||
|
||||
path_dirs = [
|
||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
(os.path.join(sd_path, '../huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
|
||||
(os.path.join(sd_path, '../repositories/BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../repositories/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
(os.path.join(sd_path, '../repositories/huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
|
||||
]
|
||||
|
||||
paths = {}
|
||||
@@ -53,13 +23,6 @@ for d, must_exist, what, options in path_dirs:
|
||||
d = os.path.abspath(d)
|
||||
if "atstart" in options:
|
||||
sys.path.insert(0, d)
|
||||
elif "sgm" in options:
|
||||
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
|
||||
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
|
||||
|
||||
sys.path.insert(0, d)
|
||||
import sgm # noqa: F401
|
||||
sys.path.pop(0)
|
||||
else:
|
||||
sys.path.append(d)
|
||||
paths[what] = d
|
||||
|
||||
@@ -28,8 +28,6 @@ import modules.images as images
|
||||
import modules.styles
|
||||
import modules.sd_models as sd_models
|
||||
import modules.sd_vae as sd_vae
|
||||
from ldm.data.util import AddMiDaS
|
||||
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
||||
|
||||
from einops import repeat, rearrange
|
||||
from blendmodes.blend import blendLayers, BlendType
|
||||
@@ -295,23 +293,7 @@ class StableDiffusionProcessing:
|
||||
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
||||
|
||||
def depth2img_image_conditioning(self, source_image):
|
||||
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
||||
transformer = AddMiDaS(model_type="dpt_hybrid")
|
||||
transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
|
||||
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
||||
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
||||
|
||||
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
||||
conditioning = torch.nn.functional.interpolate(
|
||||
self.sd_model.depth_model(midas_in),
|
||||
size=conditioning_image.shape[2:],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
(depth_min, depth_max) = torch.aminmax(conditioning)
|
||||
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
||||
return conditioning
|
||||
raise NotImplementedError('NotImplementedError: depth2img_image_conditioning')
|
||||
|
||||
def edit_image_conditioning(self, source_image):
|
||||
conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
|
||||
@@ -368,11 +350,6 @@ class StableDiffusionProcessing:
|
||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
|
||||
source_image = devices.cond_cast_float(source_image)
|
||||
|
||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
||||
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
||||
return self.depth2img_image_conditioning(source_image)
|
||||
|
||||
if self.sd_model.cond_stage_key == "edit":
|
||||
return self.edit_image_conditioning(source_image)
|
||||
|
||||
|
||||
390
modules/safe.py
390
modules/safe.py
@@ -1,195 +1,195 @@
|
||||
# this code is adapted from the script contributed by anon from /h/
|
||||
|
||||
import pickle
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import numpy
|
||||
import _codecs
|
||||
import zipfile
|
||||
import re
|
||||
|
||||
|
||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||
from modules import errors
|
||||
|
||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||
|
||||
def encode(*args):
|
||||
out = _codecs.encode(*args)
|
||||
return out
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
extra_handler = None
|
||||
|
||||
def persistent_load(self, saved_id):
|
||||
assert saved_id[0] == 'storage'
|
||||
|
||||
try:
|
||||
return TypedStorage(_internal=True)
|
||||
except TypeError:
|
||||
return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
|
||||
|
||||
def find_class(self, module, name):
|
||||
if self.extra_handler is not None:
|
||||
res = self.extra_handler(module, name)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return getattr(collections, name)
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
||||
return getattr(torch._utils, name)
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
|
||||
return getattr(torch, name)
|
||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||
return getattr(torch.nn.modules.container, name)
|
||||
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
||||
return getattr(numpy.core.multiarray, name)
|
||||
if module == 'numpy' and name in ['dtype', 'ndarray']:
|
||||
return getattr(numpy, name)
|
||||
if module == '_codecs' and name == 'encode':
|
||||
return encode
|
||||
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
||||
import pytorch_lightning.callbacks
|
||||
return pytorch_lightning.callbacks.model_checkpoint
|
||||
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
||||
import pytorch_lightning.callbacks.model_checkpoint
|
||||
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
||||
if module == "__builtin__" and name == 'set':
|
||||
return set
|
||||
|
||||
# Forbid everything else.
|
||||
raise Exception(f"global '{module}/{name}' is forbidden")
|
||||
|
||||
|
||||
# Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/<number>'
|
||||
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$")
|
||||
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||
|
||||
def check_zip_filenames(filename, names):
|
||||
for name in names:
|
||||
if allowed_zip_names_re.match(name):
|
||||
continue
|
||||
|
||||
raise Exception(f"bad file inside {filename}: {name}")
|
||||
|
||||
|
||||
def check_pt(filename, extra_handler):
|
||||
try:
|
||||
|
||||
# new pytorch format is a zip file
|
||||
with zipfile.ZipFile(filename) as z:
|
||||
check_zip_filenames(filename, z.namelist())
|
||||
|
||||
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||
if len(data_pkl_filenames) == 0:
|
||||
raise Exception(f"data.pkl not found in {filename}")
|
||||
if len(data_pkl_filenames) > 1:
|
||||
raise Exception(f"Multiple data.pkl found in {filename}")
|
||||
with z.open(data_pkl_filenames[0]) as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
unpickler.load()
|
||||
|
||||
except zipfile.BadZipfile:
|
||||
|
||||
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
||||
with open(filename, "rb") as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
for _ in range(5):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||
|
||||
|
||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
"""
|
||||
this function is intended to be used by extensions that want to load models with
|
||||
some extra classes in them that the usual unpickler would find suspicious.
|
||||
|
||||
Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||
and returns that field's value:
|
||||
|
||||
```python
|
||||
def extra(module, name):
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return collections.OrderedDict
|
||||
|
||||
return None
|
||||
|
||||
safe.load_with_extra('model.pt', extra_handler=extra)
|
||||
```
|
||||
|
||||
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
||||
definitely unsafe.
|
||||
"""
|
||||
|
||||
from modules import shared
|
||||
|
||||
try:
|
||||
if not shared.cmd_opts.disable_safe_unpickle:
|
||||
check_pt(filename, extra_handler)
|
||||
|
||||
except pickle.UnpicklingError:
|
||||
errors.report(
|
||||
f"Error verifying pickled file from {filename}\n"
|
||||
"-----> !!!! The file is most likely corrupted !!!! <-----\n"
|
||||
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
except Exception:
|
||||
errors.report(
|
||||
f"Error verifying pickled file from {filename}\n"
|
||||
f"The file may be malicious, so the program is not going to read it.\n"
|
||||
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
return unsafe_torch_load(filename, *args, **kwargs)
|
||||
|
||||
|
||||
class Extra:
|
||||
"""
|
||||
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
||||
(because it's not your code making the torch.load call). The intended use is like this:
|
||||
|
||||
```
|
||||
import torch
|
||||
from modules import safe
|
||||
|
||||
def handler(module, name):
|
||||
if module == 'torch' and name in ['float64', 'float16']:
|
||||
return getattr(torch, name)
|
||||
|
||||
return None
|
||||
|
||||
with safe.Extra(handler):
|
||||
x = torch.load('model.pt')
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, handler):
|
||||
self.handler = handler
|
||||
|
||||
def __enter__(self):
|
||||
global global_extra_handler
|
||||
|
||||
assert global_extra_handler is None, 'already inside an Extra() block'
|
||||
global_extra_handler = self.handler
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global global_extra_handler
|
||||
|
||||
global_extra_handler = None
|
||||
|
||||
|
||||
unsafe_torch_load = torch.load
|
||||
global_extra_handler = None
|
||||
# # this code is adapted from the script contributed by anon from /h/
|
||||
#
|
||||
# import pickle
|
||||
# import collections
|
||||
#
|
||||
# import torch
|
||||
# import numpy
|
||||
# import _codecs
|
||||
# import zipfile
|
||||
# import re
|
||||
#
|
||||
#
|
||||
# # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||
# from modules import errors
|
||||
#
|
||||
# TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||
#
|
||||
# def encode(*args):
|
||||
# out = _codecs.encode(*args)
|
||||
# return out
|
||||
#
|
||||
#
|
||||
# class RestrictedUnpickler(pickle.Unpickler):
|
||||
# extra_handler = None
|
||||
#
|
||||
# def persistent_load(self, saved_id):
|
||||
# assert saved_id[0] == 'storage'
|
||||
#
|
||||
# try:
|
||||
# return TypedStorage(_internal=True)
|
||||
# except TypeError:
|
||||
# return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
|
||||
#
|
||||
# def find_class(self, module, name):
|
||||
# if self.extra_handler is not None:
|
||||
# res = self.extra_handler(module, name)
|
||||
# if res is not None:
|
||||
# return res
|
||||
#
|
||||
# if module == 'collections' and name == 'OrderedDict':
|
||||
# return getattr(collections, name)
|
||||
# if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
||||
# return getattr(torch._utils, name)
|
||||
# if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
|
||||
# return getattr(torch, name)
|
||||
# if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||
# return getattr(torch.nn.modules.container, name)
|
||||
# if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
||||
# return getattr(numpy.core.multiarray, name)
|
||||
# if module == 'numpy' and name in ['dtype', 'ndarray']:
|
||||
# return getattr(numpy, name)
|
||||
# if module == '_codecs' and name == 'encode':
|
||||
# return encode
|
||||
# if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
||||
# import pytorch_lightning.callbacks
|
||||
# return pytorch_lightning.callbacks.model_checkpoint
|
||||
# if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
||||
# import pytorch_lightning.callbacks.model_checkpoint
|
||||
# return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
||||
# if module == "__builtin__" and name == 'set':
|
||||
# return set
|
||||
#
|
||||
# # Forbid everything else.
|
||||
# raise Exception(f"global '{module}/{name}' is forbidden")
|
||||
#
|
||||
#
|
||||
# # Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/<number>'
|
||||
# allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$")
|
||||
# data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||
#
|
||||
# def check_zip_filenames(filename, names):
|
||||
# for name in names:
|
||||
# if allowed_zip_names_re.match(name):
|
||||
# continue
|
||||
#
|
||||
# raise Exception(f"bad file inside {filename}: {name}")
|
||||
#
|
||||
#
|
||||
# def check_pt(filename, extra_handler):
|
||||
# try:
|
||||
#
|
||||
# # new pytorch format is a zip file
|
||||
# with zipfile.ZipFile(filename) as z:
|
||||
# check_zip_filenames(filename, z.namelist())
|
||||
#
|
||||
# # find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||
# data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||
# if len(data_pkl_filenames) == 0:
|
||||
# raise Exception(f"data.pkl not found in {filename}")
|
||||
# if len(data_pkl_filenames) > 1:
|
||||
# raise Exception(f"Multiple data.pkl found in {filename}")
|
||||
# with z.open(data_pkl_filenames[0]) as file:
|
||||
# unpickler = RestrictedUnpickler(file)
|
||||
# unpickler.extra_handler = extra_handler
|
||||
# unpickler.load()
|
||||
#
|
||||
# except zipfile.BadZipfile:
|
||||
#
|
||||
# # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
||||
# with open(filename, "rb") as file:
|
||||
# unpickler = RestrictedUnpickler(file)
|
||||
# unpickler.extra_handler = extra_handler
|
||||
# for _ in range(5):
|
||||
# unpickler.load()
|
||||
#
|
||||
#
|
||||
# def load(filename, *args, **kwargs):
|
||||
# return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||
#
|
||||
#
|
||||
# def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
# """
|
||||
# this function is intended to be used by extensions that want to load models with
|
||||
# some extra classes in them that the usual unpickler would find suspicious.
|
||||
#
|
||||
# Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||
# and returns that field's value:
|
||||
#
|
||||
# ```python
|
||||
# def extra(module, name):
|
||||
# if module == 'collections' and name == 'OrderedDict':
|
||||
# return collections.OrderedDict
|
||||
#
|
||||
# return None
|
||||
#
|
||||
# safe.load_with_extra('model.pt', extra_handler=extra)
|
||||
# ```
|
||||
#
|
||||
# The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
||||
# definitely unsafe.
|
||||
# """
|
||||
#
|
||||
# from modules import shared
|
||||
#
|
||||
# try:
|
||||
# if not shared.cmd_opts.disable_safe_unpickle:
|
||||
# check_pt(filename, extra_handler)
|
||||
#
|
||||
# except pickle.UnpicklingError:
|
||||
# errors.report(
|
||||
# f"Error verifying pickled file from {filename}\n"
|
||||
# "-----> !!!! The file is most likely corrupted !!!! <-----\n"
|
||||
# "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
|
||||
# exc_info=True,
|
||||
# )
|
||||
# return None
|
||||
# except Exception:
|
||||
# errors.report(
|
||||
# f"Error verifying pickled file from {filename}\n"
|
||||
# f"The file may be malicious, so the program is not going to read it.\n"
|
||||
# f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
|
||||
# exc_info=True,
|
||||
# )
|
||||
# return None
|
||||
#
|
||||
# return unsafe_torch_load(filename, *args, **kwargs)
|
||||
#
|
||||
#
|
||||
# class Extra:
|
||||
# """
|
||||
# A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
||||
# (because it's not your code making the torch.load call). The intended use is like this:
|
||||
#
|
||||
# ```
|
||||
# import torch
|
||||
# from modules import safe
|
||||
#
|
||||
# def handler(module, name):
|
||||
# if module == 'torch' and name in ['float64', 'float16']:
|
||||
# return getattr(torch, name)
|
||||
#
|
||||
# return None
|
||||
#
|
||||
# with safe.Extra(handler):
|
||||
# x = torch.load('model.pt')
|
||||
# ```
|
||||
# """
|
||||
#
|
||||
# def __init__(self, handler):
|
||||
# self.handler = handler
|
||||
#
|
||||
# def __enter__(self):
|
||||
# global global_extra_handler
|
||||
#
|
||||
# assert global_extra_handler is None, 'already inside an Extra() block'
|
||||
# global_extra_handler = self.handler
|
||||
#
|
||||
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# global global_extra_handler
|
||||
#
|
||||
# global_extra_handler = None
|
||||
#
|
||||
#
|
||||
# unsafe_torch_load = torch.load
|
||||
# global_extra_handler = None
|
||||
|
||||
@@ -1,232 +1,232 @@
|
||||
import ldm.modules.encoders.modules
|
||||
import open_clip
|
||||
import torch
|
||||
import transformers.utils.hub
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
||||
class ReplaceHelper:
|
||||
def __init__(self):
|
||||
self.replaced = []
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
original = getattr(obj, field, None)
|
||||
if original is None:
|
||||
return None
|
||||
|
||||
self.replaced.append((obj, field, original))
|
||||
setattr(obj, field, func)
|
||||
|
||||
return original
|
||||
|
||||
def restore(self):
|
||||
for obj, field, original in self.replaced:
|
||||
setattr(obj, field, original)
|
||||
|
||||
self.replaced.clear()
|
||||
|
||||
|
||||
class DisableInitialization(ReplaceHelper):
|
||||
"""
|
||||
When an object of this class enters a `with` block, it starts:
|
||||
- preventing torch's layer initialization functions from working
|
||||
- changes CLIP and OpenCLIP to not download model weights
|
||||
- changes CLIP to not make requests to check if there is a new version of a file you already have
|
||||
|
||||
When it leaves the block, it reverts everything to how it was before.
|
||||
|
||||
Use it like this:
|
||||
```
|
||||
with DisableInitialization():
|
||||
do_things()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, disable_clip=True):
|
||||
super().__init__()
|
||||
self.disable_clip = disable_clip
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
original = getattr(obj, field, None)
|
||||
if original is None:
|
||||
return None
|
||||
|
||||
self.replaced.append((obj, field, original))
|
||||
setattr(obj, field, func)
|
||||
|
||||
return original
|
||||
|
||||
def __enter__(self):
|
||||
def do_nothing(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
|
||||
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
||||
|
||||
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
||||
res.name_or_path = pretrained_model_name_or_path
|
||||
return res
|
||||
|
||||
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
||||
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
||||
return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
|
||||
|
||||
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
||||
|
||||
# this file is always 404, prevent making request
|
||||
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
||||
return None
|
||||
|
||||
try:
|
||||
res = original(url, *args, local_files_only=True, **kwargs)
|
||||
if res is None:
|
||||
res = original(url, *args, local_files_only=False, **kwargs)
|
||||
return res
|
||||
except Exception:
|
||||
return original(url, *args, local_files_only=False, **kwargs)
|
||||
|
||||
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
|
||||
|
||||
def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
|
||||
|
||||
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
||||
|
||||
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||
|
||||
if self.disable_clip:
|
||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.restore()
|
||||
|
||||
|
||||
class InitializeOnMeta(ReplaceHelper):
|
||||
"""
|
||||
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
|
||||
which results in those parameters having no values and taking no memory. model.to() will be broken and
|
||||
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
|
||||
|
||||
Usage:
|
||||
```
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
```
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
return
|
||||
|
||||
def set_device(x):
|
||||
x["device"] = "meta"
|
||||
return x
|
||||
|
||||
linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
|
||||
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
|
||||
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
|
||||
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.restore()
|
||||
|
||||
|
||||
class LoadStateDictOnMeta(ReplaceHelper):
|
||||
"""
|
||||
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
|
||||
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
|
||||
Meant to be used together with InitializeOnMeta above.
|
||||
|
||||
Usage:
|
||||
```
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, state_dict, device, weight_dtype_conversion=None):
|
||||
super().__init__()
|
||||
self.state_dict = state_dict
|
||||
self.device = device
|
||||
self.weight_dtype_conversion = weight_dtype_conversion or {}
|
||||
self.default_dtype = self.weight_dtype_conversion.get('')
|
||||
|
||||
def get_weight_dtype(self, key):
|
||||
key_first_term, _ = key.split('.', 1)
|
||||
return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
|
||||
|
||||
def __enter__(self):
|
||||
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
return
|
||||
|
||||
sd = self.state_dict
|
||||
device = self.device
|
||||
|
||||
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
|
||||
used_param_keys = []
|
||||
|
||||
for name, param in module._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
key = prefix + name
|
||||
sd_param = sd.pop(key, None)
|
||||
if sd_param is not None:
|
||||
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
||||
used_param_keys.append(key)
|
||||
|
||||
if param.is_meta:
|
||||
dtype = sd_param.dtype if sd_param is not None else param.dtype
|
||||
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
|
||||
|
||||
for name in module._buffers:
|
||||
key = prefix + name
|
||||
|
||||
sd_param = sd.pop(key, None)
|
||||
if sd_param is not None:
|
||||
state_dict[key] = sd_param
|
||||
used_param_keys.append(key)
|
||||
|
||||
original(module, state_dict, prefix, *args, **kwargs)
|
||||
|
||||
for key in used_param_keys:
|
||||
state_dict.pop(key, None)
|
||||
|
||||
def load_state_dict(original, module, state_dict, strict=True):
|
||||
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
|
||||
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
|
||||
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
|
||||
|
||||
In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
|
||||
|
||||
The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
|
||||
the function and does not call the original) the state dict will just fail to load because weights
|
||||
would be on the meta device.
|
||||
"""
|
||||
|
||||
if state_dict is sd:
|
||||
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||
|
||||
original(module, state_dict, strict=strict)
|
||||
|
||||
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
|
||||
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
|
||||
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
||||
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
||||
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
||||
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
|
||||
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.restore()
|
||||
# import ldm.modules.encoders.modules
|
||||
# import open_clip
|
||||
# import torch
|
||||
# import transformers.utils.hub
|
||||
#
|
||||
# from modules import shared
|
||||
#
|
||||
#
|
||||
# class ReplaceHelper:
|
||||
# def __init__(self):
|
||||
# self.replaced = []
|
||||
#
|
||||
# def replace(self, obj, field, func):
|
||||
# original = getattr(obj, field, None)
|
||||
# if original is None:
|
||||
# return None
|
||||
#
|
||||
# self.replaced.append((obj, field, original))
|
||||
# setattr(obj, field, func)
|
||||
#
|
||||
# return original
|
||||
#
|
||||
# def restore(self):
|
||||
# for obj, field, original in self.replaced:
|
||||
# setattr(obj, field, original)
|
||||
#
|
||||
# self.replaced.clear()
|
||||
#
|
||||
#
|
||||
# class DisableInitialization(ReplaceHelper):
|
||||
# """
|
||||
# When an object of this class enters a `with` block, it starts:
|
||||
# - preventing torch's layer initialization functions from working
|
||||
# - changes CLIP and OpenCLIP to not download model weights
|
||||
# - changes CLIP to not make requests to check if there is a new version of a file you already have
|
||||
#
|
||||
# When it leaves the block, it reverts everything to how it was before.
|
||||
#
|
||||
# Use it like this:
|
||||
# ```
|
||||
# with DisableInitialization():
|
||||
# do_things()
|
||||
# ```
|
||||
# """
|
||||
#
|
||||
# def __init__(self, disable_clip=True):
|
||||
# super().__init__()
|
||||
# self.disable_clip = disable_clip
|
||||
#
|
||||
# def replace(self, obj, field, func):
|
||||
# original = getattr(obj, field, None)
|
||||
# if original is None:
|
||||
# return None
|
||||
#
|
||||
# self.replaced.append((obj, field, original))
|
||||
# setattr(obj, field, func)
|
||||
#
|
||||
# return original
|
||||
#
|
||||
# def __enter__(self):
|
||||
# def do_nothing(*args, **kwargs):
|
||||
# pass
|
||||
#
|
||||
# def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
|
||||
# return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
||||
#
|
||||
# def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
||||
# res.name_or_path = pretrained_model_name_or_path
|
||||
# return res
|
||||
#
|
||||
# def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
||||
# args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
||||
# return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
|
||||
#
|
||||
# def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
||||
#
|
||||
# # this file is always 404, prevent making request
|
||||
# if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
||||
# return None
|
||||
#
|
||||
# try:
|
||||
# res = original(url, *args, local_files_only=True, **kwargs)
|
||||
# if res is None:
|
||||
# res = original(url, *args, local_files_only=False, **kwargs)
|
||||
# return res
|
||||
# except Exception:
|
||||
# return original(url, *args, local_files_only=False, **kwargs)
|
||||
#
|
||||
# def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||
# return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
|
||||
#
|
||||
# def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
# return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
|
||||
#
|
||||
# def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
# return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
||||
#
|
||||
# self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||
# self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||
# self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||
#
|
||||
# if self.disable_clip:
|
||||
# self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
# self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
# self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
# self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
# self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
# self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
#
|
||||
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# self.restore()
|
||||
#
|
||||
#
|
||||
# class InitializeOnMeta(ReplaceHelper):
|
||||
# """
|
||||
# Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
|
||||
# which results in those parameters having no values and taking no memory. model.to() will be broken and
|
||||
# will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
|
||||
#
|
||||
# Usage:
|
||||
# ```
|
||||
# with sd_disable_initialization.InitializeOnMeta():
|
||||
# sd_model = instantiate_from_config(sd_config.model)
|
||||
# ```
|
||||
# """
|
||||
#
|
||||
# def __enter__(self):
|
||||
# if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
# return
|
||||
#
|
||||
# def set_device(x):
|
||||
# x["device"] = "meta"
|
||||
# return x
|
||||
#
|
||||
# linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
|
||||
# conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
|
||||
# mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
|
||||
# self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
|
||||
#
|
||||
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# self.restore()
|
||||
#
|
||||
#
|
||||
# class LoadStateDictOnMeta(ReplaceHelper):
|
||||
# """
|
||||
# Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
|
||||
# As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
|
||||
# Meant to be used together with InitializeOnMeta above.
|
||||
#
|
||||
# Usage:
|
||||
# ```
|
||||
# with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
|
||||
# model.load_state_dict(state_dict, strict=False)
|
||||
# ```
|
||||
# """
|
||||
#
|
||||
# def __init__(self, state_dict, device, weight_dtype_conversion=None):
|
||||
# super().__init__()
|
||||
# self.state_dict = state_dict
|
||||
# self.device = device
|
||||
# self.weight_dtype_conversion = weight_dtype_conversion or {}
|
||||
# self.default_dtype = self.weight_dtype_conversion.get('')
|
||||
#
|
||||
# def get_weight_dtype(self, key):
|
||||
# key_first_term, _ = key.split('.', 1)
|
||||
# return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
|
||||
#
|
||||
# def __enter__(self):
|
||||
# if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
# return
|
||||
#
|
||||
# sd = self.state_dict
|
||||
# device = self.device
|
||||
#
|
||||
# def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
|
||||
# used_param_keys = []
|
||||
#
|
||||
# for name, param in module._parameters.items():
|
||||
# if param is None:
|
||||
# continue
|
||||
#
|
||||
# key = prefix + name
|
||||
# sd_param = sd.pop(key, None)
|
||||
# if sd_param is not None:
|
||||
# state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
||||
# used_param_keys.append(key)
|
||||
#
|
||||
# if param.is_meta:
|
||||
# dtype = sd_param.dtype if sd_param is not None else param.dtype
|
||||
# module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
|
||||
#
|
||||
# for name in module._buffers:
|
||||
# key = prefix + name
|
||||
#
|
||||
# sd_param = sd.pop(key, None)
|
||||
# if sd_param is not None:
|
||||
# state_dict[key] = sd_param
|
||||
# used_param_keys.append(key)
|
||||
#
|
||||
# original(module, state_dict, prefix, *args, **kwargs)
|
||||
#
|
||||
# for key in used_param_keys:
|
||||
# state_dict.pop(key, None)
|
||||
#
|
||||
# def load_state_dict(original, module, state_dict, strict=True):
|
||||
# """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
|
||||
# because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
|
||||
# all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
|
||||
#
|
||||
# In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
|
||||
#
|
||||
# The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
|
||||
# the function and does not call the original) the state dict will just fail to load because weights
|
||||
# would be on the meta device.
|
||||
# """
|
||||
#
|
||||
# if state_dict is sd:
|
||||
# state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||
#
|
||||
# original(module, state_dict, strict=strict)
|
||||
#
|
||||
# module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
|
||||
# module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
|
||||
# linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
||||
# conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
||||
# mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
||||
# layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
|
||||
# group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
|
||||
#
|
||||
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# self.restore()
|
||||
|
||||
@@ -1,124 +1,3 @@
|
||||
import torch
|
||||
from torch.nn.functional import silu
|
||||
from types import MethodType
|
||||
|
||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
import ldm.models.diffusion.ddpm
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
import ldm.modules.encoders.modules
|
||||
|
||||
import sgm.modules.attention
|
||||
import sgm.modules.diffusionmodules.model
|
||||
import sgm.modules.diffusionmodules.openaimodel
|
||||
import sgm.modules.encoders.modules
|
||||
|
||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
# new memory efficient cross attention blocks do not support hypernets and we already
|
||||
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
||||
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||
|
||||
# silence new console spam from SD2
|
||||
ldm.modules.attention.print = shared.ldm_print
|
||||
ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
ldm.util.print = shared.ldm_print
|
||||
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||
|
||||
optimizers = []
|
||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||
|
||||
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
|
||||
|
||||
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
|
||||
|
||||
|
||||
def list_optimizers():
|
||||
new_optimizers = script_callbacks.list_optimizers_callback()
|
||||
|
||||
new_optimizers = [x for x in new_optimizers if x.is_available()]
|
||||
|
||||
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
||||
|
||||
optimizers.clear()
|
||||
optimizers.extend(new_optimizers)
|
||||
|
||||
|
||||
def apply_optimizations(option=None):
|
||||
return
|
||||
|
||||
|
||||
def undo_optimizations():
|
||||
return
|
||||
|
||||
|
||||
def fix_checkpoint():
|
||||
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||
checkpoints to be added when not training (there's a warning)"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def weighted_loss(sd_model, pred, target, mean=True):
|
||||
#Calculate the weight normally, but ignore the mean
|
||||
loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||
|
||||
#Check if we have weights available
|
||||
weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||
if weight is not None:
|
||||
loss *= weight
|
||||
|
||||
#Return the loss, as mean if specified
|
||||
return loss.mean() if mean else loss
|
||||
|
||||
def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
||||
try:
|
||||
#Temporarily append weights to a place accessible during loss calc
|
||||
sd_model._custom_loss_weight = w
|
||||
|
||||
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||
if not hasattr(sd_model, '_old_get_loss'):
|
||||
sd_model._old_get_loss = sd_model.get_loss
|
||||
sd_model.get_loss = MethodType(weighted_loss, sd_model)
|
||||
|
||||
#Run the standard forward function, but with the patched 'get_loss'
|
||||
return sd_model.forward(x, c, *args, **kwargs)
|
||||
finally:
|
||||
try:
|
||||
#Delete temporary weights if appended
|
||||
del sd_model._custom_loss_weight
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
#If we have an old loss function, reset the loss function to the original one
|
||||
if hasattr(sd_model, '_old_get_loss'):
|
||||
sd_model.get_loss = sd_model._old_get_loss
|
||||
del sd_model._old_get_loss
|
||||
|
||||
def apply_weighted_forward(sd_model):
|
||||
#Add new function 'weighted_forward' that can be called to calc weighted loss
|
||||
sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
|
||||
|
||||
def undo_weighted_forward(sd_model):
|
||||
try:
|
||||
del sd_model.weighted_forward
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
fixes = None
|
||||
layers = None
|
||||
@@ -156,74 +35,234 @@ class StableDiffusionModelHijack:
|
||||
pass
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.embeddings = embeddings
|
||||
self.textual_inversion_key = textual_inversion_key
|
||||
self.weight = self.wrapped.weight
|
||||
|
||||
def forward(self, input_ids):
|
||||
batch_fixes = self.embeddings.fixes
|
||||
self.embeddings.fixes = None
|
||||
|
||||
inputs_embeds = self.wrapped(input_ids)
|
||||
|
||||
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
vecs = []
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||
emb = devices.cond_cast_unet(vec)
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
||||
|
||||
vecs.append(tensor)
|
||||
|
||||
return torch.stack(vecs)
|
||||
|
||||
|
||||
class TextualInversionEmbeddings(torch.nn.Embedding):
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
|
||||
super().__init__(num_embeddings, embedding_dim, **kwargs)
|
||||
|
||||
self.embeddings = model_hijack
|
||||
self.textual_inversion_key = textual_inversion_key
|
||||
|
||||
@property
|
||||
def wrapped(self):
|
||||
return super().forward
|
||||
|
||||
def forward(self, input_ids):
|
||||
return EmbeddingsWithFixes.forward(self, input_ids)
|
||||
|
||||
|
||||
def add_circular_option_to_conv_2d():
|
||||
conv2d_constructor = torch.nn.Conv2d.__init__
|
||||
|
||||
def conv2d_constructor_circular(self, *args, **kwargs):
|
||||
return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
||||
|
||||
torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
||||
|
||||
|
||||
model_hijack = StableDiffusionModelHijack()
|
||||
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
"""
|
||||
Fix register buffer bug for Mac OS.
|
||||
"""
|
||||
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != devices.device:
|
||||
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
||||
|
||||
setattr(self, name, attr)
|
||||
|
||||
|
||||
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
||||
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
||||
# import torch
|
||||
# from torch.nn.functional import silu
|
||||
# from types import MethodType
|
||||
#
|
||||
# from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
||||
# from modules.hypernetworks import hypernetwork
|
||||
# from modules.shared import cmd_opts
|
||||
# from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
||||
#
|
||||
# import ldm.modules.attention
|
||||
# import ldm.modules.diffusionmodules.model
|
||||
# import ldm.modules.diffusionmodules.openaimodel
|
||||
# import ldm.models.diffusion.ddpm
|
||||
# import ldm.models.diffusion.ddim
|
||||
# import ldm.models.diffusion.plms
|
||||
# import ldm.modules.encoders.modules
|
||||
#
|
||||
# import sgm.modules.attention
|
||||
# import sgm.modules.diffusionmodules.model
|
||||
# import sgm.modules.diffusionmodules.openaimodel
|
||||
# import sgm.modules.encoders.modules
|
||||
#
|
||||
# attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||
# diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
#
|
||||
# # new memory efficient cross attention blocks do not support hypernets and we already
|
||||
# # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
||||
# ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
||||
# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||
#
|
||||
# # silence new console spam from SD2
|
||||
# ldm.modules.attention.print = shared.ldm_print
|
||||
# ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
# ldm.util.print = shared.ldm_print
|
||||
# ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||
#
|
||||
# optimizers = []
|
||||
# current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||
#
|
||||
# ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||
# ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
|
||||
#
|
||||
# sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||
# sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
|
||||
#
|
||||
#
|
||||
# def list_optimizers():
|
||||
# new_optimizers = script_callbacks.list_optimizers_callback()
|
||||
#
|
||||
# new_optimizers = [x for x in new_optimizers if x.is_available()]
|
||||
#
|
||||
# new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
||||
#
|
||||
# optimizers.clear()
|
||||
# optimizers.extend(new_optimizers)
|
||||
#
|
||||
#
|
||||
# def apply_optimizations(option=None):
|
||||
# return
|
||||
#
|
||||
#
|
||||
# def undo_optimizations():
|
||||
# return
|
||||
#
|
||||
#
|
||||
# def fix_checkpoint():
|
||||
# """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||
# checkpoints to be added when not training (there's a warning)"""
|
||||
#
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def weighted_loss(sd_model, pred, target, mean=True):
|
||||
# #Calculate the weight normally, but ignore the mean
|
||||
# loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||
#
|
||||
# #Check if we have weights available
|
||||
# weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||
# if weight is not None:
|
||||
# loss *= weight
|
||||
#
|
||||
# #Return the loss, as mean if specified
|
||||
# return loss.mean() if mean else loss
|
||||
#
|
||||
# def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
||||
# try:
|
||||
# #Temporarily append weights to a place accessible during loss calc
|
||||
# sd_model._custom_loss_weight = w
|
||||
#
|
||||
# #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||
# #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||
# if not hasattr(sd_model, '_old_get_loss'):
|
||||
# sd_model._old_get_loss = sd_model.get_loss
|
||||
# sd_model.get_loss = MethodType(weighted_loss, sd_model)
|
||||
#
|
||||
# #Run the standard forward function, but with the patched 'get_loss'
|
||||
# return sd_model.forward(x, c, *args, **kwargs)
|
||||
# finally:
|
||||
# try:
|
||||
# #Delete temporary weights if appended
|
||||
# del sd_model._custom_loss_weight
|
||||
# except AttributeError:
|
||||
# pass
|
||||
#
|
||||
# #If we have an old loss function, reset the loss function to the original one
|
||||
# if hasattr(sd_model, '_old_get_loss'):
|
||||
# sd_model.get_loss = sd_model._old_get_loss
|
||||
# del sd_model._old_get_loss
|
||||
#
|
||||
# def apply_weighted_forward(sd_model):
|
||||
# #Add new function 'weighted_forward' that can be called to calc weighted loss
|
||||
# sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
|
||||
#
|
||||
# def undo_weighted_forward(sd_model):
|
||||
# try:
|
||||
# del sd_model.weighted_forward
|
||||
# except AttributeError:
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# class StableDiffusionModelHijack:
|
||||
# fixes = None
|
||||
# layers = None
|
||||
# circular_enabled = False
|
||||
# clip = None
|
||||
# optimization_method = None
|
||||
#
|
||||
# def __init__(self):
|
||||
# self.extra_generation_params = {}
|
||||
# self.comments = []
|
||||
#
|
||||
# def apply_optimizations(self, option=None):
|
||||
# pass
|
||||
#
|
||||
# def convert_sdxl_to_ssd(self, m):
|
||||
# pass
|
||||
#
|
||||
# def hijack(self, m):
|
||||
# pass
|
||||
#
|
||||
# def undo_hijack(self, m):
|
||||
# pass
|
||||
#
|
||||
# def apply_circular(self, enable):
|
||||
# pass
|
||||
#
|
||||
# def clear_comments(self):
|
||||
# self.comments = []
|
||||
# self.extra_generation_params = {}
|
||||
#
|
||||
# def get_prompt_lengths(self, text, cond_stage_model):
|
||||
# pass
|
||||
#
|
||||
# def redo_hijack(self, m):
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# class EmbeddingsWithFixes(torch.nn.Module):
|
||||
# def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||
# super().__init__()
|
||||
# self.wrapped = wrapped
|
||||
# self.embeddings = embeddings
|
||||
# self.textual_inversion_key = textual_inversion_key
|
||||
# self.weight = self.wrapped.weight
|
||||
#
|
||||
# def forward(self, input_ids):
|
||||
# batch_fixes = self.embeddings.fixes
|
||||
# self.embeddings.fixes = None
|
||||
#
|
||||
# inputs_embeds = self.wrapped(input_ids)
|
||||
#
|
||||
# if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||
# return inputs_embeds
|
||||
#
|
||||
# vecs = []
|
||||
# for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
# for offset, embedding in fixes:
|
||||
# vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||
# emb = devices.cond_cast_unet(vec)
|
||||
# emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
# tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
||||
#
|
||||
# vecs.append(tensor)
|
||||
#
|
||||
# return torch.stack(vecs)
|
||||
#
|
||||
#
|
||||
# class TextualInversionEmbeddings(torch.nn.Embedding):
|
||||
# def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
|
||||
# super().__init__(num_embeddings, embedding_dim, **kwargs)
|
||||
#
|
||||
# self.embeddings = model_hijack
|
||||
# self.textual_inversion_key = textual_inversion_key
|
||||
#
|
||||
# @property
|
||||
# def wrapped(self):
|
||||
# return super().forward
|
||||
#
|
||||
# def forward(self, input_ids):
|
||||
# return EmbeddingsWithFixes.forward(self, input_ids)
|
||||
#
|
||||
#
|
||||
# def add_circular_option_to_conv_2d():
|
||||
# conv2d_constructor = torch.nn.Conv2d.__init__
|
||||
#
|
||||
# def conv2d_constructor_circular(self, *args, **kwargs):
|
||||
# return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
||||
#
|
||||
# torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# def register_buffer(self, name, attr):
|
||||
# """
|
||||
# Fix register buffer bug for Mac OS.
|
||||
# """
|
||||
#
|
||||
# if type(attr) == torch.Tensor:
|
||||
# if attr.device != devices.device:
|
||||
# attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
||||
#
|
||||
# setattr(self, name, attr)
|
||||
#
|
||||
#
|
||||
# ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
||||
# ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
||||
|
||||
@@ -1,46 +1,46 @@
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
|
||||
|
||||
def BasicTransformerBlock_forward(self, x, context=None):
|
||||
return checkpoint(self._forward, x, context)
|
||||
|
||||
|
||||
def AttentionBlock_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
|
||||
def ResBlock_forward(self, x, emb):
|
||||
return checkpoint(self._forward, x, emb)
|
||||
|
||||
|
||||
stored = []
|
||||
|
||||
|
||||
def add():
|
||||
if len(stored) != 0:
|
||||
return
|
||||
|
||||
stored.extend([
|
||||
ldm.modules.attention.BasicTransformerBlock.forward,
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
|
||||
])
|
||||
|
||||
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
|
||||
|
||||
|
||||
def remove():
|
||||
if len(stored) == 0:
|
||||
return
|
||||
|
||||
ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
|
||||
|
||||
stored.clear()
|
||||
|
||||
# from torch.utils.checkpoint import checkpoint
|
||||
#
|
||||
# import ldm.modules.attention
|
||||
# import ldm.modules.diffusionmodules.openaimodel
|
||||
#
|
||||
#
|
||||
# def BasicTransformerBlock_forward(self, x, context=None):
|
||||
# return checkpoint(self._forward, x, context)
|
||||
#
|
||||
#
|
||||
# def AttentionBlock_forward(self, x):
|
||||
# return checkpoint(self._forward, x)
|
||||
#
|
||||
#
|
||||
# def ResBlock_forward(self, x, emb):
|
||||
# return checkpoint(self._forward, x, emb)
|
||||
#
|
||||
#
|
||||
# stored = []
|
||||
#
|
||||
#
|
||||
# def add():
|
||||
# if len(stored) != 0:
|
||||
# return
|
||||
#
|
||||
# stored.extend([
|
||||
# ldm.modules.attention.BasicTransformerBlock.forward,
|
||||
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
|
||||
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
|
||||
# ])
|
||||
#
|
||||
# ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
|
||||
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
|
||||
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
|
||||
#
|
||||
#
|
||||
# def remove():
|
||||
# if len(stored) == 0:
|
||||
# return
|
||||
#
|
||||
# ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
|
||||
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
|
||||
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
|
||||
#
|
||||
# stored.clear()
|
||||
#
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,154 +1,154 @@
|
||||
import torch
|
||||
from packaging import version
|
||||
from einops import repeat
|
||||
import math
|
||||
|
||||
from modules import devices
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
|
||||
|
||||
class TorchHijackForUnet:
|
||||
"""
|
||||
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
||||
"""
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'cat':
|
||||
return self.cat
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def cat(self, tensors, *args, **kwargs):
|
||||
if len(tensors) == 2:
|
||||
a, b = tensors
|
||||
if a.shape[-2:] != b.shape[-2:]:
|
||||
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||
|
||||
tensors = (a, b)
|
||||
|
||||
return torch.cat(tensors, *args, **kwargs)
|
||||
|
||||
|
||||
th = TorchHijackForUnet()
|
||||
|
||||
|
||||
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||
"""Always make sure inputs to unet are in correct dtype."""
|
||||
if isinstance(cond, dict):
|
||||
for y in cond.keys():
|
||||
if isinstance(cond[y], list):
|
||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
else:
|
||||
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||
|
||||
with devices.autocast():
|
||||
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
||||
if devices.unet_needs_upcast:
|
||||
return result.float()
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
# Monkey patch to create timestep embed tensor on device, avoiding a block.
|
||||
def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
# Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
|
||||
# Prevents a lot of unnecessary aten::copy_ calls
|
||||
def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
if not isinstance(context, list):
|
||||
context = [context]
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, context=context[i])
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||
def forward(self, x):
|
||||
if devices.unet_needs_upcast:
|
||||
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||
else:
|
||||
return torch.nn.GELU.forward(self, x)
|
||||
|
||||
|
||||
ddpm_edit_hijack = None
|
||||
def hijack_ddpm_edit():
|
||||
global ddpm_edit_hijack
|
||||
if not ddpm_edit_hijack:
|
||||
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
|
||||
|
||||
|
||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
||||
CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
|
||||
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
||||
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
||||
|
||||
|
||||
def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
||||
if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = devices.dtype_unet
|
||||
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
||||
|
||||
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||
# import torch
|
||||
# from packaging import version
|
||||
# from einops import repeat
|
||||
# import math
|
||||
#
|
||||
# from modules import devices
|
||||
# from modules.sd_hijack_utils import CondFunc
|
||||
#
|
||||
#
|
||||
# class TorchHijackForUnet:
|
||||
# """
|
||||
# This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||
# this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
||||
# """
|
||||
#
|
||||
# def __getattr__(self, item):
|
||||
# if item == 'cat':
|
||||
# return self.cat
|
||||
#
|
||||
# if hasattr(torch, item):
|
||||
# return getattr(torch, item)
|
||||
#
|
||||
# raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
#
|
||||
# def cat(self, tensors, *args, **kwargs):
|
||||
# if len(tensors) == 2:
|
||||
# a, b = tensors
|
||||
# if a.shape[-2:] != b.shape[-2:]:
|
||||
# a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||
#
|
||||
# tensors = (a, b)
|
||||
#
|
||||
# return torch.cat(tensors, *args, **kwargs)
|
||||
#
|
||||
#
|
||||
# th = TorchHijackForUnet()
|
||||
#
|
||||
#
|
||||
# # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||
# def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||
# """Always make sure inputs to unet are in correct dtype."""
|
||||
# if isinstance(cond, dict):
|
||||
# for y in cond.keys():
|
||||
# if isinstance(cond[y], list):
|
||||
# cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
# else:
|
||||
# cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||
#
|
||||
# with devices.autocast():
|
||||
# result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
||||
# if devices.unet_needs_upcast:
|
||||
# return result.float()
|
||||
# else:
|
||||
# return result
|
||||
#
|
||||
#
|
||||
# # Monkey patch to create timestep embed tensor on device, avoiding a block.
|
||||
# def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
|
||||
# """
|
||||
# Create sinusoidal timestep embeddings.
|
||||
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
# These may be fractional.
|
||||
# :param dim: the dimension of the output.
|
||||
# :param max_period: controls the minimum frequency of the embeddings.
|
||||
# :return: an [N x dim] Tensor of positional embeddings.
|
||||
# """
|
||||
# if not repeat_only:
|
||||
# half = dim // 2
|
||||
# freqs = torch.exp(
|
||||
# -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
||||
# )
|
||||
# args = timesteps[:, None].float() * freqs[None]
|
||||
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
# if dim % 2:
|
||||
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
# else:
|
||||
# embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
# return embedding
|
||||
#
|
||||
#
|
||||
# # Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
|
||||
# # Prevents a lot of unnecessary aten::copy_ calls
|
||||
# def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
|
||||
# # note: if no context is given, cross-attention defaults to self-attention
|
||||
# if not isinstance(context, list):
|
||||
# context = [context]
|
||||
# b, c, h, w = x.shape
|
||||
# x_in = x
|
||||
# x = self.norm(x)
|
||||
# if not self.use_linear:
|
||||
# x = self.proj_in(x)
|
||||
# x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||
# if self.use_linear:
|
||||
# x = self.proj_in(x)
|
||||
# for i, block in enumerate(self.transformer_blocks):
|
||||
# x = block(x, context=context[i])
|
||||
# if self.use_linear:
|
||||
# x = self.proj_out(x)
|
||||
# x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
||||
# if not self.use_linear:
|
||||
# x = self.proj_out(x)
|
||||
# return x + x_in
|
||||
#
|
||||
#
|
||||
# class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||
# def forward(self, x):
|
||||
# if devices.unet_needs_upcast:
|
||||
# return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||
# else:
|
||||
# return torch.nn.GELU.forward(self, x)
|
||||
#
|
||||
#
|
||||
# ddpm_edit_hijack = None
|
||||
# def hijack_ddpm_edit():
|
||||
# global ddpm_edit_hijack
|
||||
# if not ddpm_edit_hijack:
|
||||
# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
# ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
|
||||
#
|
||||
#
|
||||
# unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
||||
# CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
|
||||
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
#
|
||||
# if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||
# CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
# CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
# CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
#
|
||||
# first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||
# first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||
#
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
||||
# CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
||||
#
|
||||
#
|
||||
# def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
||||
# if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
||||
# dtype = torch.float32
|
||||
# else:
|
||||
# dtype = devices.dtype_unet
|
||||
# return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
||||
#
|
||||
#
|
||||
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||
# CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||
|
||||
@@ -10,7 +10,6 @@ import re
|
||||
import safetensors.torch
|
||||
from omegaconf import OmegaConf, ListConfig
|
||||
from urllib import request
|
||||
import ldm.modules.midas as midas
|
||||
import gc
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
@@ -415,89 +414,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
|
||||
def enable_midas_autodownload():
|
||||
"""
|
||||
Gives the ldm.modules.midas.api.load_model function automatic downloading.
|
||||
|
||||
When the 512-depth-ema model, and other future models like it, is loaded,
|
||||
it calls midas.api.load_model to load the associated midas depth model.
|
||||
This function applies a wrapper to download the model to the correct
|
||||
location automatically.
|
||||
"""
|
||||
|
||||
midas_path = os.path.join(paths.models_path, 'midas')
|
||||
|
||||
# stable-diffusion-stability-ai hard-codes the midas model path to
|
||||
# a location that differs from where other scripts using this model look.
|
||||
# HACK: Overriding the path here.
|
||||
for k, v in midas.api.ISL_PATHS.items():
|
||||
file_name = os.path.basename(v)
|
||||
midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
|
||||
|
||||
midas_urls = {
|
||||
"dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
||||
"midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
|
||||
"midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
|
||||
}
|
||||
|
||||
midas.api.load_model_inner = midas.api.load_model
|
||||
|
||||
def load_model_wrapper(model_type):
|
||||
path = midas.api.ISL_PATHS[model_type]
|
||||
if not os.path.exists(path):
|
||||
if not os.path.exists(midas_path):
|
||||
os.mkdir(midas_path)
|
||||
|
||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||
request.urlretrieve(midas_urls[model_type], path)
|
||||
print(f"{model_type} downloaded")
|
||||
|
||||
return midas.api.load_model_inner(model_type)
|
||||
|
||||
midas.api.load_model = load_model_wrapper
|
||||
pass
|
||||
|
||||
|
||||
def patch_given_betas():
|
||||
import ldm.models.diffusion.ddpm
|
||||
|
||||
def patched_register_schedule(*args, **kwargs):
|
||||
"""a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
|
||||
|
||||
if isinstance(args[1], ListConfig):
|
||||
args = (args[0], np.array(args[1]), *args[2:])
|
||||
|
||||
original_register_schedule(*args, **kwargs)
|
||||
|
||||
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
|
||||
pass
|
||||
|
||||
|
||||
def repair_config(sd_config, state_dict=None):
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
if hasattr(sd_config.model.params, 'unet_config'):
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
if hasattr(sd_config.model.params, 'first_stage_config'):
|
||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
||||
|
||||
# For UnCLIP-L, override the hardcoded karlo directory
|
||||
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
|
||||
karlo_path = os.path.join(paths.models_path, 'karlo')
|
||||
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
||||
|
||||
# Do not use checkpoint for inference.
|
||||
# This helps prevent extra performance overhead on checking parameters.
|
||||
# The perf overhead is about 100ms/it on 4090 for SDXL.
|
||||
if hasattr(sd_config.model.params, "network_config"):
|
||||
sd_config.model.params.network_config.params.use_checkpoint = False
|
||||
if hasattr(sd_config.model.params, "unet_config"):
|
||||
sd_config.model.params.unet_config.params.use_checkpoint = False
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
|
||||
@@ -1,137 +1,137 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared, paths, sd_disable_initialization, devices
|
||||
|
||||
sd_configs_path = shared.sd_configs_path
|
||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
||||
|
||||
|
||||
config_default = shared.sd_default_config
|
||||
# config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
|
||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
|
||||
|
||||
|
||||
def is_using_v_parameterization_for_sd2(state_dict):
|
||||
"""
|
||||
Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
||||
"""
|
||||
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
|
||||
device = devices.device
|
||||
|
||||
with sd_disable_initialization.DisableInitialization():
|
||||
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
image_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
model_channels=320,
|
||||
attention_resolutions=[4, 2, 1],
|
||||
num_res_blocks=2,
|
||||
channel_mult=[1, 2, 4, 4],
|
||||
num_head_channels=64,
|
||||
use_spatial_transformer=True,
|
||||
use_linear_in_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=1024,
|
||||
legacy=False
|
||||
)
|
||||
unet.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
||||
unet.load_state_dict(unet_sd, strict=True)
|
||||
unet.to(device=device, dtype=devices.dtype_unet)
|
||||
|
||||
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
||||
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
||||
|
||||
with devices.autocast():
|
||||
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
|
||||
|
||||
return out < -1
|
||||
|
||||
|
||||
def guess_model_config_from_state_dict(sd, filename):
|
||||
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||
|
||||
if "model.diffusion_model.x_embedder.proj.weight" in sd:
|
||||
return config_sd3
|
||||
|
||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_sdxl_inpainting
|
||||
else:
|
||||
return config_sdxl
|
||||
|
||||
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||
return config_sdxl_refiner
|
||||
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
return config_depth_model
|
||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||
return config_unclip
|
||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
|
||||
return config_unopenclip
|
||||
|
||||
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_sd2_inpainting
|
||||
# elif is_using_v_parameterization_for_sd2(sd):
|
||||
# return config_sd2v
|
||||
else:
|
||||
return config_sd2v
|
||||
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_inpainting
|
||||
if diffusion_model_input.shape[1] == 8:
|
||||
return config_instruct_pix2pix
|
||||
|
||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
||||
return config_alt_diffusion_m18
|
||||
return config_alt_diffusion
|
||||
|
||||
return config_default
|
||||
|
||||
|
||||
def find_checkpoint_config(state_dict, info):
|
||||
if info is None:
|
||||
return guess_model_config_from_state_dict(state_dict, "")
|
||||
|
||||
config = find_checkpoint_config_near_filename(info)
|
||||
if config is not None:
|
||||
return config
|
||||
|
||||
return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||
|
||||
|
||||
def find_checkpoint_config_near_filename(info):
|
||||
if info is None:
|
||||
return None
|
||||
|
||||
config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
# import os
|
||||
#
|
||||
# import torch
|
||||
#
|
||||
# from modules import shared, paths, sd_disable_initialization, devices
|
||||
#
|
||||
# sd_configs_path = shared.sd_configs_path
|
||||
# # sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
# # sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
||||
#
|
||||
#
|
||||
# config_default = shared.sd_default_config
|
||||
# # config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
# config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
# config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
# config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||
# config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||
# config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
|
||||
# config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
# config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||
# config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||
# config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||
# config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
# config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
# config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||
# config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
|
||||
#
|
||||
#
|
||||
# def is_using_v_parameterization_for_sd2(state_dict):
|
||||
# """
|
||||
# Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
||||
# """
|
||||
#
|
||||
# import ldm.modules.diffusionmodules.openaimodel
|
||||
#
|
||||
# device = devices.device
|
||||
#
|
||||
# with sd_disable_initialization.DisableInitialization():
|
||||
# unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
||||
# use_checkpoint=False,
|
||||
# use_fp16=False,
|
||||
# image_size=32,
|
||||
# in_channels=4,
|
||||
# out_channels=4,
|
||||
# model_channels=320,
|
||||
# attention_resolutions=[4, 2, 1],
|
||||
# num_res_blocks=2,
|
||||
# channel_mult=[1, 2, 4, 4],
|
||||
# num_head_channels=64,
|
||||
# use_spatial_transformer=True,
|
||||
# use_linear_in_transformer=True,
|
||||
# transformer_depth=1,
|
||||
# context_dim=1024,
|
||||
# legacy=False
|
||||
# )
|
||||
# unet.eval()
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
||||
# unet.load_state_dict(unet_sd, strict=True)
|
||||
# unet.to(device=device, dtype=devices.dtype_unet)
|
||||
#
|
||||
# test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
||||
# x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
||||
#
|
||||
# with devices.autocast():
|
||||
# out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
|
||||
#
|
||||
# return out < -1
|
||||
#
|
||||
#
|
||||
# def guess_model_config_from_state_dict(sd, filename):
|
||||
# sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||
# diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
# sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||
#
|
||||
# if "model.diffusion_model.x_embedder.proj.weight" in sd:
|
||||
# return config_sd3
|
||||
#
|
||||
# if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||
# if diffusion_model_input.shape[1] == 9:
|
||||
# return config_sdxl_inpainting
|
||||
# else:
|
||||
# return config_sdxl
|
||||
#
|
||||
# if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||
# return config_sdxl_refiner
|
||||
# elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
# return config_depth_model
|
||||
# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||
# return config_unclip
|
||||
# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
|
||||
# return config_unopenclip
|
||||
#
|
||||
# if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||
# if diffusion_model_input.shape[1] == 9:
|
||||
# return config_sd2_inpainting
|
||||
# # elif is_using_v_parameterization_for_sd2(sd):
|
||||
# # return config_sd2v
|
||||
# else:
|
||||
# return config_sd2v
|
||||
#
|
||||
# if diffusion_model_input is not None:
|
||||
# if diffusion_model_input.shape[1] == 9:
|
||||
# return config_inpainting
|
||||
# if diffusion_model_input.shape[1] == 8:
|
||||
# return config_instruct_pix2pix
|
||||
#
|
||||
# if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||
# if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
||||
# return config_alt_diffusion_m18
|
||||
# return config_alt_diffusion
|
||||
#
|
||||
# return config_default
|
||||
#
|
||||
#
|
||||
# def find_checkpoint_config(state_dict, info):
|
||||
# if info is None:
|
||||
# return guess_model_config_from_state_dict(state_dict, "")
|
||||
#
|
||||
# config = find_checkpoint_config_near_filename(info)
|
||||
# if config is not None:
|
||||
# return config
|
||||
#
|
||||
# return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||
#
|
||||
#
|
||||
# def find_checkpoint_config_near_filename(info):
|
||||
# if info is None:
|
||||
# return None
|
||||
#
|
||||
# config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
||||
# if os.path.exists(config):
|
||||
# return config
|
||||
#
|
||||
# return None
|
||||
#
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -6,7 +5,7 @@ if TYPE_CHECKING:
|
||||
from modules.sd_models import CheckpointInfo
|
||||
|
||||
|
||||
class WebuiSdModel(LatentDiffusion):
|
||||
class WebuiSdModel:
|
||||
"""This class is not actually instantinated, but its fields are created and fieeld by webui"""
|
||||
|
||||
lowvram: bool
|
||||
|
||||
@@ -1,115 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import sgm.models.diffusion
|
||||
import sgm.modules.diffusionmodules.denoiser_scaling
|
||||
import sgm.modules.diffusionmodules.discretizer
|
||||
from modules import devices, shared, prompt_parser
|
||||
from modules import torch_utils
|
||||
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||
|
||||
for embedder in self.conditioner.embedders:
|
||||
embedder.ucg_rate = 0.0
|
||||
|
||||
width = getattr(batch, 'width', 1024) or 1024
|
||||
height = getattr(batch, 'height', 1024) or 1024
|
||||
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||
|
||||
devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
|
||||
|
||||
sdxl_conds = {
|
||||
"txt": batch,
|
||||
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
||||
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
||||
}
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
||||
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
|
||||
if self.model.diffusion_model.in_channels == 9:
|
||||
x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||
|
||||
return self.model(x, t, cond, *args, **kwargs)
|
||||
|
||||
|
||||
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||
return x
|
||||
|
||||
|
||||
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||||
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
||||
|
||||
|
||||
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
||||
res = []
|
||||
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
||||
encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
||||
res.append(encoded)
|
||||
|
||||
return torch.cat(res, dim=1)
|
||||
|
||||
|
||||
def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
||||
return embedder.tokenize(texts)
|
||||
|
||||
raise AssertionError('no tokenizer available')
|
||||
|
||||
|
||||
|
||||
def process_texts(self, texts):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||
return embedder.process_texts(texts)
|
||||
|
||||
|
||||
def get_target_prompt_token_count(self, token_count):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
||||
return embedder.get_target_prompt_token_count(token_count)
|
||||
|
||||
|
||||
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||
sgm.modules.GeneralConditioner.tokenize = tokenize
|
||||
sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||
|
||||
|
||||
def extend_sdxl(model):
|
||||
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||
|
||||
dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
||||
model.model.diffusion_model.dtype = dtype
|
||||
model.model.conditioning_key = 'crossattn'
|
||||
model.cond_stage_key = 'txt'
|
||||
# model.cond_stage_model will be set in sd_hijack
|
||||
|
||||
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||
|
||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
||||
|
||||
model.conditioner.wrapped = torch.nn.Module()
|
||||
|
||||
|
||||
sgm.modules.attention.print = shared.ldm_print
|
||||
sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||
sgm.modules.encoders.modules.print = shared.ldm_print
|
||||
|
||||
# this gets the code to load the vanilla attention that we override
|
||||
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
||||
# from __future__ import annotations
|
||||
#
|
||||
# import torch
|
||||
#
|
||||
# import sgm.models.diffusion
|
||||
# import sgm.modules.diffusionmodules.denoiser_scaling
|
||||
# import sgm.modules.diffusionmodules.discretizer
|
||||
# from modules import devices, shared, prompt_parser
|
||||
# from modules import torch_utils
|
||||
#
|
||||
# from backend import memory_management
|
||||
#
|
||||
#
|
||||
# def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||
#
|
||||
# for embedder in self.conditioner.embedders:
|
||||
# embedder.ucg_rate = 0.0
|
||||
#
|
||||
# width = getattr(batch, 'width', 1024) or 1024
|
||||
# height = getattr(batch, 'height', 1024) or 1024
|
||||
# is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||
# aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||
#
|
||||
# devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
|
||||
#
|
||||
# sdxl_conds = {
|
||||
# "txt": batch,
|
||||
# "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
# "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
||||
# "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
# "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
||||
# }
|
||||
#
|
||||
# force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
||||
# c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
||||
#
|
||||
# return c
|
||||
#
|
||||
#
|
||||
# def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
|
||||
# if self.model.diffusion_model.in_channels == 9:
|
||||
# x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||
#
|
||||
# return self.model(x, t, cond, *args, **kwargs)
|
||||
#
|
||||
#
|
||||
# def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||
# return x
|
||||
#
|
||||
#
|
||||
# sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||
# sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||||
# sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
||||
#
|
||||
#
|
||||
# def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
||||
# res = []
|
||||
#
|
||||
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
||||
# encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
||||
# res.append(encoded)
|
||||
#
|
||||
# return torch.cat(res, dim=1)
|
||||
#
|
||||
#
|
||||
# def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
||||
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
||||
# return embedder.tokenize(texts)
|
||||
#
|
||||
# raise AssertionError('no tokenizer available')
|
||||
#
|
||||
#
|
||||
#
|
||||
# def process_texts(self, texts):
|
||||
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||
# return embedder.process_texts(texts)
|
||||
#
|
||||
#
|
||||
# def get_target_prompt_token_count(self, token_count):
|
||||
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
||||
# return embedder.get_target_prompt_token_count(token_count)
|
||||
#
|
||||
#
|
||||
# # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||
# sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||
# sgm.modules.GeneralConditioner.tokenize = tokenize
|
||||
# sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||
# sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||
#
|
||||
#
|
||||
# def extend_sdxl(model):
|
||||
# """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||
#
|
||||
# dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
||||
# model.model.diffusion_model.dtype = dtype
|
||||
# model.model.conditioning_key = 'crossattn'
|
||||
# model.cond_stage_key = 'txt'
|
||||
# # model.cond_stage_model will be set in sd_hijack
|
||||
#
|
||||
# model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||
#
|
||||
# discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||
# model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
||||
#
|
||||
# model.conditioner.wrapped = torch.nn.Module()
|
||||
#
|
||||
#
|
||||
# sgm.modules.attention.print = shared.ldm_print
|
||||
# sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
# sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||
# sgm.modules.encoders.modules.print = shared.ldm_print
|
||||
#
|
||||
# # this gets the code to load the vanilla attention that we override
|
||||
# sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||
# sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
||||
|
||||
@@ -35,9 +35,7 @@ def refresh_vae_list():
|
||||
|
||||
|
||||
def cross_attention_optimizations():
|
||||
import modules.sd_hijack
|
||||
|
||||
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
|
||||
return ["Automatic"]
|
||||
|
||||
|
||||
def sd_unet_items():
|
||||
|
||||
@@ -1,245 +1,243 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||
from torchvision import transforms
|
||||
from collections import defaultdict
|
||||
from random import shuffle, choices
|
||||
|
||||
import random
|
||||
import tqdm
|
||||
from modules import devices, shared, images
|
||||
import re
|
||||
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||
|
||||
|
||||
class DatasetEntry:
|
||||
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
|
||||
self.filename = filename
|
||||
self.filename_text = filename_text
|
||||
self.weight = weight
|
||||
self.latent_dist = latent_dist
|
||||
self.latent_sample = latent_sample
|
||||
self.cond = cond
|
||||
self.cond_text = cond_text
|
||||
self.pixel_values = pixel_values
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
self.dataset = []
|
||||
|
||||
with open(template_file, "r") as file:
|
||||
lines = [x.strip() for x in file.readlines()]
|
||||
|
||||
self.lines = lines
|
||||
|
||||
assert data_root, 'dataset directory not specified'
|
||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
|
||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.tag_drop_out = tag_drop_out
|
||||
groups = defaultdict(list)
|
||||
|
||||
print("Preparing dataset...")
|
||||
for path in tqdm.tqdm(self.image_paths):
|
||||
alpha_channel = None
|
||||
if shared.state.interrupted:
|
||||
raise Exception("interrupted")
|
||||
try:
|
||||
image = images.read(path)
|
||||
#Currently does not work for single color transparency
|
||||
#We would need to read image.info['transparency'] for that
|
||||
if use_weight and 'A' in image.getbands():
|
||||
alpha_channel = image.getchannel('A')
|
||||
image = image.convert('RGB')
|
||||
if not varsize:
|
||||
image = image.resize((width, height), PIL.Image.BICUBIC)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
text_filename = f"{os.path.splitext(path)[0]}.txt"
|
||||
filename = os.path.basename(path)
|
||||
|
||||
if os.path.exists(text_filename):
|
||||
with open(text_filename, "r", encoding="utf8") as file:
|
||||
filename_text = file.read()
|
||||
else:
|
||||
filename_text = os.path.splitext(filename)[0]
|
||||
filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
||||
if re_word:
|
||||
tokens = re_word.findall(filename_text)
|
||||
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||
|
||||
npimage = np.array(image).astype(np.uint8)
|
||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||
latent_sample = None
|
||||
|
||||
with devices.autocast():
|
||||
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||
|
||||
#Perform latent sampling, even for random sampling.
|
||||
#We need the sample dimensions for the weights
|
||||
if latent_sampling_method == "deterministic":
|
||||
if isinstance(latent_dist, DiagonalGaussianDistribution):
|
||||
# Works only for DiagonalGaussianDistribution
|
||||
latent_dist.std = 0
|
||||
else:
|
||||
latent_sampling_method = "once"
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
|
||||
if use_weight and alpha_channel is not None:
|
||||
channels, *latent_size = latent_sample.shape
|
||||
weight_img = alpha_channel.resize(latent_size)
|
||||
npweight = np.array(weight_img).astype(np.float32)
|
||||
#Repeat for every channel in the latent sample
|
||||
weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
|
||||
#Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
|
||||
weight -= weight.min()
|
||||
weight /= weight.mean()
|
||||
elif use_weight:
|
||||
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
|
||||
weight = torch.ones(latent_sample.shape)
|
||||
else:
|
||||
weight = None
|
||||
|
||||
if latent_sampling_method == "random":
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
||||
else:
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
|
||||
|
||||
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
entry.cond_text = self.create_text(filename_text)
|
||||
|
||||
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
with devices.autocast():
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
groups[image.size].append(len(self.dataset))
|
||||
self.dataset.append(entry)
|
||||
del torchdata
|
||||
del latent_dist
|
||||
del latent_sample
|
||||
del weight
|
||||
|
||||
self.length = len(self.dataset)
|
||||
self.groups = list(groups.values())
|
||||
assert self.length > 0, "No images have been found in the dataset."
|
||||
self.batch_size = min(batch_size, self.length)
|
||||
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||
self.latent_sampling_method = latent_sampling_method
|
||||
|
||||
if len(groups) > 1:
|
||||
print("Buckets:")
|
||||
for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
|
||||
print(f" {w}x{h}: {len(ids)}")
|
||||
print()
|
||||
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
tags = filename_text.split(',')
|
||||
if self.tag_drop_out != 0:
|
||||
tags = [t for t in tags if random.random() > self.tag_drop_out]
|
||||
if self.shuffle_tags:
|
||||
random.shuffle(tags)
|
||||
text = text.replace("[filewords]", ','.join(tags))
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
return text
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
entry = self.dataset[i]
|
||||
if self.tag_drop_out != 0 or self.shuffle_tags:
|
||||
entry.cond_text = self.create_text(entry.filename_text)
|
||||
if self.latent_sampling_method == "random":
|
||||
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||
return entry
|
||||
|
||||
|
||||
class GroupedBatchSampler(Sampler):
|
||||
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||
super().__init__(data_source)
|
||||
|
||||
n = len(data_source)
|
||||
self.groups = data_source.groups
|
||||
self.len = n_batch = n // batch_size
|
||||
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
||||
self.base = [int(e) // batch_size for e in expected]
|
||||
self.n_rand_batches = nrb = n_batch - sum(self.base)
|
||||
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
def __iter__(self):
|
||||
b = self.batch_size
|
||||
|
||||
for g in self.groups:
|
||||
shuffle(g)
|
||||
|
||||
batches = []
|
||||
for g in self.groups:
|
||||
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
||||
for _ in range(self.n_rand_batches):
|
||||
rand_group = choices(self.groups, self.probs)[0]
|
||||
batches.append(choices(rand_group, k=b))
|
||||
|
||||
shuffle(batches)
|
||||
|
||||
yield from batches
|
||||
|
||||
|
||||
class PersonalizedDataLoader(DataLoader):
|
||||
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||
if latent_sampling_method == "random":
|
||||
self.collate_fn = collate_wrapper_random
|
||||
else:
|
||||
self.collate_fn = collate_wrapper
|
||||
|
||||
|
||||
class BatchLoader:
|
||||
def __init__(self, data):
|
||||
self.cond_text = [entry.cond_text for entry in data]
|
||||
self.cond = [entry.cond for entry in data]
|
||||
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||
if all(entry.weight is not None for entry in data):
|
||||
self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
|
||||
else:
|
||||
self.weight = None
|
||||
#self.emb_index = [entry.emb_index for entry in data]
|
||||
#print(self.latent_sample.device)
|
||||
|
||||
def pin_memory(self):
|
||||
self.latent_sample = self.latent_sample.pin_memory()
|
||||
return self
|
||||
|
||||
def collate_wrapper(batch):
|
||||
return BatchLoader(batch)
|
||||
|
||||
class BatchLoaderRandom(BatchLoader):
|
||||
def __init__(self, data):
|
||||
super().__init__(data)
|
||||
|
||||
def pin_memory(self):
|
||||
return self
|
||||
|
||||
def collate_wrapper_random(batch):
|
||||
return BatchLoaderRandom(batch)
|
||||
# import os
|
||||
# import numpy as np
|
||||
# import PIL
|
||||
# import torch
|
||||
# from torch.utils.data import Dataset, DataLoader, Sampler
|
||||
# from torchvision import transforms
|
||||
# from collections import defaultdict
|
||||
# from random import shuffle, choices
|
||||
#
|
||||
# import random
|
||||
# import tqdm
|
||||
# from modules import devices, shared, images
|
||||
# import re
|
||||
#
|
||||
# re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||
#
|
||||
#
|
||||
# class DatasetEntry:
|
||||
# def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
|
||||
# self.filename = filename
|
||||
# self.filename_text = filename_text
|
||||
# self.weight = weight
|
||||
# self.latent_dist = latent_dist
|
||||
# self.latent_sample = latent_sample
|
||||
# self.cond = cond
|
||||
# self.cond_text = cond_text
|
||||
# self.pixel_values = pixel_values
|
||||
#
|
||||
#
|
||||
# class PersonalizedBase(Dataset):
|
||||
# def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
|
||||
# re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
|
||||
#
|
||||
# self.placeholder_token = placeholder_token
|
||||
#
|
||||
# self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
#
|
||||
# self.dataset = []
|
||||
#
|
||||
# with open(template_file, "r") as file:
|
||||
# lines = [x.strip() for x in file.readlines()]
|
||||
#
|
||||
# self.lines = lines
|
||||
#
|
||||
# assert data_root, 'dataset directory not specified'
|
||||
# assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
# assert os.listdir(data_root), "Dataset directory is empty"
|
||||
#
|
||||
# self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||
#
|
||||
# self.shuffle_tags = shuffle_tags
|
||||
# self.tag_drop_out = tag_drop_out
|
||||
# groups = defaultdict(list)
|
||||
#
|
||||
# print("Preparing dataset...")
|
||||
# for path in tqdm.tqdm(self.image_paths):
|
||||
# alpha_channel = None
|
||||
# if shared.state.interrupted:
|
||||
# raise Exception("interrupted")
|
||||
# try:
|
||||
# image = images.read(path)
|
||||
# #Currently does not work for single color transparency
|
||||
# #We would need to read image.info['transparency'] for that
|
||||
# if use_weight and 'A' in image.getbands():
|
||||
# alpha_channel = image.getchannel('A')
|
||||
# image = image.convert('RGB')
|
||||
# if not varsize:
|
||||
# image = image.resize((width, height), PIL.Image.BICUBIC)
|
||||
# except Exception:
|
||||
# continue
|
||||
#
|
||||
# text_filename = f"{os.path.splitext(path)[0]}.txt"
|
||||
# filename = os.path.basename(path)
|
||||
#
|
||||
# if os.path.exists(text_filename):
|
||||
# with open(text_filename, "r", encoding="utf8") as file:
|
||||
# filename_text = file.read()
|
||||
# else:
|
||||
# filename_text = os.path.splitext(filename)[0]
|
||||
# filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
||||
# if re_word:
|
||||
# tokens = re_word.findall(filename_text)
|
||||
# filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||
#
|
||||
# npimage = np.array(image).astype(np.uint8)
|
||||
# npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
#
|
||||
# torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||
# latent_sample = None
|
||||
#
|
||||
# with devices.autocast():
|
||||
# latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||
#
|
||||
# #Perform latent sampling, even for random sampling.
|
||||
# #We need the sample dimensions for the weights
|
||||
# if latent_sampling_method == "deterministic":
|
||||
# if isinstance(latent_dist, DiagonalGaussianDistribution):
|
||||
# # Works only for DiagonalGaussianDistribution
|
||||
# latent_dist.std = 0
|
||||
# else:
|
||||
# latent_sampling_method = "once"
|
||||
# latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
#
|
||||
# if use_weight and alpha_channel is not None:
|
||||
# channels, *latent_size = latent_sample.shape
|
||||
# weight_img = alpha_channel.resize(latent_size)
|
||||
# npweight = np.array(weight_img).astype(np.float32)
|
||||
# #Repeat for every channel in the latent sample
|
||||
# weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
|
||||
# #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
|
||||
# weight -= weight.min()
|
||||
# weight /= weight.mean()
|
||||
# elif use_weight:
|
||||
# #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
|
||||
# weight = torch.ones(latent_sample.shape)
|
||||
# else:
|
||||
# weight = None
|
||||
#
|
||||
# if latent_sampling_method == "random":
|
||||
# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
||||
# else:
|
||||
# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
|
||||
#
|
||||
# if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
# entry.cond_text = self.create_text(filename_text)
|
||||
#
|
||||
# if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
# with devices.autocast():
|
||||
# entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
# groups[image.size].append(len(self.dataset))
|
||||
# self.dataset.append(entry)
|
||||
# del torchdata
|
||||
# del latent_dist
|
||||
# del latent_sample
|
||||
# del weight
|
||||
#
|
||||
# self.length = len(self.dataset)
|
||||
# self.groups = list(groups.values())
|
||||
# assert self.length > 0, "No images have been found in the dataset."
|
||||
# self.batch_size = min(batch_size, self.length)
|
||||
# self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||
# self.latent_sampling_method = latent_sampling_method
|
||||
#
|
||||
# if len(groups) > 1:
|
||||
# print("Buckets:")
|
||||
# for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
|
||||
# print(f" {w}x{h}: {len(ids)}")
|
||||
# print()
|
||||
#
|
||||
# def create_text(self, filename_text):
|
||||
# text = random.choice(self.lines)
|
||||
# tags = filename_text.split(',')
|
||||
# if self.tag_drop_out != 0:
|
||||
# tags = [t for t in tags if random.random() > self.tag_drop_out]
|
||||
# if self.shuffle_tags:
|
||||
# random.shuffle(tags)
|
||||
# text = text.replace("[filewords]", ','.join(tags))
|
||||
# text = text.replace("[name]", self.placeholder_token)
|
||||
# return text
|
||||
#
|
||||
# def __len__(self):
|
||||
# return self.length
|
||||
#
|
||||
# def __getitem__(self, i):
|
||||
# entry = self.dataset[i]
|
||||
# if self.tag_drop_out != 0 or self.shuffle_tags:
|
||||
# entry.cond_text = self.create_text(entry.filename_text)
|
||||
# if self.latent_sampling_method == "random":
|
||||
# entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||
# return entry
|
||||
#
|
||||
#
|
||||
# class GroupedBatchSampler(Sampler):
|
||||
# def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||
# super().__init__(data_source)
|
||||
#
|
||||
# n = len(data_source)
|
||||
# self.groups = data_source.groups
|
||||
# self.len = n_batch = n // batch_size
|
||||
# expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
||||
# self.base = [int(e) // batch_size for e in expected]
|
||||
# self.n_rand_batches = nrb = n_batch - sum(self.base)
|
||||
# self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
||||
# self.batch_size = batch_size
|
||||
#
|
||||
# def __len__(self):
|
||||
# return self.len
|
||||
#
|
||||
# def __iter__(self):
|
||||
# b = self.batch_size
|
||||
#
|
||||
# for g in self.groups:
|
||||
# shuffle(g)
|
||||
#
|
||||
# batches = []
|
||||
# for g in self.groups:
|
||||
# batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
||||
# for _ in range(self.n_rand_batches):
|
||||
# rand_group = choices(self.groups, self.probs)[0]
|
||||
# batches.append(choices(rand_group, k=b))
|
||||
#
|
||||
# shuffle(batches)
|
||||
#
|
||||
# yield from batches
|
||||
#
|
||||
#
|
||||
# class PersonalizedDataLoader(DataLoader):
|
||||
# def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||
# super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||
# if latent_sampling_method == "random":
|
||||
# self.collate_fn = collate_wrapper_random
|
||||
# else:
|
||||
# self.collate_fn = collate_wrapper
|
||||
#
|
||||
#
|
||||
# class BatchLoader:
|
||||
# def __init__(self, data):
|
||||
# self.cond_text = [entry.cond_text for entry in data]
|
||||
# self.cond = [entry.cond for entry in data]
|
||||
# self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||
# if all(entry.weight is not None for entry in data):
|
||||
# self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
|
||||
# else:
|
||||
# self.weight = None
|
||||
# #self.emb_index = [entry.emb_index for entry in data]
|
||||
# #print(self.latent_sample.device)
|
||||
#
|
||||
# def pin_memory(self):
|
||||
# self.latent_sample = self.latent_sample.pin_memory()
|
||||
# return self
|
||||
#
|
||||
# def collate_wrapper(batch):
|
||||
# return BatchLoader(batch)
|
||||
#
|
||||
# class BatchLoaderRandom(BatchLoader):
|
||||
# def __init__(self, data):
|
||||
# super().__init__(data)
|
||||
#
|
||||
# def pin_memory(self):
|
||||
# return self
|
||||
#
|
||||
# def collate_wrapper_random(batch):
|
||||
# return BatchLoaderRandom(batch)
|
||||
|
||||
Reference in New Issue
Block a user