mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 18:51:31 +00:00
Free WebUI from its Prison
Congratulations WebUI. Say Hello to freedom.
This commit is contained in:
@@ -1,250 +0,0 @@
|
|||||||
import os
|
|
||||||
import gc
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
from PIL import Image
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
import safetensors.torch
|
|
||||||
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.util import instantiate_from_config, ismap
|
|
||||||
from modules import shared, sd_hijack, devices
|
|
||||||
|
|
||||||
cached_ldsr_model: torch.nn.Module = None
|
|
||||||
|
|
||||||
|
|
||||||
# Create LDSR Class
|
|
||||||
class LDSR:
|
|
||||||
def load_model_from_config(self, half_attention):
|
|
||||||
global cached_ldsr_model
|
|
||||||
|
|
||||||
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
|
|
||||||
print("Loading model from cache")
|
|
||||||
model: torch.nn.Module = cached_ldsr_model
|
|
||||||
else:
|
|
||||||
print(f"Loading model from {self.modelPath}")
|
|
||||||
_, extension = os.path.splitext(self.modelPath)
|
|
||||||
if extension.lower() == ".safetensors":
|
|
||||||
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
|
|
||||||
else:
|
|
||||||
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
|
||||||
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
|
|
||||||
config = OmegaConf.load(self.yamlPath)
|
|
||||||
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
|
|
||||||
model: torch.nn.Module = instantiate_from_config(config.model)
|
|
||||||
model.load_state_dict(sd, strict=False)
|
|
||||||
model = model.to(shared.device)
|
|
||||||
if half_attention:
|
|
||||||
model = model.half()
|
|
||||||
if shared.cmd_opts.opt_channelslast:
|
|
||||||
model = model.to(memory_format=torch.channels_last)
|
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(model) # apply optimization
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
if shared.opts.ldsr_cached:
|
|
||||||
cached_ldsr_model = model
|
|
||||||
|
|
||||||
return {"model": model}
|
|
||||||
|
|
||||||
def __init__(self, model_path, yaml_path):
|
|
||||||
self.modelPath = model_path
|
|
||||||
self.yamlPath = yaml_path
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def run(model, selected_path, custom_steps, eta):
|
|
||||||
example = get_cond(selected_path)
|
|
||||||
|
|
||||||
n_runs = 1
|
|
||||||
guider = None
|
|
||||||
ckwargs = None
|
|
||||||
ddim_use_x0_pred = False
|
|
||||||
temperature = 1.
|
|
||||||
eta = eta
|
|
||||||
custom_shape = None
|
|
||||||
|
|
||||||
height, width = example["image"].shape[1:3]
|
|
||||||
split_input = height >= 128 and width >= 128
|
|
||||||
|
|
||||||
if split_input:
|
|
||||||
ks = 128
|
|
||||||
stride = 64
|
|
||||||
vqf = 4 #
|
|
||||||
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
|
|
||||||
"vqf": vqf,
|
|
||||||
"patch_distributed_vq": True,
|
|
||||||
"tie_braker": False,
|
|
||||||
"clip_max_weight": 0.5,
|
|
||||||
"clip_min_weight": 0.01,
|
|
||||||
"clip_max_tie_weight": 0.5,
|
|
||||||
"clip_min_tie_weight": 0.01}
|
|
||||||
else:
|
|
||||||
if hasattr(model, "split_input_params"):
|
|
||||||
delattr(model, "split_input_params")
|
|
||||||
|
|
||||||
x_t = None
|
|
||||||
logs = None
|
|
||||||
for _ in range(n_runs):
|
|
||||||
if custom_shape is not None:
|
|
||||||
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
|
||||||
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
|
||||||
|
|
||||||
logs = make_convolutional_sample(example, model,
|
|
||||||
custom_steps=custom_steps,
|
|
||||||
eta=eta, quantize_x0=False,
|
|
||||||
custom_shape=custom_shape,
|
|
||||||
temperature=temperature, noise_dropout=0.,
|
|
||||||
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
|
|
||||||
ddim_use_x0_pred=ddim_use_x0_pred
|
|
||||||
)
|
|
||||||
return logs
|
|
||||||
|
|
||||||
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
|
|
||||||
model = self.load_model_from_config(half_attention)
|
|
||||||
|
|
||||||
# Run settings
|
|
||||||
diffusion_steps = int(steps)
|
|
||||||
eta = 1.0
|
|
||||||
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
im_og = image
|
|
||||||
width_og, height_og = im_og.size
|
|
||||||
# If we can adjust the max upscale size, then the 4 below should be our variable
|
|
||||||
down_sample_rate = target_scale / 4
|
|
||||||
wd = width_og * down_sample_rate
|
|
||||||
hd = height_og * down_sample_rate
|
|
||||||
width_downsampled_pre = int(np.ceil(wd))
|
|
||||||
height_downsampled_pre = int(np.ceil(hd))
|
|
||||||
|
|
||||||
if down_sample_rate != 1:
|
|
||||||
print(
|
|
||||||
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
|
||||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
|
||||||
else:
|
|
||||||
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
|
||||||
|
|
||||||
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
|
||||||
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
|
||||||
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
|
||||||
|
|
||||||
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
|
||||||
|
|
||||||
sample = logs["sample"]
|
|
||||||
sample = sample.detach().cpu()
|
|
||||||
sample = torch.clamp(sample, -1., 1.)
|
|
||||||
sample = (sample + 1.) / 2. * 255
|
|
||||||
sample = sample.numpy().astype(np.uint8)
|
|
||||||
sample = np.transpose(sample, (0, 2, 3, 1))
|
|
||||||
a = Image.fromarray(sample[0])
|
|
||||||
|
|
||||||
# remove padding
|
|
||||||
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
|
|
||||||
|
|
||||||
del model
|
|
||||||
gc.collect()
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
return a
|
|
||||||
|
|
||||||
|
|
||||||
def get_cond(selected_path):
|
|
||||||
example = {}
|
|
||||||
up_f = 4
|
|
||||||
c = selected_path.convert('RGB')
|
|
||||||
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
|
||||||
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
|
|
||||||
antialias=True)
|
|
||||||
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
|
|
||||||
c = rearrange(c, '1 c h w -> 1 h w c')
|
|
||||||
c = 2. * c - 1.
|
|
||||||
|
|
||||||
c = c.to(shared.device)
|
|
||||||
example["LR_image"] = c
|
|
||||||
example["image"] = c_up
|
|
||||||
|
|
||||||
return example
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
|
|
||||||
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
|
|
||||||
corrector_kwargs=None, x_t=None
|
|
||||||
):
|
|
||||||
ddim = DDIMSampler(model)
|
|
||||||
bs = shape[0]
|
|
||||||
shape = shape[1:]
|
|
||||||
print(f"Sampling with eta = {eta}; steps: {steps}")
|
|
||||||
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
|
|
||||||
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
|
|
||||||
mask=mask, x0=x0, temperature=temperature, verbose=False,
|
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs, x_t=x_t)
|
|
||||||
|
|
||||||
return samples, intermediates
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
|
||||||
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
|
||||||
log = {}
|
|
||||||
|
|
||||||
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
|
||||||
return_first_stage_outputs=True,
|
|
||||||
force_c_encode=not (hasattr(model, 'split_input_params')
|
|
||||||
and model.cond_stage_key == 'coordinates_bbox'),
|
|
||||||
return_original_cond=True)
|
|
||||||
|
|
||||||
if custom_shape is not None:
|
|
||||||
z = torch.randn(custom_shape)
|
|
||||||
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
|
||||||
|
|
||||||
z0 = None
|
|
||||||
|
|
||||||
log["input"] = x
|
|
||||||
log["reconstruction"] = xrec
|
|
||||||
|
|
||||||
if ismap(xc):
|
|
||||||
log["original_conditioning"] = model.to_rgb(xc)
|
|
||||||
if hasattr(model, 'cond_stage_key'):
|
|
||||||
log[model.cond_stage_key] = model.to_rgb(xc)
|
|
||||||
|
|
||||||
else:
|
|
||||||
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
|
||||||
if model.cond_stage_model:
|
|
||||||
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
|
||||||
if model.cond_stage_key == 'class_label':
|
|
||||||
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
|
||||||
|
|
||||||
with model.ema_scope("Plotting"):
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
|
|
||||||
eta=eta,
|
|
||||||
quantize_x0=quantize_x0, mask=None, x0=z0,
|
|
||||||
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
|
|
||||||
x_t=x_T)
|
|
||||||
t1 = time.time()
|
|
||||||
|
|
||||||
if ddim_use_x0_pred:
|
|
||||||
sample = intermediates['pred_x0'][-1]
|
|
||||||
|
|
||||||
x_sample = model.decode_first_stage(sample)
|
|
||||||
|
|
||||||
try:
|
|
||||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
|
||||||
log["sample_noquant"] = x_sample_noquant
|
|
||||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
log["sample"] = x_sample
|
|
||||||
log["time"] = t1 - t0
|
|
||||||
|
|
||||||
return log
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
import os
|
|
||||||
from modules import paths
|
|
||||||
|
|
||||||
|
|
||||||
def preload(parser):
|
|
||||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from modules.modelloader import load_file_from_url
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
|
||||||
from modules_forge.utils import prepare_free_memory
|
|
||||||
from ldsr_model_arch import LDSR
|
|
||||||
from modules import shared, script_callbacks, errors
|
|
||||||
import sd_hijack_autoencoder # noqa: F401
|
|
||||||
import sd_hijack_ddpm_v1 # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerLDSR(Upscaler):
|
|
||||||
def __init__(self, user_path):
|
|
||||||
self.name = "LDSR"
|
|
||||||
self.user_path = user_path
|
|
||||||
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
|
||||||
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
|
||||||
super().__init__()
|
|
||||||
scaler_data = UpscalerData("LDSR", None, self)
|
|
||||||
self.scalers = [scaler_data]
|
|
||||||
|
|
||||||
def load_model(self, path: str):
|
|
||||||
# Remove incorrect project.yaml file if too big
|
|
||||||
yaml_path = os.path.join(self.model_path, "project.yaml")
|
|
||||||
old_model_path = os.path.join(self.model_path, "model.pth")
|
|
||||||
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
|
||||||
|
|
||||||
local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"])
|
|
||||||
local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None)
|
|
||||||
local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None)
|
|
||||||
local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None)
|
|
||||||
|
|
||||||
if os.path.exists(yaml_path):
|
|
||||||
statinfo = os.stat(yaml_path)
|
|
||||||
if statinfo.st_size >= 10485760:
|
|
||||||
print("Removing invalid LDSR YAML file.")
|
|
||||||
os.remove(yaml_path)
|
|
||||||
|
|
||||||
if os.path.exists(old_model_path):
|
|
||||||
print("Renaming model from model.pth to model.ckpt")
|
|
||||||
os.rename(old_model_path, new_model_path)
|
|
||||||
|
|
||||||
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
|
|
||||||
model = local_safetensors_path
|
|
||||||
else:
|
|
||||||
model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
|
|
||||||
|
|
||||||
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
|
|
||||||
|
|
||||||
return LDSR(model, yaml)
|
|
||||||
|
|
||||||
def do_upscale(self, img, path):
|
|
||||||
prepare_free_memory(aggressive=True)
|
|
||||||
try:
|
|
||||||
ldsr = self.load_model(path)
|
|
||||||
except Exception:
|
|
||||||
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
|
|
||||||
return img
|
|
||||||
ddim_steps = shared.opts.ldsr_steps
|
|
||||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
|
||||||
|
|
||||||
|
|
||||||
def on_ui_settings():
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
|
|
||||||
shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
|
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_ui_settings(on_ui_settings)
|
|
||||||
@@ -1,293 +0,0 @@
|
|||||||
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
|
||||||
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
|
||||||
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
|
||||||
|
|
||||||
from ldm.modules.ema import LitEma
|
|
||||||
from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
|
|
||||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
|
|
||||||
import ldm.models.autoencoder
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
class VQModel(pl.LightningModule):
|
|
||||||
def __init__(self,
|
|
||||||
ddconfig,
|
|
||||||
lossconfig,
|
|
||||||
n_embed,
|
|
||||||
embed_dim,
|
|
||||||
ckpt_path=None,
|
|
||||||
ignore_keys=None,
|
|
||||||
image_key="image",
|
|
||||||
colorize_nlabels=None,
|
|
||||||
monitor=None,
|
|
||||||
batch_resize_range=None,
|
|
||||||
scheduler_config=None,
|
|
||||||
lr_g_factor=1.0,
|
|
||||||
remap=None,
|
|
||||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
|
||||||
use_ema=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.n_embed = n_embed
|
|
||||||
self.image_key = image_key
|
|
||||||
self.encoder = Encoder(**ddconfig)
|
|
||||||
self.decoder = Decoder(**ddconfig)
|
|
||||||
self.loss = instantiate_from_config(lossconfig)
|
|
||||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
|
||||||
remap=remap,
|
|
||||||
sane_index_shape=sane_index_shape)
|
|
||||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
|
||||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
|
||||||
if colorize_nlabels is not None:
|
|
||||||
assert type(colorize_nlabels)==int
|
|
||||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
|
||||||
if monitor is not None:
|
|
||||||
self.monitor = monitor
|
|
||||||
self.batch_resize_range = batch_resize_range
|
|
||||||
if self.batch_resize_range is not None:
|
|
||||||
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
|
||||||
|
|
||||||
self.use_ema = use_ema
|
|
||||||
if self.use_ema:
|
|
||||||
self.model_ema = LitEma(self)
|
|
||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
|
||||||
|
|
||||||
if ckpt_path is not None:
|
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
|
|
||||||
self.scheduler_config = scheduler_config
|
|
||||||
self.lr_g_factor = lr_g_factor
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def ema_scope(self, context=None):
|
|
||||||
if self.use_ema:
|
|
||||||
self.model_ema.store(self.parameters())
|
|
||||||
self.model_ema.copy_to(self)
|
|
||||||
if context is not None:
|
|
||||||
print(f"{context}: Switched to EMA weights")
|
|
||||||
try:
|
|
||||||
yield None
|
|
||||||
finally:
|
|
||||||
if self.use_ema:
|
|
||||||
self.model_ema.restore(self.parameters())
|
|
||||||
if context is not None:
|
|
||||||
print(f"{context}: Restored training weights")
|
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=None):
|
|
||||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
|
||||||
keys = list(sd.keys())
|
|
||||||
for k in keys:
|
|
||||||
for ik in ignore_keys or []:
|
|
||||||
if k.startswith(ik):
|
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
|
||||||
del sd[k]
|
|
||||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
|
||||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
|
||||||
if missing:
|
|
||||||
print(f"Missing Keys: {missing}")
|
|
||||||
if unexpected:
|
|
||||||
print(f"Unexpected Keys: {unexpected}")
|
|
||||||
|
|
||||||
def on_train_batch_end(self, *args, **kwargs):
|
|
||||||
if self.use_ema:
|
|
||||||
self.model_ema(self)
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
h = self.encoder(x)
|
|
||||||
h = self.quant_conv(h)
|
|
||||||
quant, emb_loss, info = self.quantize(h)
|
|
||||||
return quant, emb_loss, info
|
|
||||||
|
|
||||||
def encode_to_prequant(self, x):
|
|
||||||
h = self.encoder(x)
|
|
||||||
h = self.quant_conv(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
def decode(self, quant):
|
|
||||||
quant = self.post_quant_conv(quant)
|
|
||||||
dec = self.decoder(quant)
|
|
||||||
return dec
|
|
||||||
|
|
||||||
def decode_code(self, code_b):
|
|
||||||
quant_b = self.quantize.embed_code(code_b)
|
|
||||||
dec = self.decode(quant_b)
|
|
||||||
return dec
|
|
||||||
|
|
||||||
def forward(self, input, return_pred_indices=False):
|
|
||||||
quant, diff, (_,_,ind) = self.encode(input)
|
|
||||||
dec = self.decode(quant)
|
|
||||||
if return_pred_indices:
|
|
||||||
return dec, diff, ind
|
|
||||||
return dec, diff
|
|
||||||
|
|
||||||
def get_input(self, batch, k):
|
|
||||||
x = batch[k]
|
|
||||||
if len(x.shape) == 3:
|
|
||||||
x = x[..., None]
|
|
||||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
|
||||||
if self.batch_resize_range is not None:
|
|
||||||
lower_size = self.batch_resize_range[0]
|
|
||||||
upper_size = self.batch_resize_range[1]
|
|
||||||
if self.global_step <= 4:
|
|
||||||
# do the first few batches with max size to avoid later oom
|
|
||||||
new_resize = upper_size
|
|
||||||
else:
|
|
||||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
|
||||||
if new_resize != x.shape[2]:
|
|
||||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
|
||||||
x = x.detach()
|
|
||||||
return x
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
||||||
# https://github.com/pytorch/pytorch/issues/37142
|
|
||||||
# try not to fool the heuristics
|
|
||||||
x = self.get_input(batch, self.image_key)
|
|
||||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
|
||||||
|
|
||||||
if optimizer_idx == 0:
|
|
||||||
# autoencode
|
|
||||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
|
||||||
last_layer=self.get_last_layer(), split="train",
|
|
||||||
predicted_indices=ind)
|
|
||||||
|
|
||||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
|
||||||
return aeloss
|
|
||||||
|
|
||||||
if optimizer_idx == 1:
|
|
||||||
# discriminator
|
|
||||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
|
||||||
last_layer=self.get_last_layer(), split="train")
|
|
||||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
|
||||||
return discloss
|
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
|
||||||
log_dict = self._validation_step(batch, batch_idx)
|
|
||||||
with self.ema_scope():
|
|
||||||
self._validation_step(batch, batch_idx, suffix="_ema")
|
|
||||||
return log_dict
|
|
||||||
|
|
||||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
|
||||||
x = self.get_input(batch, self.image_key)
|
|
||||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
|
||||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
|
||||||
self.global_step,
|
|
||||||
last_layer=self.get_last_layer(),
|
|
||||||
split="val"+suffix,
|
|
||||||
predicted_indices=ind
|
|
||||||
)
|
|
||||||
|
|
||||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
|
||||||
self.global_step,
|
|
||||||
last_layer=self.get_last_layer(),
|
|
||||||
split="val"+suffix,
|
|
||||||
predicted_indices=ind
|
|
||||||
)
|
|
||||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
|
||||||
self.log(f"val{suffix}/rec_loss", rec_loss,
|
|
||||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
|
||||||
self.log(f"val{suffix}/aeloss", aeloss,
|
|
||||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
|
||||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
|
||||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
|
||||||
self.log_dict(log_dict_ae)
|
|
||||||
self.log_dict(log_dict_disc)
|
|
||||||
return self.log_dict
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
lr_d = self.learning_rate
|
|
||||||
lr_g = self.lr_g_factor*self.learning_rate
|
|
||||||
print("lr_d", lr_d)
|
|
||||||
print("lr_g", lr_g)
|
|
||||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
|
||||||
list(self.decoder.parameters())+
|
|
||||||
list(self.quantize.parameters())+
|
|
||||||
list(self.quant_conv.parameters())+
|
|
||||||
list(self.post_quant_conv.parameters()),
|
|
||||||
lr=lr_g, betas=(0.5, 0.9))
|
|
||||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
|
||||||
lr=lr_d, betas=(0.5, 0.9))
|
|
||||||
|
|
||||||
if self.scheduler_config is not None:
|
|
||||||
scheduler = instantiate_from_config(self.scheduler_config)
|
|
||||||
|
|
||||||
print("Setting up LambdaLR scheduler...")
|
|
||||||
scheduler = [
|
|
||||||
{
|
|
||||||
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
|
||||||
'interval': 'step',
|
|
||||||
'frequency': 1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
|
||||||
'interval': 'step',
|
|
||||||
'frequency': 1
|
|
||||||
},
|
|
||||||
]
|
|
||||||
return [opt_ae, opt_disc], scheduler
|
|
||||||
return [opt_ae, opt_disc], []
|
|
||||||
|
|
||||||
def get_last_layer(self):
|
|
||||||
return self.decoder.conv_out.weight
|
|
||||||
|
|
||||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
|
||||||
log = {}
|
|
||||||
x = self.get_input(batch, self.image_key)
|
|
||||||
x = x.to(self.device)
|
|
||||||
if only_inputs:
|
|
||||||
log["inputs"] = x
|
|
||||||
return log
|
|
||||||
xrec, _ = self(x)
|
|
||||||
if x.shape[1] > 3:
|
|
||||||
# colorize with random projection
|
|
||||||
assert xrec.shape[1] > 3
|
|
||||||
x = self.to_rgb(x)
|
|
||||||
xrec = self.to_rgb(xrec)
|
|
||||||
log["inputs"] = x
|
|
||||||
log["reconstructions"] = xrec
|
|
||||||
if plot_ema:
|
|
||||||
with self.ema_scope():
|
|
||||||
xrec_ema, _ = self(x)
|
|
||||||
if x.shape[1] > 3:
|
|
||||||
xrec_ema = self.to_rgb(xrec_ema)
|
|
||||||
log["reconstructions_ema"] = xrec_ema
|
|
||||||
return log
|
|
||||||
|
|
||||||
def to_rgb(self, x):
|
|
||||||
assert self.image_key == "segmentation"
|
|
||||||
if not hasattr(self, "colorize"):
|
|
||||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
|
||||||
x = F.conv2d(x, weight=self.colorize)
|
|
||||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class VQModelInterface(VQModel):
|
|
||||||
def __init__(self, embed_dim, *args, **kwargs):
|
|
||||||
super().__init__(*args, embed_dim=embed_dim, **kwargs)
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
h = self.encoder(x)
|
|
||||||
h = self.quant_conv(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
def decode(self, h, force_not_quantize=False):
|
|
||||||
# also go through quantization layer
|
|
||||||
if not force_not_quantize:
|
|
||||||
quant, emb_loss, info = self.quantize(h)
|
|
||||||
else:
|
|
||||||
quant = h
|
|
||||||
quant = self.post_quant_conv(quant)
|
|
||||||
dec = self.decoder(quant)
|
|
||||||
return dec
|
|
||||||
|
|
||||||
ldm.models.autoencoder.VQModel = VQModel
|
|
||||||
ldm.models.autoencoder.VQModelInterface = VQModelInterface
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,147 +0,0 @@
|
|||||||
# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py,
|
|
||||||
# where the license is as follows:
|
|
||||||
#
|
|
||||||
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
|
||||||
#
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
# of this software and associated documentation files (the "Software"), to deal
|
|
||||||
# in the Software without restriction, including without limitation the rights
|
|
||||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
# copies of the Software, and to permit persons to whom the Software is
|
|
||||||
# furnished to do so, subject to the following conditions:
|
|
||||||
#
|
|
||||||
# The above copyright notice and this permission notice shall be included in all
|
|
||||||
# copies or substantial portions of the Software.
|
|
||||||
#
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
||||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
||||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
||||||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
|
||||||
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
|
||||||
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
|
||||||
# OR OTHER DEALINGS IN THE SOFTWARE./
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
|
|
||||||
class VectorQuantizer2(nn.Module):
|
|
||||||
"""
|
|
||||||
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
|
||||||
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
|
||||||
# backwards compatibility we use the buggy version by default, but you can
|
|
||||||
# specify legacy=False to fix it.
|
|
||||||
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
|
||||||
sane_index_shape=False, legacy=True):
|
|
||||||
super().__init__()
|
|
||||||
self.n_e = n_e
|
|
||||||
self.e_dim = e_dim
|
|
||||||
self.beta = beta
|
|
||||||
self.legacy = legacy
|
|
||||||
|
|
||||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
|
||||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
|
||||||
|
|
||||||
self.remap = remap
|
|
||||||
if self.remap is not None:
|
|
||||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
|
||||||
self.re_embed = self.used.shape[0]
|
|
||||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
|
||||||
if self.unknown_index == "extra":
|
|
||||||
self.unknown_index = self.re_embed
|
|
||||||
self.re_embed = self.re_embed + 1
|
|
||||||
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
|
||||||
f"Using {self.unknown_index} for unknown indices.")
|
|
||||||
else:
|
|
||||||
self.re_embed = n_e
|
|
||||||
|
|
||||||
self.sane_index_shape = sane_index_shape
|
|
||||||
|
|
||||||
def remap_to_used(self, inds):
|
|
||||||
ishape = inds.shape
|
|
||||||
assert len(ishape) > 1
|
|
||||||
inds = inds.reshape(ishape[0], -1)
|
|
||||||
used = self.used.to(inds)
|
|
||||||
match = (inds[:, :, None] == used[None, None, ...]).long()
|
|
||||||
new = match.argmax(-1)
|
|
||||||
unknown = match.sum(2) < 1
|
|
||||||
if self.unknown_index == "random":
|
|
||||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
|
||||||
else:
|
|
||||||
new[unknown] = self.unknown_index
|
|
||||||
return new.reshape(ishape)
|
|
||||||
|
|
||||||
def unmap_to_all(self, inds):
|
|
||||||
ishape = inds.shape
|
|
||||||
assert len(ishape) > 1
|
|
||||||
inds = inds.reshape(ishape[0], -1)
|
|
||||||
used = self.used.to(inds)
|
|
||||||
if self.re_embed > self.used.shape[0]: # extra token
|
|
||||||
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
|
||||||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
|
||||||
return back.reshape(ishape)
|
|
||||||
|
|
||||||
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
|
||||||
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
|
||||||
assert rescale_logits is False, "Only for interface compatible with Gumbel"
|
|
||||||
assert return_logits is False, "Only for interface compatible with Gumbel"
|
|
||||||
# reshape z -> (batch, height, width, channel) and flatten
|
|
||||||
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
|
||||||
z_flattened = z.view(-1, self.e_dim)
|
|
||||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
|
||||||
|
|
||||||
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
|
||||||
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
|
|
||||||
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
|
||||||
|
|
||||||
min_encoding_indices = torch.argmin(d, dim=1)
|
|
||||||
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
|
||||||
perplexity = None
|
|
||||||
min_encodings = None
|
|
||||||
|
|
||||||
# compute loss for embedding
|
|
||||||
if not self.legacy:
|
|
||||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
|
|
||||||
torch.mean((z_q - z.detach()) ** 2)
|
|
||||||
else:
|
|
||||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
|
|
||||||
torch.mean((z_q - z.detach()) ** 2)
|
|
||||||
|
|
||||||
# preserve gradients
|
|
||||||
z_q = z + (z_q - z).detach()
|
|
||||||
|
|
||||||
# reshape back to match original input shape
|
|
||||||
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
|
||||||
|
|
||||||
if self.remap is not None:
|
|
||||||
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
|
||||||
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
|
||||||
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
|
||||||
|
|
||||||
if self.sane_index_shape:
|
|
||||||
min_encoding_indices = min_encoding_indices.reshape(
|
|
||||||
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
|
||||||
|
|
||||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
|
||||||
|
|
||||||
def get_codebook_entry(self, indices, shape):
|
|
||||||
# shape specifying (batch, height, width, channel)
|
|
||||||
if self.remap is not None:
|
|
||||||
indices = indices.reshape(shape[0], -1) # add batch axis
|
|
||||||
indices = self.unmap_to_all(indices)
|
|
||||||
indices = indices.reshape(-1) # flatten again
|
|
||||||
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = self.embedding(indices)
|
|
||||||
|
|
||||||
if shape is not None:
|
|
||||||
z_q = z_q.view(shape)
|
|
||||||
# reshape back to match original input shape
|
|
||||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return z_q
|
|
||||||
@@ -11,9 +11,12 @@ from transformers.models.clip.modeling_clip import CLIPVisionModelOutput
|
|||||||
from annotator.util import HWC3
|
from annotator.util import HWC3
|
||||||
from typing import Callable, Tuple, Union
|
from typing import Callable, Tuple, Union
|
||||||
|
|
||||||
from modules.safe import Extra
|
|
||||||
from modules import devices
|
from modules import devices
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
Extra = lambda x: contextlib.nullcontext()
|
||||||
|
|
||||||
|
|
||||||
def torch_handler(module: str, name: str):
|
def torch_handler(module: str, name: str):
|
||||||
""" Allow all torch access. Bypass A1111 safety whitelist. """
|
""" Allow all torch access. Bypass A1111 safety whitelist. """
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import cv2
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List
|
from typing import Any, Callable, Dict, List
|
||||||
from modules.safe import unsafe_torch_load
|
|
||||||
from lib_controlnet.logging import logger
|
from lib_controlnet.logging import logger
|
||||||
|
|
||||||
|
|
||||||
@@ -28,7 +27,7 @@ def load_state_dict(ckpt_path, location="cpu"):
|
|||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
||||||
else:
|
else:
|
||||||
state_dict = unsafe_torch_load(ckpt_path, map_location=torch.device(location))
|
state_dict = torch.load(ckpt_path, map_location=torch.device(location))
|
||||||
state_dict = get_state_dict(state_dict)
|
state_dict = get_state_dict(state_dict)
|
||||||
logger.info(f"Loaded state_dict from [{ckpt_path}]")
|
logger.info(f"Loaded state_dict from [{ckpt_path}]")
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
|
|||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin
|
from PIL import PngImagePlugin
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -725,7 +724,7 @@ class Api:
|
|||||||
|
|
||||||
def get_sd_models(self):
|
def get_sd_models(self):
|
||||||
import modules.sd_models as sd_models
|
import modules.sd_models as sd_models
|
||||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
|
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename} for x in sd_models.checkpoints_list.values()]
|
||||||
|
|
||||||
def get_sd_vaes(self):
|
def get_sd_vaes(self):
|
||||||
import modules.sd_vae as sd_vae
|
import modules.sd_vae as sd_vae
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import modules.textual_inversion.dataset
|
|||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.util import default
|
from backend.nn.unet import default
|
||||||
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||||
from modules.textual_inversion import textual_inversion, saving_settings
|
from modules.textual_inversion import textual_inversion, saving_settings
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
|||||||
@@ -1,25 +1,12 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
from modules.timer import startup_timer
|
from modules.timer import startup_timer
|
||||||
|
|
||||||
|
|
||||||
class HiddenPrints:
|
|
||||||
def __enter__(self):
|
|
||||||
self._original_stdout = sys.stdout
|
|
||||||
sys.stdout = open(os.devnull, 'w')
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
sys.stdout.close()
|
|
||||||
sys.stdout = self._original_stdout
|
|
||||||
|
|
||||||
|
|
||||||
def imports():
|
def imports():
|
||||||
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
@@ -35,16 +22,8 @@ def imports():
|
|||||||
import gradio # noqa: F401
|
import gradio # noqa: F401
|
||||||
startup_timer.record("import gradio")
|
startup_timer.record("import gradio")
|
||||||
|
|
||||||
with HiddenPrints():
|
from modules import paths, timer, import_hook, errors # noqa: F401
|
||||||
from modules import paths, timer, import_hook, errors # noqa: F401
|
startup_timer.record("setup paths")
|
||||||
startup_timer.record("setup paths")
|
|
||||||
|
|
||||||
import ldm.modules.encoders.modules # noqa: F401
|
|
||||||
import ldm.modules.diffusionmodules.model
|
|
||||||
startup_timer.record("import ldm")
|
|
||||||
|
|
||||||
import sgm.modules.encoders.modules # noqa: F401
|
|
||||||
startup_timer.record("import sgm")
|
|
||||||
|
|
||||||
from modules import shared_init
|
from modules import shared_init
|
||||||
shared_init.initialize()
|
shared_init.initialize()
|
||||||
@@ -137,15 +116,6 @@ def initialize_rest(*, reload_script_modules=False):
|
|||||||
sd_vae.refresh_vae_list()
|
sd_vae.refresh_vae_list()
|
||||||
startup_timer.record("refresh VAE")
|
startup_timer.record("refresh VAE")
|
||||||
|
|
||||||
from modules import textual_inversion
|
|
||||||
textual_inversion.textual_inversion.list_textual_inversion_templates()
|
|
||||||
startup_timer.record("refresh textual inversion templates")
|
|
||||||
|
|
||||||
from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
|
|
||||||
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
|
|
||||||
sd_hijack.list_optimizers()
|
|
||||||
startup_timer.record("scripts list_optimizers")
|
|
||||||
|
|
||||||
from modules import sd_unet
|
from modules import sd_unet
|
||||||
sd_unet.list_unets()
|
sd_unet.list_unets()
|
||||||
startup_timer.record("scripts list_unets")
|
startup_timer.record("scripts list_unets")
|
||||||
|
|||||||
@@ -391,15 +391,15 @@ def prepare_environment():
|
|||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||||
|
|
||||||
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
||||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
# stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||||
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
# stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||||
huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git')
|
huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git')
|
||||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||||
|
|
||||||
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
# stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
# stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4")
|
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
@@ -456,8 +456,8 @@ def prepare_environment():
|
|||||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||||
|
|
||||||
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
||||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
# git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
# git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||||
git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash)
|
git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash)
|
||||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||||
|
|||||||
@@ -2,45 +2,15 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
|
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
|
||||||
|
|
||||||
import modules.safe # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
def mute_sdxl_imports():
|
|
||||||
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
|
|
||||||
|
|
||||||
class Dummy:
|
|
||||||
pass
|
|
||||||
|
|
||||||
module = Dummy()
|
|
||||||
module.LPIPS = None
|
|
||||||
sys.modules['taming.modules.losses.lpips'] = module
|
|
||||||
|
|
||||||
module = Dummy()
|
|
||||||
module.StableDataModuleFromConfig = None
|
|
||||||
sys.modules['sgm.data'] = module
|
|
||||||
|
|
||||||
|
|
||||||
# data_path = cmd_opts_pre.data
|
|
||||||
sys.path.insert(0, script_path)
|
sys.path.insert(0, script_path)
|
||||||
|
|
||||||
# search for directory of stable diffusion in following places
|
sd_path = os.path.dirname(__file__)
|
||||||
sd_path = None
|
|
||||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
|
|
||||||
for possible_sd_path in possible_sd_paths:
|
|
||||||
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
|
||||||
sd_path = os.path.abspath(possible_sd_path)
|
|
||||||
break
|
|
||||||
|
|
||||||
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
|
||||||
|
|
||||||
mute_sdxl_imports()
|
|
||||||
|
|
||||||
path_dirs = [
|
path_dirs = [
|
||||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
(os.path.join(sd_path, '../repositories/BLIP'), 'models/blip.py', 'BLIP', []),
|
||||||
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
(os.path.join(sd_path, '../repositories/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
(os.path.join(sd_path, '../repositories/huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
|
||||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
|
||||||
(os.path.join(sd_path, '../huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
paths = {}
|
paths = {}
|
||||||
@@ -53,13 +23,6 @@ for d, must_exist, what, options in path_dirs:
|
|||||||
d = os.path.abspath(d)
|
d = os.path.abspath(d)
|
||||||
if "atstart" in options:
|
if "atstart" in options:
|
||||||
sys.path.insert(0, d)
|
sys.path.insert(0, d)
|
||||||
elif "sgm" in options:
|
|
||||||
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
|
|
||||||
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
|
|
||||||
|
|
||||||
sys.path.insert(0, d)
|
|
||||||
import sgm # noqa: F401
|
|
||||||
sys.path.pop(0)
|
|
||||||
else:
|
else:
|
||||||
sys.path.append(d)
|
sys.path.append(d)
|
||||||
paths[what] = d
|
paths[what] = d
|
||||||
|
|||||||
@@ -28,8 +28,6 @@ import modules.images as images
|
|||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.sd_models as sd_models
|
import modules.sd_models as sd_models
|
||||||
import modules.sd_vae as sd_vae
|
import modules.sd_vae as sd_vae
|
||||||
from ldm.data.util import AddMiDaS
|
|
||||||
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
|
||||||
|
|
||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
from blendmodes.blend import blendLayers, BlendType
|
from blendmodes.blend import blendLayers, BlendType
|
||||||
@@ -295,23 +293,7 @@ class StableDiffusionProcessing:
|
|||||||
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
||||||
|
|
||||||
def depth2img_image_conditioning(self, source_image):
|
def depth2img_image_conditioning(self, source_image):
|
||||||
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
raise NotImplementedError('NotImplementedError: depth2img_image_conditioning')
|
||||||
transformer = AddMiDaS(model_type="dpt_hybrid")
|
|
||||||
transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
|
|
||||||
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
|
||||||
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
|
||||||
|
|
||||||
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
|
||||||
conditioning = torch.nn.functional.interpolate(
|
|
||||||
self.sd_model.depth_model(midas_in),
|
|
||||||
size=conditioning_image.shape[2:],
|
|
||||||
mode="bicubic",
|
|
||||||
align_corners=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
(depth_min, depth_max) = torch.aminmax(conditioning)
|
|
||||||
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
|
||||||
return conditioning
|
|
||||||
|
|
||||||
def edit_image_conditioning(self, source_image):
|
def edit_image_conditioning(self, source_image):
|
||||||
conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
|
conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
|
||||||
@@ -368,11 +350,6 @@ class StableDiffusionProcessing:
|
|||||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
|
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
|
||||||
source_image = devices.cond_cast_float(source_image)
|
source_image = devices.cond_cast_float(source_image)
|
||||||
|
|
||||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
|
||||||
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
|
||||||
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
|
||||||
return self.depth2img_image_conditioning(source_image)
|
|
||||||
|
|
||||||
if self.sd_model.cond_stage_key == "edit":
|
if self.sd_model.cond_stage_key == "edit":
|
||||||
return self.edit_image_conditioning(source_image)
|
return self.edit_image_conditioning(source_image)
|
||||||
|
|
||||||
|
|||||||
390
modules/safe.py
390
modules/safe.py
@@ -1,195 +1,195 @@
|
|||||||
# this code is adapted from the script contributed by anon from /h/
|
# # this code is adapted from the script contributed by anon from /h/
|
||||||
|
#
|
||||||
import pickle
|
# import pickle
|
||||||
import collections
|
# import collections
|
||||||
|
#
|
||||||
import torch
|
# import torch
|
||||||
import numpy
|
# import numpy
|
||||||
import _codecs
|
# import _codecs
|
||||||
import zipfile
|
# import zipfile
|
||||||
import re
|
# import re
|
||||||
|
#
|
||||||
|
#
|
||||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
# # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||||
from modules import errors
|
# from modules import errors
|
||||||
|
#
|
||||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
# TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||||
|
#
|
||||||
def encode(*args):
|
# def encode(*args):
|
||||||
out = _codecs.encode(*args)
|
# out = _codecs.encode(*args)
|
||||||
return out
|
# return out
|
||||||
|
#
|
||||||
|
#
|
||||||
class RestrictedUnpickler(pickle.Unpickler):
|
# class RestrictedUnpickler(pickle.Unpickler):
|
||||||
extra_handler = None
|
# extra_handler = None
|
||||||
|
#
|
||||||
def persistent_load(self, saved_id):
|
# def persistent_load(self, saved_id):
|
||||||
assert saved_id[0] == 'storage'
|
# assert saved_id[0] == 'storage'
|
||||||
|
#
|
||||||
try:
|
# try:
|
||||||
return TypedStorage(_internal=True)
|
# return TypedStorage(_internal=True)
|
||||||
except TypeError:
|
# except TypeError:
|
||||||
return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
|
# return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
|
||||||
|
#
|
||||||
def find_class(self, module, name):
|
# def find_class(self, module, name):
|
||||||
if self.extra_handler is not None:
|
# if self.extra_handler is not None:
|
||||||
res = self.extra_handler(module, name)
|
# res = self.extra_handler(module, name)
|
||||||
if res is not None:
|
# if res is not None:
|
||||||
return res
|
# return res
|
||||||
|
#
|
||||||
if module == 'collections' and name == 'OrderedDict':
|
# if module == 'collections' and name == 'OrderedDict':
|
||||||
return getattr(collections, name)
|
# return getattr(collections, name)
|
||||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
# if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
||||||
return getattr(torch._utils, name)
|
# return getattr(torch._utils, name)
|
||||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
|
# if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
|
||||||
return getattr(torch, name)
|
# return getattr(torch, name)
|
||||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
# if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||||
return getattr(torch.nn.modules.container, name)
|
# return getattr(torch.nn.modules.container, name)
|
||||||
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
# if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
||||||
return getattr(numpy.core.multiarray, name)
|
# return getattr(numpy.core.multiarray, name)
|
||||||
if module == 'numpy' and name in ['dtype', 'ndarray']:
|
# if module == 'numpy' and name in ['dtype', 'ndarray']:
|
||||||
return getattr(numpy, name)
|
# return getattr(numpy, name)
|
||||||
if module == '_codecs' and name == 'encode':
|
# if module == '_codecs' and name == 'encode':
|
||||||
return encode
|
# return encode
|
||||||
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
# if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
||||||
import pytorch_lightning.callbacks
|
# import pytorch_lightning.callbacks
|
||||||
return pytorch_lightning.callbacks.model_checkpoint
|
# return pytorch_lightning.callbacks.model_checkpoint
|
||||||
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
# if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
||||||
import pytorch_lightning.callbacks.model_checkpoint
|
# import pytorch_lightning.callbacks.model_checkpoint
|
||||||
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
# return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
||||||
if module == "__builtin__" and name == 'set':
|
# if module == "__builtin__" and name == 'set':
|
||||||
return set
|
# return set
|
||||||
|
#
|
||||||
# Forbid everything else.
|
# # Forbid everything else.
|
||||||
raise Exception(f"global '{module}/{name}' is forbidden")
|
# raise Exception(f"global '{module}/{name}' is forbidden")
|
||||||
|
#
|
||||||
|
#
|
||||||
# Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/<number>'
|
# # Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/<number>'
|
||||||
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$")
|
# allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$")
|
||||||
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
# data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||||
|
#
|
||||||
def check_zip_filenames(filename, names):
|
# def check_zip_filenames(filename, names):
|
||||||
for name in names:
|
# for name in names:
|
||||||
if allowed_zip_names_re.match(name):
|
# if allowed_zip_names_re.match(name):
|
||||||
continue
|
# continue
|
||||||
|
#
|
||||||
raise Exception(f"bad file inside {filename}: {name}")
|
# raise Exception(f"bad file inside {filename}: {name}")
|
||||||
|
#
|
||||||
|
#
|
||||||
def check_pt(filename, extra_handler):
|
# def check_pt(filename, extra_handler):
|
||||||
try:
|
# try:
|
||||||
|
#
|
||||||
# new pytorch format is a zip file
|
# # new pytorch format is a zip file
|
||||||
with zipfile.ZipFile(filename) as z:
|
# with zipfile.ZipFile(filename) as z:
|
||||||
check_zip_filenames(filename, z.namelist())
|
# check_zip_filenames(filename, z.namelist())
|
||||||
|
#
|
||||||
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
# # find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||||
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
# data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||||
if len(data_pkl_filenames) == 0:
|
# if len(data_pkl_filenames) == 0:
|
||||||
raise Exception(f"data.pkl not found in {filename}")
|
# raise Exception(f"data.pkl not found in {filename}")
|
||||||
if len(data_pkl_filenames) > 1:
|
# if len(data_pkl_filenames) > 1:
|
||||||
raise Exception(f"Multiple data.pkl found in {filename}")
|
# raise Exception(f"Multiple data.pkl found in {filename}")
|
||||||
with z.open(data_pkl_filenames[0]) as file:
|
# with z.open(data_pkl_filenames[0]) as file:
|
||||||
unpickler = RestrictedUnpickler(file)
|
# unpickler = RestrictedUnpickler(file)
|
||||||
unpickler.extra_handler = extra_handler
|
# unpickler.extra_handler = extra_handler
|
||||||
unpickler.load()
|
# unpickler.load()
|
||||||
|
#
|
||||||
except zipfile.BadZipfile:
|
# except zipfile.BadZipfile:
|
||||||
|
#
|
||||||
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
# # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
||||||
with open(filename, "rb") as file:
|
# with open(filename, "rb") as file:
|
||||||
unpickler = RestrictedUnpickler(file)
|
# unpickler = RestrictedUnpickler(file)
|
||||||
unpickler.extra_handler = extra_handler
|
# unpickler.extra_handler = extra_handler
|
||||||
for _ in range(5):
|
# for _ in range(5):
|
||||||
unpickler.load()
|
# unpickler.load()
|
||||||
|
#
|
||||||
|
#
|
||||||
def load(filename, *args, **kwargs):
|
# def load(filename, *args, **kwargs):
|
||||||
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
# return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||||
|
#
|
||||||
|
#
|
||||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
# def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||||
"""
|
# """
|
||||||
this function is intended to be used by extensions that want to load models with
|
# this function is intended to be used by extensions that want to load models with
|
||||||
some extra classes in them that the usual unpickler would find suspicious.
|
# some extra classes in them that the usual unpickler would find suspicious.
|
||||||
|
#
|
||||||
Use the extra_handler argument to specify a function that takes module and field name as text,
|
# Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||||
and returns that field's value:
|
# and returns that field's value:
|
||||||
|
#
|
||||||
```python
|
# ```python
|
||||||
def extra(module, name):
|
# def extra(module, name):
|
||||||
if module == 'collections' and name == 'OrderedDict':
|
# if module == 'collections' and name == 'OrderedDict':
|
||||||
return collections.OrderedDict
|
# return collections.OrderedDict
|
||||||
|
#
|
||||||
return None
|
# return None
|
||||||
|
#
|
||||||
safe.load_with_extra('model.pt', extra_handler=extra)
|
# safe.load_with_extra('model.pt', extra_handler=extra)
|
||||||
```
|
# ```
|
||||||
|
#
|
||||||
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
# The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
||||||
definitely unsafe.
|
# definitely unsafe.
|
||||||
"""
|
# """
|
||||||
|
#
|
||||||
from modules import shared
|
# from modules import shared
|
||||||
|
#
|
||||||
try:
|
# try:
|
||||||
if not shared.cmd_opts.disable_safe_unpickle:
|
# if not shared.cmd_opts.disable_safe_unpickle:
|
||||||
check_pt(filename, extra_handler)
|
# check_pt(filename, extra_handler)
|
||||||
|
#
|
||||||
except pickle.UnpicklingError:
|
# except pickle.UnpicklingError:
|
||||||
errors.report(
|
# errors.report(
|
||||||
f"Error verifying pickled file from {filename}\n"
|
# f"Error verifying pickled file from {filename}\n"
|
||||||
"-----> !!!! The file is most likely corrupted !!!! <-----\n"
|
# "-----> !!!! The file is most likely corrupted !!!! <-----\n"
|
||||||
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
|
# "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
|
||||||
exc_info=True,
|
# exc_info=True,
|
||||||
)
|
# )
|
||||||
return None
|
# return None
|
||||||
except Exception:
|
# except Exception:
|
||||||
errors.report(
|
# errors.report(
|
||||||
f"Error verifying pickled file from {filename}\n"
|
# f"Error verifying pickled file from {filename}\n"
|
||||||
f"The file may be malicious, so the program is not going to read it.\n"
|
# f"The file may be malicious, so the program is not going to read it.\n"
|
||||||
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
|
# f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
|
||||||
exc_info=True,
|
# exc_info=True,
|
||||||
)
|
# )
|
||||||
return None
|
# return None
|
||||||
|
#
|
||||||
return unsafe_torch_load(filename, *args, **kwargs)
|
# return unsafe_torch_load(filename, *args, **kwargs)
|
||||||
|
#
|
||||||
|
#
|
||||||
class Extra:
|
# class Extra:
|
||||||
"""
|
# """
|
||||||
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
# A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
||||||
(because it's not your code making the torch.load call). The intended use is like this:
|
# (because it's not your code making the torch.load call). The intended use is like this:
|
||||||
|
#
|
||||||
```
|
# ```
|
||||||
import torch
|
# import torch
|
||||||
from modules import safe
|
# from modules import safe
|
||||||
|
#
|
||||||
def handler(module, name):
|
# def handler(module, name):
|
||||||
if module == 'torch' and name in ['float64', 'float16']:
|
# if module == 'torch' and name in ['float64', 'float16']:
|
||||||
return getattr(torch, name)
|
# return getattr(torch, name)
|
||||||
|
#
|
||||||
return None
|
# return None
|
||||||
|
#
|
||||||
with safe.Extra(handler):
|
# with safe.Extra(handler):
|
||||||
x = torch.load('model.pt')
|
# x = torch.load('model.pt')
|
||||||
```
|
# ```
|
||||||
"""
|
# """
|
||||||
|
#
|
||||||
def __init__(self, handler):
|
# def __init__(self, handler):
|
||||||
self.handler = handler
|
# self.handler = handler
|
||||||
|
#
|
||||||
def __enter__(self):
|
# def __enter__(self):
|
||||||
global global_extra_handler
|
# global global_extra_handler
|
||||||
|
#
|
||||||
assert global_extra_handler is None, 'already inside an Extra() block'
|
# assert global_extra_handler is None, 'already inside an Extra() block'
|
||||||
global_extra_handler = self.handler
|
# global_extra_handler = self.handler
|
||||||
|
#
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
global global_extra_handler
|
# global global_extra_handler
|
||||||
|
#
|
||||||
global_extra_handler = None
|
# global_extra_handler = None
|
||||||
|
#
|
||||||
|
#
|
||||||
unsafe_torch_load = torch.load
|
# unsafe_torch_load = torch.load
|
||||||
global_extra_handler = None
|
# global_extra_handler = None
|
||||||
|
|||||||
@@ -1,232 +1,232 @@
|
|||||||
import ldm.modules.encoders.modules
|
# import ldm.modules.encoders.modules
|
||||||
import open_clip
|
# import open_clip
|
||||||
import torch
|
# import torch
|
||||||
import transformers.utils.hub
|
# import transformers.utils.hub
|
||||||
|
#
|
||||||
from modules import shared
|
# from modules import shared
|
||||||
|
#
|
||||||
|
#
|
||||||
class ReplaceHelper:
|
# class ReplaceHelper:
|
||||||
def __init__(self):
|
# def __init__(self):
|
||||||
self.replaced = []
|
# self.replaced = []
|
||||||
|
#
|
||||||
def replace(self, obj, field, func):
|
# def replace(self, obj, field, func):
|
||||||
original = getattr(obj, field, None)
|
# original = getattr(obj, field, None)
|
||||||
if original is None:
|
# if original is None:
|
||||||
return None
|
# return None
|
||||||
|
#
|
||||||
self.replaced.append((obj, field, original))
|
# self.replaced.append((obj, field, original))
|
||||||
setattr(obj, field, func)
|
# setattr(obj, field, func)
|
||||||
|
#
|
||||||
return original
|
# return original
|
||||||
|
#
|
||||||
def restore(self):
|
# def restore(self):
|
||||||
for obj, field, original in self.replaced:
|
# for obj, field, original in self.replaced:
|
||||||
setattr(obj, field, original)
|
# setattr(obj, field, original)
|
||||||
|
#
|
||||||
self.replaced.clear()
|
# self.replaced.clear()
|
||||||
|
#
|
||||||
|
#
|
||||||
class DisableInitialization(ReplaceHelper):
|
# class DisableInitialization(ReplaceHelper):
|
||||||
"""
|
# """
|
||||||
When an object of this class enters a `with` block, it starts:
|
# When an object of this class enters a `with` block, it starts:
|
||||||
- preventing torch's layer initialization functions from working
|
# - preventing torch's layer initialization functions from working
|
||||||
- changes CLIP and OpenCLIP to not download model weights
|
# - changes CLIP and OpenCLIP to not download model weights
|
||||||
- changes CLIP to not make requests to check if there is a new version of a file you already have
|
# - changes CLIP to not make requests to check if there is a new version of a file you already have
|
||||||
|
#
|
||||||
When it leaves the block, it reverts everything to how it was before.
|
# When it leaves the block, it reverts everything to how it was before.
|
||||||
|
#
|
||||||
Use it like this:
|
# Use it like this:
|
||||||
```
|
# ```
|
||||||
with DisableInitialization():
|
# with DisableInitialization():
|
||||||
do_things()
|
# do_things()
|
||||||
```
|
# ```
|
||||||
"""
|
# """
|
||||||
|
#
|
||||||
def __init__(self, disable_clip=True):
|
# def __init__(self, disable_clip=True):
|
||||||
super().__init__()
|
# super().__init__()
|
||||||
self.disable_clip = disable_clip
|
# self.disable_clip = disable_clip
|
||||||
|
#
|
||||||
def replace(self, obj, field, func):
|
# def replace(self, obj, field, func):
|
||||||
original = getattr(obj, field, None)
|
# original = getattr(obj, field, None)
|
||||||
if original is None:
|
# if original is None:
|
||||||
return None
|
# return None
|
||||||
|
#
|
||||||
self.replaced.append((obj, field, original))
|
# self.replaced.append((obj, field, original))
|
||||||
setattr(obj, field, func)
|
# setattr(obj, field, func)
|
||||||
|
#
|
||||||
return original
|
# return original
|
||||||
|
#
|
||||||
def __enter__(self):
|
# def __enter__(self):
|
||||||
def do_nothing(*args, **kwargs):
|
# def do_nothing(*args, **kwargs):
|
||||||
pass
|
# pass
|
||||||
|
#
|
||||||
def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
|
# def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
|
||||||
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
# return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
||||||
|
#
|
||||||
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
# def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
# res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
||||||
res.name_or_path = pretrained_model_name_or_path
|
# res.name_or_path = pretrained_model_name_or_path
|
||||||
return res
|
# return res
|
||||||
|
#
|
||||||
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
# def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
||||||
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
# args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
||||||
return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
|
# return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
|
||||||
|
#
|
||||||
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
# def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
||||||
|
#
|
||||||
# this file is always 404, prevent making request
|
# # this file is always 404, prevent making request
|
||||||
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
# if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
||||||
return None
|
# return None
|
||||||
|
#
|
||||||
try:
|
# try:
|
||||||
res = original(url, *args, local_files_only=True, **kwargs)
|
# res = original(url, *args, local_files_only=True, **kwargs)
|
||||||
if res is None:
|
# if res is None:
|
||||||
res = original(url, *args, local_files_only=False, **kwargs)
|
# res = original(url, *args, local_files_only=False, **kwargs)
|
||||||
return res
|
# return res
|
||||||
except Exception:
|
# except Exception:
|
||||||
return original(url, *args, local_files_only=False, **kwargs)
|
# return original(url, *args, local_files_only=False, **kwargs)
|
||||||
|
#
|
||||||
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
# def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||||
return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
|
# return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
|
||||||
|
#
|
||||||
def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
|
# def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||||
return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
|
# return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
|
||||||
|
#
|
||||||
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
# def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||||
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
# return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
||||||
|
#
|
||||||
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
# self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||||
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
# self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||||
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
# self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||||
|
#
|
||||||
if self.disable_clip:
|
# if self.disable_clip:
|
||||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
# self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
# self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
# self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
# self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
# self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
# self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||||
|
#
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.restore()
|
# self.restore()
|
||||||
|
#
|
||||||
|
#
|
||||||
class InitializeOnMeta(ReplaceHelper):
|
# class InitializeOnMeta(ReplaceHelper):
|
||||||
"""
|
# """
|
||||||
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
|
# Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
|
||||||
which results in those parameters having no values and taking no memory. model.to() will be broken and
|
# which results in those parameters having no values and taking no memory. model.to() will be broken and
|
||||||
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
|
# will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
|
||||||
|
#
|
||||||
Usage:
|
# Usage:
|
||||||
```
|
# ```
|
||||||
with sd_disable_initialization.InitializeOnMeta():
|
# with sd_disable_initialization.InitializeOnMeta():
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
# sd_model = instantiate_from_config(sd_config.model)
|
||||||
```
|
# ```
|
||||||
"""
|
# """
|
||||||
|
#
|
||||||
def __enter__(self):
|
# def __enter__(self):
|
||||||
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
# if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||||
return
|
# return
|
||||||
|
#
|
||||||
def set_device(x):
|
# def set_device(x):
|
||||||
x["device"] = "meta"
|
# x["device"] = "meta"
|
||||||
return x
|
# return x
|
||||||
|
#
|
||||||
linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
|
# linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
|
||||||
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
|
# conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
|
||||||
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
|
# mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
|
||||||
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
|
# self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
|
||||||
|
#
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.restore()
|
# self.restore()
|
||||||
|
#
|
||||||
|
#
|
||||||
class LoadStateDictOnMeta(ReplaceHelper):
|
# class LoadStateDictOnMeta(ReplaceHelper):
|
||||||
"""
|
# """
|
||||||
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
|
# Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
|
||||||
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
|
# As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
|
||||||
Meant to be used together with InitializeOnMeta above.
|
# Meant to be used together with InitializeOnMeta above.
|
||||||
|
#
|
||||||
Usage:
|
# Usage:
|
||||||
```
|
# ```
|
||||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
|
# with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
|
||||||
model.load_state_dict(state_dict, strict=False)
|
# model.load_state_dict(state_dict, strict=False)
|
||||||
```
|
# ```
|
||||||
"""
|
# """
|
||||||
|
#
|
||||||
def __init__(self, state_dict, device, weight_dtype_conversion=None):
|
# def __init__(self, state_dict, device, weight_dtype_conversion=None):
|
||||||
super().__init__()
|
# super().__init__()
|
||||||
self.state_dict = state_dict
|
# self.state_dict = state_dict
|
||||||
self.device = device
|
# self.device = device
|
||||||
self.weight_dtype_conversion = weight_dtype_conversion or {}
|
# self.weight_dtype_conversion = weight_dtype_conversion or {}
|
||||||
self.default_dtype = self.weight_dtype_conversion.get('')
|
# self.default_dtype = self.weight_dtype_conversion.get('')
|
||||||
|
#
|
||||||
def get_weight_dtype(self, key):
|
# def get_weight_dtype(self, key):
|
||||||
key_first_term, _ = key.split('.', 1)
|
# key_first_term, _ = key.split('.', 1)
|
||||||
return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
|
# return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
|
||||||
|
#
|
||||||
def __enter__(self):
|
# def __enter__(self):
|
||||||
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
# if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||||
return
|
# return
|
||||||
|
#
|
||||||
sd = self.state_dict
|
# sd = self.state_dict
|
||||||
device = self.device
|
# device = self.device
|
||||||
|
#
|
||||||
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
|
# def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
|
||||||
used_param_keys = []
|
# used_param_keys = []
|
||||||
|
#
|
||||||
for name, param in module._parameters.items():
|
# for name, param in module._parameters.items():
|
||||||
if param is None:
|
# if param is None:
|
||||||
continue
|
# continue
|
||||||
|
#
|
||||||
key = prefix + name
|
# key = prefix + name
|
||||||
sd_param = sd.pop(key, None)
|
# sd_param = sd.pop(key, None)
|
||||||
if sd_param is not None:
|
# if sd_param is not None:
|
||||||
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
# state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
||||||
used_param_keys.append(key)
|
# used_param_keys.append(key)
|
||||||
|
#
|
||||||
if param.is_meta:
|
# if param.is_meta:
|
||||||
dtype = sd_param.dtype if sd_param is not None else param.dtype
|
# dtype = sd_param.dtype if sd_param is not None else param.dtype
|
||||||
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
|
# module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
|
||||||
|
#
|
||||||
for name in module._buffers:
|
# for name in module._buffers:
|
||||||
key = prefix + name
|
# key = prefix + name
|
||||||
|
#
|
||||||
sd_param = sd.pop(key, None)
|
# sd_param = sd.pop(key, None)
|
||||||
if sd_param is not None:
|
# if sd_param is not None:
|
||||||
state_dict[key] = sd_param
|
# state_dict[key] = sd_param
|
||||||
used_param_keys.append(key)
|
# used_param_keys.append(key)
|
||||||
|
#
|
||||||
original(module, state_dict, prefix, *args, **kwargs)
|
# original(module, state_dict, prefix, *args, **kwargs)
|
||||||
|
#
|
||||||
for key in used_param_keys:
|
# for key in used_param_keys:
|
||||||
state_dict.pop(key, None)
|
# state_dict.pop(key, None)
|
||||||
|
#
|
||||||
def load_state_dict(original, module, state_dict, strict=True):
|
# def load_state_dict(original, module, state_dict, strict=True):
|
||||||
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
|
# """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
|
||||||
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
|
# because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
|
||||||
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
|
# all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
|
||||||
|
#
|
||||||
In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
|
# In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
|
||||||
|
#
|
||||||
The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
|
# The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
|
||||||
the function and does not call the original) the state dict will just fail to load because weights
|
# the function and does not call the original) the state dict will just fail to load because weights
|
||||||
would be on the meta device.
|
# would be on the meta device.
|
||||||
"""
|
# """
|
||||||
|
#
|
||||||
if state_dict is sd:
|
# if state_dict is sd:
|
||||||
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
# state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||||
|
#
|
||||||
original(module, state_dict, strict=strict)
|
# original(module, state_dict, strict=strict)
|
||||||
|
#
|
||||||
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
|
# module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
|
||||||
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
|
# module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
|
||||||
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
# linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
||||||
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
# conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
||||||
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
# mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
||||||
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
|
# layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
|
||||||
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
|
# group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
|
||||||
|
#
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.restore()
|
# self.restore()
|
||||||
|
|||||||
@@ -1,124 +1,3 @@
|
|||||||
import torch
|
|
||||||
from torch.nn.functional import silu
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
|
||||||
from modules.hypernetworks import hypernetwork
|
|
||||||
from modules.shared import cmd_opts
|
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
|
||||||
|
|
||||||
import ldm.modules.attention
|
|
||||||
import ldm.modules.diffusionmodules.model
|
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
|
||||||
import ldm.models.diffusion.ddpm
|
|
||||||
import ldm.models.diffusion.ddim
|
|
||||||
import ldm.models.diffusion.plms
|
|
||||||
import ldm.modules.encoders.modules
|
|
||||||
|
|
||||||
import sgm.modules.attention
|
|
||||||
import sgm.modules.diffusionmodules.model
|
|
||||||
import sgm.modules.diffusionmodules.openaimodel
|
|
||||||
import sgm.modules.encoders.modules
|
|
||||||
|
|
||||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
|
||||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
|
||||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
|
||||||
|
|
||||||
# new memory efficient cross attention blocks do not support hypernets and we already
|
|
||||||
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
|
||||||
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
|
||||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
|
||||||
|
|
||||||
# silence new console spam from SD2
|
|
||||||
ldm.modules.attention.print = shared.ldm_print
|
|
||||||
ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
|
||||||
ldm.util.print = shared.ldm_print
|
|
||||||
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
|
||||||
|
|
||||||
optimizers = []
|
|
||||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
|
||||||
|
|
||||||
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
|
||||||
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
|
|
||||||
|
|
||||||
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
|
||||||
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
|
|
||||||
|
|
||||||
|
|
||||||
def list_optimizers():
|
|
||||||
new_optimizers = script_callbacks.list_optimizers_callback()
|
|
||||||
|
|
||||||
new_optimizers = [x for x in new_optimizers if x.is_available()]
|
|
||||||
|
|
||||||
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
|
||||||
|
|
||||||
optimizers.clear()
|
|
||||||
optimizers.extend(new_optimizers)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_optimizations(option=None):
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def undo_optimizations():
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def fix_checkpoint():
|
|
||||||
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
|
||||||
checkpoints to be added when not training (there's a warning)"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def weighted_loss(sd_model, pred, target, mean=True):
|
|
||||||
#Calculate the weight normally, but ignore the mean
|
|
||||||
loss = sd_model._old_get_loss(pred, target, mean=False)
|
|
||||||
|
|
||||||
#Check if we have weights available
|
|
||||||
weight = getattr(sd_model, '_custom_loss_weight', None)
|
|
||||||
if weight is not None:
|
|
||||||
loss *= weight
|
|
||||||
|
|
||||||
#Return the loss, as mean if specified
|
|
||||||
return loss.mean() if mean else loss
|
|
||||||
|
|
||||||
def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|
||||||
try:
|
|
||||||
#Temporarily append weights to a place accessible during loss calc
|
|
||||||
sd_model._custom_loss_weight = w
|
|
||||||
|
|
||||||
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
|
||||||
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
|
||||||
if not hasattr(sd_model, '_old_get_loss'):
|
|
||||||
sd_model._old_get_loss = sd_model.get_loss
|
|
||||||
sd_model.get_loss = MethodType(weighted_loss, sd_model)
|
|
||||||
|
|
||||||
#Run the standard forward function, but with the patched 'get_loss'
|
|
||||||
return sd_model.forward(x, c, *args, **kwargs)
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
#Delete temporary weights if appended
|
|
||||||
del sd_model._custom_loss_weight
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
#If we have an old loss function, reset the loss function to the original one
|
|
||||||
if hasattr(sd_model, '_old_get_loss'):
|
|
||||||
sd_model.get_loss = sd_model._old_get_loss
|
|
||||||
del sd_model._old_get_loss
|
|
||||||
|
|
||||||
def apply_weighted_forward(sd_model):
|
|
||||||
#Add new function 'weighted_forward' that can be called to calc weighted loss
|
|
||||||
sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
|
|
||||||
|
|
||||||
def undo_weighted_forward(sd_model):
|
|
||||||
try:
|
|
||||||
del sd_model.weighted_forward
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionModelHijack:
|
class StableDiffusionModelHijack:
|
||||||
fixes = None
|
fixes = None
|
||||||
layers = None
|
layers = None
|
||||||
@@ -156,74 +35,234 @@ class StableDiffusionModelHijack:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithFixes(torch.nn.Module):
|
|
||||||
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
|
||||||
super().__init__()
|
|
||||||
self.wrapped = wrapped
|
|
||||||
self.embeddings = embeddings
|
|
||||||
self.textual_inversion_key = textual_inversion_key
|
|
||||||
self.weight = self.wrapped.weight
|
|
||||||
|
|
||||||
def forward(self, input_ids):
|
|
||||||
batch_fixes = self.embeddings.fixes
|
|
||||||
self.embeddings.fixes = None
|
|
||||||
|
|
||||||
inputs_embeds = self.wrapped(input_ids)
|
|
||||||
|
|
||||||
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
vecs = []
|
|
||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
|
||||||
for offset, embedding in fixes:
|
|
||||||
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
|
||||||
emb = devices.cond_cast_unet(vec)
|
|
||||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
|
||||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
|
||||||
|
|
||||||
vecs.append(tensor)
|
|
||||||
|
|
||||||
return torch.stack(vecs)
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionEmbeddings(torch.nn.Embedding):
|
|
||||||
def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
|
|
||||||
super().__init__(num_embeddings, embedding_dim, **kwargs)
|
|
||||||
|
|
||||||
self.embeddings = model_hijack
|
|
||||||
self.textual_inversion_key = textual_inversion_key
|
|
||||||
|
|
||||||
@property
|
|
||||||
def wrapped(self):
|
|
||||||
return super().forward
|
|
||||||
|
|
||||||
def forward(self, input_ids):
|
|
||||||
return EmbeddingsWithFixes.forward(self, input_ids)
|
|
||||||
|
|
||||||
|
|
||||||
def add_circular_option_to_conv_2d():
|
|
||||||
conv2d_constructor = torch.nn.Conv2d.__init__
|
|
||||||
|
|
||||||
def conv2d_constructor_circular(self, *args, **kwargs):
|
|
||||||
return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
|
||||||
|
|
||||||
torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
|
||||||
|
|
||||||
|
|
||||||
model_hijack = StableDiffusionModelHijack()
|
model_hijack = StableDiffusionModelHijack()
|
||||||
|
|
||||||
|
# import torch
|
||||||
def register_buffer(self, name, attr):
|
# from torch.nn.functional import silu
|
||||||
"""
|
# from types import MethodType
|
||||||
Fix register buffer bug for Mac OS.
|
#
|
||||||
"""
|
# from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
||||||
|
# from modules.hypernetworks import hypernetwork
|
||||||
if type(attr) == torch.Tensor:
|
# from modules.shared import cmd_opts
|
||||||
if attr.device != devices.device:
|
# from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
||||||
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
#
|
||||||
|
# import ldm.modules.attention
|
||||||
setattr(self, name, attr)
|
# import ldm.modules.diffusionmodules.model
|
||||||
|
# import ldm.modules.diffusionmodules.openaimodel
|
||||||
|
# import ldm.models.diffusion.ddpm
|
||||||
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
# import ldm.models.diffusion.ddim
|
||||||
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
# import ldm.models.diffusion.plms
|
||||||
|
# import ldm.modules.encoders.modules
|
||||||
|
#
|
||||||
|
# import sgm.modules.attention
|
||||||
|
# import sgm.modules.diffusionmodules.model
|
||||||
|
# import sgm.modules.diffusionmodules.openaimodel
|
||||||
|
# import sgm.modules.encoders.modules
|
||||||
|
#
|
||||||
|
# attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||||
|
# diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
|
# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
|
#
|
||||||
|
# # new memory efficient cross attention blocks do not support hypernets and we already
|
||||||
|
# # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
||||||
|
# ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
||||||
|
# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||||
|
#
|
||||||
|
# # silence new console spam from SD2
|
||||||
|
# ldm.modules.attention.print = shared.ldm_print
|
||||||
|
# ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||||
|
# ldm.util.print = shared.ldm_print
|
||||||
|
# ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||||
|
#
|
||||||
|
# optimizers = []
|
||||||
|
# current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||||
|
#
|
||||||
|
# ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||||
|
# ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
|
||||||
|
#
|
||||||
|
# sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||||
|
# sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def list_optimizers():
|
||||||
|
# new_optimizers = script_callbacks.list_optimizers_callback()
|
||||||
|
#
|
||||||
|
# new_optimizers = [x for x in new_optimizers if x.is_available()]
|
||||||
|
#
|
||||||
|
# new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
||||||
|
#
|
||||||
|
# optimizers.clear()
|
||||||
|
# optimizers.extend(new_optimizers)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def apply_optimizations(option=None):
|
||||||
|
# return
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def undo_optimizations():
|
||||||
|
# return
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def fix_checkpoint():
|
||||||
|
# """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||||
|
# checkpoints to be added when not training (there's a warning)"""
|
||||||
|
#
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def weighted_loss(sd_model, pred, target, mean=True):
|
||||||
|
# #Calculate the weight normally, but ignore the mean
|
||||||
|
# loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||||
|
#
|
||||||
|
# #Check if we have weights available
|
||||||
|
# weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||||
|
# if weight is not None:
|
||||||
|
# loss *= weight
|
||||||
|
#
|
||||||
|
# #Return the loss, as mean if specified
|
||||||
|
# return loss.mean() if mean else loss
|
||||||
|
#
|
||||||
|
# def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
||||||
|
# try:
|
||||||
|
# #Temporarily append weights to a place accessible during loss calc
|
||||||
|
# sd_model._custom_loss_weight = w
|
||||||
|
#
|
||||||
|
# #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||||
|
# #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||||
|
# if not hasattr(sd_model, '_old_get_loss'):
|
||||||
|
# sd_model._old_get_loss = sd_model.get_loss
|
||||||
|
# sd_model.get_loss = MethodType(weighted_loss, sd_model)
|
||||||
|
#
|
||||||
|
# #Run the standard forward function, but with the patched 'get_loss'
|
||||||
|
# return sd_model.forward(x, c, *args, **kwargs)
|
||||||
|
# finally:
|
||||||
|
# try:
|
||||||
|
# #Delete temporary weights if appended
|
||||||
|
# del sd_model._custom_loss_weight
|
||||||
|
# except AttributeError:
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# #If we have an old loss function, reset the loss function to the original one
|
||||||
|
# if hasattr(sd_model, '_old_get_loss'):
|
||||||
|
# sd_model.get_loss = sd_model._old_get_loss
|
||||||
|
# del sd_model._old_get_loss
|
||||||
|
#
|
||||||
|
# def apply_weighted_forward(sd_model):
|
||||||
|
# #Add new function 'weighted_forward' that can be called to calc weighted loss
|
||||||
|
# sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
|
||||||
|
#
|
||||||
|
# def undo_weighted_forward(sd_model):
|
||||||
|
# try:
|
||||||
|
# del sd_model.weighted_forward
|
||||||
|
# except AttributeError:
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class StableDiffusionModelHijack:
|
||||||
|
# fixes = None
|
||||||
|
# layers = None
|
||||||
|
# circular_enabled = False
|
||||||
|
# clip = None
|
||||||
|
# optimization_method = None
|
||||||
|
#
|
||||||
|
# def __init__(self):
|
||||||
|
# self.extra_generation_params = {}
|
||||||
|
# self.comments = []
|
||||||
|
#
|
||||||
|
# def apply_optimizations(self, option=None):
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# def convert_sdxl_to_ssd(self, m):
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# def hijack(self, m):
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# def undo_hijack(self, m):
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# def apply_circular(self, enable):
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# def clear_comments(self):
|
||||||
|
# self.comments = []
|
||||||
|
# self.extra_generation_params = {}
|
||||||
|
#
|
||||||
|
# def get_prompt_lengths(self, text, cond_stage_model):
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# def redo_hijack(self, m):
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
|
# def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||||
|
# super().__init__()
|
||||||
|
# self.wrapped = wrapped
|
||||||
|
# self.embeddings = embeddings
|
||||||
|
# self.textual_inversion_key = textual_inversion_key
|
||||||
|
# self.weight = self.wrapped.weight
|
||||||
|
#
|
||||||
|
# def forward(self, input_ids):
|
||||||
|
# batch_fixes = self.embeddings.fixes
|
||||||
|
# self.embeddings.fixes = None
|
||||||
|
#
|
||||||
|
# inputs_embeds = self.wrapped(input_ids)
|
||||||
|
#
|
||||||
|
# if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||||
|
# return inputs_embeds
|
||||||
|
#
|
||||||
|
# vecs = []
|
||||||
|
# for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
|
# for offset, embedding in fixes:
|
||||||
|
# vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||||
|
# emb = devices.cond_cast_unet(vec)
|
||||||
|
# emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||||
|
# tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
||||||
|
#
|
||||||
|
# vecs.append(tensor)
|
||||||
|
#
|
||||||
|
# return torch.stack(vecs)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class TextualInversionEmbeddings(torch.nn.Embedding):
|
||||||
|
# def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
|
||||||
|
# super().__init__(num_embeddings, embedding_dim, **kwargs)
|
||||||
|
#
|
||||||
|
# self.embeddings = model_hijack
|
||||||
|
# self.textual_inversion_key = textual_inversion_key
|
||||||
|
#
|
||||||
|
# @property
|
||||||
|
# def wrapped(self):
|
||||||
|
# return super().forward
|
||||||
|
#
|
||||||
|
# def forward(self, input_ids):
|
||||||
|
# return EmbeddingsWithFixes.forward(self, input_ids)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def add_circular_option_to_conv_2d():
|
||||||
|
# conv2d_constructor = torch.nn.Conv2d.__init__
|
||||||
|
#
|
||||||
|
# def conv2d_constructor_circular(self, *args, **kwargs):
|
||||||
|
# return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
||||||
|
#
|
||||||
|
# torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def register_buffer(self, name, attr):
|
||||||
|
# """
|
||||||
|
# Fix register buffer bug for Mac OS.
|
||||||
|
# """
|
||||||
|
#
|
||||||
|
# if type(attr) == torch.Tensor:
|
||||||
|
# if attr.device != devices.device:
|
||||||
|
# attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
||||||
|
#
|
||||||
|
# setattr(self, name, attr)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
||||||
|
# ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
||||||
|
|||||||
@@ -1,46 +1,46 @@
|
|||||||
from torch.utils.checkpoint import checkpoint
|
# from torch.utils.checkpoint import checkpoint
|
||||||
|
#
|
||||||
import ldm.modules.attention
|
# import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
# import ldm.modules.diffusionmodules.openaimodel
|
||||||
|
#
|
||||||
|
#
|
||||||
def BasicTransformerBlock_forward(self, x, context=None):
|
# def BasicTransformerBlock_forward(self, x, context=None):
|
||||||
return checkpoint(self._forward, x, context)
|
# return checkpoint(self._forward, x, context)
|
||||||
|
#
|
||||||
|
#
|
||||||
def AttentionBlock_forward(self, x):
|
# def AttentionBlock_forward(self, x):
|
||||||
return checkpoint(self._forward, x)
|
# return checkpoint(self._forward, x)
|
||||||
|
#
|
||||||
|
#
|
||||||
def ResBlock_forward(self, x, emb):
|
# def ResBlock_forward(self, x, emb):
|
||||||
return checkpoint(self._forward, x, emb)
|
# return checkpoint(self._forward, x, emb)
|
||||||
|
#
|
||||||
|
#
|
||||||
stored = []
|
# stored = []
|
||||||
|
#
|
||||||
|
#
|
||||||
def add():
|
# def add():
|
||||||
if len(stored) != 0:
|
# if len(stored) != 0:
|
||||||
return
|
# return
|
||||||
|
#
|
||||||
stored.extend([
|
# stored.extend([
|
||||||
ldm.modules.attention.BasicTransformerBlock.forward,
|
# ldm.modules.attention.BasicTransformerBlock.forward,
|
||||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
|
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
|
||||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
|
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
|
||||||
])
|
# ])
|
||||||
|
#
|
||||||
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
|
# ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
|
||||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
|
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
|
||||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
|
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
|
||||||
|
#
|
||||||
|
#
|
||||||
def remove():
|
# def remove():
|
||||||
if len(stored) == 0:
|
# if len(stored) == 0:
|
||||||
return
|
# return
|
||||||
|
#
|
||||||
ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
|
# ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
|
||||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
|
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
|
||||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
|
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
|
||||||
|
#
|
||||||
stored.clear()
|
# stored.clear()
|
||||||
|
#
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,154 +1,154 @@
|
|||||||
import torch
|
# import torch
|
||||||
from packaging import version
|
# from packaging import version
|
||||||
from einops import repeat
|
# from einops import repeat
|
||||||
import math
|
# import math
|
||||||
|
#
|
||||||
from modules import devices
|
# from modules import devices
|
||||||
from modules.sd_hijack_utils import CondFunc
|
# from modules.sd_hijack_utils import CondFunc
|
||||||
|
#
|
||||||
|
#
|
||||||
class TorchHijackForUnet:
|
# class TorchHijackForUnet:
|
||||||
"""
|
# """
|
||||||
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
# This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||||
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
# this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
||||||
"""
|
# """
|
||||||
|
#
|
||||||
def __getattr__(self, item):
|
# def __getattr__(self, item):
|
||||||
if item == 'cat':
|
# if item == 'cat':
|
||||||
return self.cat
|
# return self.cat
|
||||||
|
#
|
||||||
if hasattr(torch, item):
|
# if hasattr(torch, item):
|
||||||
return getattr(torch, item)
|
# return getattr(torch, item)
|
||||||
|
#
|
||||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
# raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||||
|
#
|
||||||
def cat(self, tensors, *args, **kwargs):
|
# def cat(self, tensors, *args, **kwargs):
|
||||||
if len(tensors) == 2:
|
# if len(tensors) == 2:
|
||||||
a, b = tensors
|
# a, b = tensors
|
||||||
if a.shape[-2:] != b.shape[-2:]:
|
# if a.shape[-2:] != b.shape[-2:]:
|
||||||
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
# a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||||
|
#
|
||||||
tensors = (a, b)
|
# tensors = (a, b)
|
||||||
|
#
|
||||||
return torch.cat(tensors, *args, **kwargs)
|
# return torch.cat(tensors, *args, **kwargs)
|
||||||
|
#
|
||||||
|
#
|
||||||
th = TorchHijackForUnet()
|
# th = TorchHijackForUnet()
|
||||||
|
#
|
||||||
|
#
|
||||||
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
# # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||||
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
# def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||||
"""Always make sure inputs to unet are in correct dtype."""
|
# """Always make sure inputs to unet are in correct dtype."""
|
||||||
if isinstance(cond, dict):
|
# if isinstance(cond, dict):
|
||||||
for y in cond.keys():
|
# for y in cond.keys():
|
||||||
if isinstance(cond[y], list):
|
# if isinstance(cond[y], list):
|
||||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
# cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||||
else:
|
# else:
|
||||||
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
# cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||||
|
#
|
||||||
with devices.autocast():
|
# with devices.autocast():
|
||||||
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
# result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
||||||
if devices.unet_needs_upcast:
|
# if devices.unet_needs_upcast:
|
||||||
return result.float()
|
# return result.float()
|
||||||
else:
|
# else:
|
||||||
return result
|
# return result
|
||||||
|
#
|
||||||
|
#
|
||||||
# Monkey patch to create timestep embed tensor on device, avoiding a block.
|
# # Monkey patch to create timestep embed tensor on device, avoiding a block.
|
||||||
def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
|
# def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
|
||||||
"""
|
# """
|
||||||
Create sinusoidal timestep embeddings.
|
# Create sinusoidal timestep embeddings.
|
||||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||||
These may be fractional.
|
# These may be fractional.
|
||||||
:param dim: the dimension of the output.
|
# :param dim: the dimension of the output.
|
||||||
:param max_period: controls the minimum frequency of the embeddings.
|
# :param max_period: controls the minimum frequency of the embeddings.
|
||||||
:return: an [N x dim] Tensor of positional embeddings.
|
# :return: an [N x dim] Tensor of positional embeddings.
|
||||||
"""
|
# """
|
||||||
if not repeat_only:
|
# if not repeat_only:
|
||||||
half = dim // 2
|
# half = dim // 2
|
||||||
freqs = torch.exp(
|
# freqs = torch.exp(
|
||||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
# -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
||||||
)
|
# )
|
||||||
args = timesteps[:, None].float() * freqs[None]
|
# args = timesteps[:, None].float() * freqs[None]
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
if dim % 2:
|
# if dim % 2:
|
||||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
else:
|
# else:
|
||||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
# embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||||
return embedding
|
# return embedding
|
||||||
|
#
|
||||||
|
#
|
||||||
# Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
|
# # Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
|
||||||
# Prevents a lot of unnecessary aten::copy_ calls
|
# # Prevents a lot of unnecessary aten::copy_ calls
|
||||||
def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
|
# def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
|
||||||
# note: if no context is given, cross-attention defaults to self-attention
|
# # note: if no context is given, cross-attention defaults to self-attention
|
||||||
if not isinstance(context, list):
|
# if not isinstance(context, list):
|
||||||
context = [context]
|
# context = [context]
|
||||||
b, c, h, w = x.shape
|
# b, c, h, w = x.shape
|
||||||
x_in = x
|
# x_in = x
|
||||||
x = self.norm(x)
|
# x = self.norm(x)
|
||||||
if not self.use_linear:
|
# if not self.use_linear:
|
||||||
x = self.proj_in(x)
|
# x = self.proj_in(x)
|
||||||
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
# x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||||
if self.use_linear:
|
# if self.use_linear:
|
||||||
x = self.proj_in(x)
|
# x = self.proj_in(x)
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
# for i, block in enumerate(self.transformer_blocks):
|
||||||
x = block(x, context=context[i])
|
# x = block(x, context=context[i])
|
||||||
if self.use_linear:
|
# if self.use_linear:
|
||||||
x = self.proj_out(x)
|
# x = self.proj_out(x)
|
||||||
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
# x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
||||||
if not self.use_linear:
|
# if not self.use_linear:
|
||||||
x = self.proj_out(x)
|
# x = self.proj_out(x)
|
||||||
return x + x_in
|
# return x + x_in
|
||||||
|
#
|
||||||
|
#
|
||||||
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
# class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||||
def __init__(self, *args, **kwargs):
|
# def __init__(self, *args, **kwargs):
|
||||||
torch.nn.GELU.__init__(self, *args, **kwargs)
|
# torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||||
def forward(self, x):
|
# def forward(self, x):
|
||||||
if devices.unet_needs_upcast:
|
# if devices.unet_needs_upcast:
|
||||||
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
# return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||||
else:
|
# else:
|
||||||
return torch.nn.GELU.forward(self, x)
|
# return torch.nn.GELU.forward(self, x)
|
||||||
|
#
|
||||||
|
#
|
||||||
ddpm_edit_hijack = None
|
# ddpm_edit_hijack = None
|
||||||
def hijack_ddpm_edit():
|
# def hijack_ddpm_edit():
|
||||||
global ddpm_edit_hijack
|
# global ddpm_edit_hijack
|
||||||
if not ddpm_edit_hijack:
|
# if not ddpm_edit_hijack:
|
||||||
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
|
# ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
|
||||||
|
#
|
||||||
|
#
|
||||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
# unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
||||||
CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
|
# CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
|
||||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||||
|
#
|
||||||
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
# if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
# CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
# CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
# CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||||
|
#
|
||||||
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
# first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||||
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
# first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||||
|
#
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
||||||
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
# CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
||||||
|
#
|
||||||
|
#
|
||||||
def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
# def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
||||||
if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
# if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
||||||
dtype = torch.float32
|
# dtype = torch.float32
|
||||||
else:
|
# else:
|
||||||
dtype = devices.dtype_unet
|
# dtype = devices.dtype_unet
|
||||||
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
# return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
||||||
|
#
|
||||||
|
#
|
||||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||||
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
# CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import re
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from omegaconf import OmegaConf, ListConfig
|
from omegaconf import OmegaConf, ListConfig
|
||||||
from urllib import request
|
from urllib import request
|
||||||
import ldm.modules.midas as midas
|
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||||
@@ -415,89 +414,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
|
|
||||||
|
|
||||||
def enable_midas_autodownload():
|
def enable_midas_autodownload():
|
||||||
"""
|
pass
|
||||||
Gives the ldm.modules.midas.api.load_model function automatic downloading.
|
|
||||||
|
|
||||||
When the 512-depth-ema model, and other future models like it, is loaded,
|
|
||||||
it calls midas.api.load_model to load the associated midas depth model.
|
|
||||||
This function applies a wrapper to download the model to the correct
|
|
||||||
location automatically.
|
|
||||||
"""
|
|
||||||
|
|
||||||
midas_path = os.path.join(paths.models_path, 'midas')
|
|
||||||
|
|
||||||
# stable-diffusion-stability-ai hard-codes the midas model path to
|
|
||||||
# a location that differs from where other scripts using this model look.
|
|
||||||
# HACK: Overriding the path here.
|
|
||||||
for k, v in midas.api.ISL_PATHS.items():
|
|
||||||
file_name = os.path.basename(v)
|
|
||||||
midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
|
|
||||||
|
|
||||||
midas_urls = {
|
|
||||||
"dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
|
||||||
"dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
|
||||||
"midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
|
|
||||||
"midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
|
|
||||||
}
|
|
||||||
|
|
||||||
midas.api.load_model_inner = midas.api.load_model
|
|
||||||
|
|
||||||
def load_model_wrapper(model_type):
|
|
||||||
path = midas.api.ISL_PATHS[model_type]
|
|
||||||
if not os.path.exists(path):
|
|
||||||
if not os.path.exists(midas_path):
|
|
||||||
os.mkdir(midas_path)
|
|
||||||
|
|
||||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
|
||||||
request.urlretrieve(midas_urls[model_type], path)
|
|
||||||
print(f"{model_type} downloaded")
|
|
||||||
|
|
||||||
return midas.api.load_model_inner(model_type)
|
|
||||||
|
|
||||||
midas.api.load_model = load_model_wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def patch_given_betas():
|
def patch_given_betas():
|
||||||
import ldm.models.diffusion.ddpm
|
pass
|
||||||
|
|
||||||
def patched_register_schedule(*args, **kwargs):
|
|
||||||
"""a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
|
|
||||||
|
|
||||||
if isinstance(args[1], ListConfig):
|
|
||||||
args = (args[0], np.array(args[1]), *args[2:])
|
|
||||||
|
|
||||||
original_register_schedule(*args, **kwargs)
|
|
||||||
|
|
||||||
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
|
|
||||||
|
|
||||||
|
|
||||||
def repair_config(sd_config, state_dict=None):
|
def repair_config(sd_config, state_dict=None):
|
||||||
if not hasattr(sd_config.model.params, "use_ema"):
|
pass
|
||||||
sd_config.model.params.use_ema = False
|
|
||||||
|
|
||||||
if hasattr(sd_config.model.params, 'unet_config'):
|
|
||||||
if shared.cmd_opts.no_half:
|
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
|
||||||
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
|
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
|
||||||
|
|
||||||
if hasattr(sd_config.model.params, 'first_stage_config'):
|
|
||||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
|
||||||
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
|
||||||
|
|
||||||
# For UnCLIP-L, override the hardcoded karlo directory
|
|
||||||
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
|
|
||||||
karlo_path = os.path.join(paths.models_path, 'karlo')
|
|
||||||
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
|
||||||
|
|
||||||
# Do not use checkpoint for inference.
|
|
||||||
# This helps prevent extra performance overhead on checking parameters.
|
|
||||||
# The perf overhead is about 100ms/it on 4090 for SDXL.
|
|
||||||
if hasattr(sd_config.model.params, "network_config"):
|
|
||||||
sd_config.model.params.network_config.params.use_checkpoint = False
|
|
||||||
if hasattr(sd_config.model.params, "unet_config"):
|
|
||||||
sd_config.model.params.unet_config.params.use_checkpoint = False
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||||
|
|||||||
@@ -1,137 +1,137 @@
|
|||||||
import os
|
# import os
|
||||||
|
#
|
||||||
import torch
|
# import torch
|
||||||
|
#
|
||||||
from modules import shared, paths, sd_disable_initialization, devices
|
# from modules import shared, paths, sd_disable_initialization, devices
|
||||||
|
#
|
||||||
sd_configs_path = shared.sd_configs_path
|
# sd_configs_path = shared.sd_configs_path
|
||||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
# # sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||||
sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
# # sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
||||||
|
#
|
||||||
|
#
|
||||||
config_default = shared.sd_default_config
|
# config_default = shared.sd_default_config
|
||||||
# config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
# # config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
# config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
# config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||||
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
# config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||||
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
# config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||||
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
|
# config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
|
||||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
# config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
# config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
# config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
# config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
# config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
# config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||||
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
# config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||||
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
|
# config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
|
||||||
|
#
|
||||||
|
#
|
||||||
def is_using_v_parameterization_for_sd2(state_dict):
|
# def is_using_v_parameterization_for_sd2(state_dict):
|
||||||
"""
|
# """
|
||||||
Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
# Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
||||||
"""
|
# """
|
||||||
|
#
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
# import ldm.modules.diffusionmodules.openaimodel
|
||||||
|
#
|
||||||
device = devices.device
|
# device = devices.device
|
||||||
|
#
|
||||||
with sd_disable_initialization.DisableInitialization():
|
# with sd_disable_initialization.DisableInitialization():
|
||||||
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
# unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
||||||
use_checkpoint=False,
|
# use_checkpoint=False,
|
||||||
use_fp16=False,
|
# use_fp16=False,
|
||||||
image_size=32,
|
# image_size=32,
|
||||||
in_channels=4,
|
# in_channels=4,
|
||||||
out_channels=4,
|
# out_channels=4,
|
||||||
model_channels=320,
|
# model_channels=320,
|
||||||
attention_resolutions=[4, 2, 1],
|
# attention_resolutions=[4, 2, 1],
|
||||||
num_res_blocks=2,
|
# num_res_blocks=2,
|
||||||
channel_mult=[1, 2, 4, 4],
|
# channel_mult=[1, 2, 4, 4],
|
||||||
num_head_channels=64,
|
# num_head_channels=64,
|
||||||
use_spatial_transformer=True,
|
# use_spatial_transformer=True,
|
||||||
use_linear_in_transformer=True,
|
# use_linear_in_transformer=True,
|
||||||
transformer_depth=1,
|
# transformer_depth=1,
|
||||||
context_dim=1024,
|
# context_dim=1024,
|
||||||
legacy=False
|
# legacy=False
|
||||||
)
|
# )
|
||||||
unet.eval()
|
# unet.eval()
|
||||||
|
#
|
||||||
with torch.no_grad():
|
# with torch.no_grad():
|
||||||
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
# unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
||||||
unet.load_state_dict(unet_sd, strict=True)
|
# unet.load_state_dict(unet_sd, strict=True)
|
||||||
unet.to(device=device, dtype=devices.dtype_unet)
|
# unet.to(device=device, dtype=devices.dtype_unet)
|
||||||
|
#
|
||||||
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
# test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
||||||
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
# x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
||||||
|
#
|
||||||
with devices.autocast():
|
# with devices.autocast():
|
||||||
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
|
# out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
|
||||||
|
#
|
||||||
return out < -1
|
# return out < -1
|
||||||
|
#
|
||||||
|
#
|
||||||
def guess_model_config_from_state_dict(sd, filename):
|
# def guess_model_config_from_state_dict(sd, filename):
|
||||||
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
# sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
# diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
# sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||||
|
#
|
||||||
if "model.diffusion_model.x_embedder.proj.weight" in sd:
|
# if "model.diffusion_model.x_embedder.proj.weight" in sd:
|
||||||
return config_sd3
|
# return config_sd3
|
||||||
|
#
|
||||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
# if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||||
if diffusion_model_input.shape[1] == 9:
|
# if diffusion_model_input.shape[1] == 9:
|
||||||
return config_sdxl_inpainting
|
# return config_sdxl_inpainting
|
||||||
else:
|
# else:
|
||||||
return config_sdxl
|
# return config_sdxl
|
||||||
|
#
|
||||||
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
# if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||||
return config_sdxl_refiner
|
# return config_sdxl_refiner
|
||||||
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
# elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||||
return config_depth_model
|
# return config_depth_model
|
||||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||||
return config_unclip
|
# return config_unclip
|
||||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
|
# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
|
||||||
return config_unopenclip
|
# return config_unopenclip
|
||||||
|
#
|
||||||
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
# if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||||
if diffusion_model_input.shape[1] == 9:
|
# if diffusion_model_input.shape[1] == 9:
|
||||||
return config_sd2_inpainting
|
# return config_sd2_inpainting
|
||||||
# elif is_using_v_parameterization_for_sd2(sd):
|
# # elif is_using_v_parameterization_for_sd2(sd):
|
||||||
# return config_sd2v
|
# # return config_sd2v
|
||||||
else:
|
# else:
|
||||||
return config_sd2v
|
# return config_sd2v
|
||||||
|
#
|
||||||
if diffusion_model_input is not None:
|
# if diffusion_model_input is not None:
|
||||||
if diffusion_model_input.shape[1] == 9:
|
# if diffusion_model_input.shape[1] == 9:
|
||||||
return config_inpainting
|
# return config_inpainting
|
||||||
if diffusion_model_input.shape[1] == 8:
|
# if diffusion_model_input.shape[1] == 8:
|
||||||
return config_instruct_pix2pix
|
# return config_instruct_pix2pix
|
||||||
|
#
|
||||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
# if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||||
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
# if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
||||||
return config_alt_diffusion_m18
|
# return config_alt_diffusion_m18
|
||||||
return config_alt_diffusion
|
# return config_alt_diffusion
|
||||||
|
#
|
||||||
return config_default
|
# return config_default
|
||||||
|
#
|
||||||
|
#
|
||||||
def find_checkpoint_config(state_dict, info):
|
# def find_checkpoint_config(state_dict, info):
|
||||||
if info is None:
|
# if info is None:
|
||||||
return guess_model_config_from_state_dict(state_dict, "")
|
# return guess_model_config_from_state_dict(state_dict, "")
|
||||||
|
#
|
||||||
config = find_checkpoint_config_near_filename(info)
|
# config = find_checkpoint_config_near_filename(info)
|
||||||
if config is not None:
|
# if config is not None:
|
||||||
return config
|
# return config
|
||||||
|
#
|
||||||
return guess_model_config_from_state_dict(state_dict, info.filename)
|
# return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||||
|
#
|
||||||
|
#
|
||||||
def find_checkpoint_config_near_filename(info):
|
# def find_checkpoint_config_near_filename(info):
|
||||||
if info is None:
|
# if info is None:
|
||||||
return None
|
# return None
|
||||||
|
#
|
||||||
config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
# config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
||||||
if os.path.exists(config):
|
# if os.path.exists(config):
|
||||||
return config
|
# return config
|
||||||
|
#
|
||||||
return None
|
# return None
|
||||||
|
#
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
@@ -6,7 +5,7 @@ if TYPE_CHECKING:
|
|||||||
from modules.sd_models import CheckpointInfo
|
from modules.sd_models import CheckpointInfo
|
||||||
|
|
||||||
|
|
||||||
class WebuiSdModel(LatentDiffusion):
|
class WebuiSdModel:
|
||||||
"""This class is not actually instantinated, but its fields are created and fieeld by webui"""
|
"""This class is not actually instantinated, but its fields are created and fieeld by webui"""
|
||||||
|
|
||||||
lowvram: bool
|
lowvram: bool
|
||||||
|
|||||||
@@ -1,115 +1,115 @@
|
|||||||
from __future__ import annotations
|
# from __future__ import annotations
|
||||||
|
#
|
||||||
import torch
|
# import torch
|
||||||
|
#
|
||||||
import sgm.models.diffusion
|
# import sgm.models.diffusion
|
||||||
import sgm.modules.diffusionmodules.denoiser_scaling
|
# import sgm.modules.diffusionmodules.denoiser_scaling
|
||||||
import sgm.modules.diffusionmodules.discretizer
|
# import sgm.modules.diffusionmodules.discretizer
|
||||||
from modules import devices, shared, prompt_parser
|
# from modules import devices, shared, prompt_parser
|
||||||
from modules import torch_utils
|
# from modules import torch_utils
|
||||||
|
#
|
||||||
from backend import memory_management
|
# from backend import memory_management
|
||||||
|
#
|
||||||
|
#
|
||||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
# def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||||
|
#
|
||||||
for embedder in self.conditioner.embedders:
|
# for embedder in self.conditioner.embedders:
|
||||||
embedder.ucg_rate = 0.0
|
# embedder.ucg_rate = 0.0
|
||||||
|
#
|
||||||
width = getattr(batch, 'width', 1024) or 1024
|
# width = getattr(batch, 'width', 1024) or 1024
|
||||||
height = getattr(batch, 'height', 1024) or 1024
|
# height = getattr(batch, 'height', 1024) or 1024
|
||||||
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
# is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||||
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
# aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||||
|
#
|
||||||
devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
|
# devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
|
||||||
|
#
|
||||||
sdxl_conds = {
|
# sdxl_conds = {
|
||||||
"txt": batch,
|
# "txt": batch,
|
||||||
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
# "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||||
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
# "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
||||||
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
# "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||||
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
# "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
||||||
}
|
# }
|
||||||
|
#
|
||||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
# force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
||||||
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
# c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
||||||
|
#
|
||||||
return c
|
# return c
|
||||||
|
#
|
||||||
|
#
|
||||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
|
# def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
|
||||||
if self.model.diffusion_model.in_channels == 9:
|
# if self.model.diffusion_model.in_channels == 9:
|
||||||
x = torch.cat([x] + cond['c_concat'], dim=1)
|
# x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||||
|
#
|
||||||
return self.model(x, t, cond, *args, **kwargs)
|
# return self.model(x, t, cond, *args, **kwargs)
|
||||||
|
#
|
||||||
|
#
|
||||||
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
# def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||||
return x
|
# return x
|
||||||
|
#
|
||||||
|
#
|
||||||
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
# sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||||
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
# sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||||||
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
# sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
||||||
|
#
|
||||||
|
#
|
||||||
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
# def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
||||||
res = []
|
# res = []
|
||||||
|
#
|
||||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
||||||
encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
# encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
||||||
res.append(encoded)
|
# res.append(encoded)
|
||||||
|
#
|
||||||
return torch.cat(res, dim=1)
|
# return torch.cat(res, dim=1)
|
||||||
|
#
|
||||||
|
#
|
||||||
def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
# def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
||||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
||||||
return embedder.tokenize(texts)
|
# return embedder.tokenize(texts)
|
||||||
|
#
|
||||||
raise AssertionError('no tokenizer available')
|
# raise AssertionError('no tokenizer available')
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
def process_texts(self, texts):
|
# def process_texts(self, texts):
|
||||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||||
return embedder.process_texts(texts)
|
# return embedder.process_texts(texts)
|
||||||
|
#
|
||||||
|
#
|
||||||
def get_target_prompt_token_count(self, token_count):
|
# def get_target_prompt_token_count(self, token_count):
|
||||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
||||||
return embedder.get_target_prompt_token_count(token_count)
|
# return embedder.get_target_prompt_token_count(token_count)
|
||||||
|
#
|
||||||
|
#
|
||||||
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
# # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||||
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
# sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||||
sgm.modules.GeneralConditioner.tokenize = tokenize
|
# sgm.modules.GeneralConditioner.tokenize = tokenize
|
||||||
sgm.modules.GeneralConditioner.process_texts = process_texts
|
# sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||||
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
# sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||||
|
#
|
||||||
|
#
|
||||||
def extend_sdxl(model):
|
# def extend_sdxl(model):
|
||||||
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
# """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||||
|
#
|
||||||
dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
# dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
||||||
model.model.diffusion_model.dtype = dtype
|
# model.model.diffusion_model.dtype = dtype
|
||||||
model.model.conditioning_key = 'crossattn'
|
# model.model.conditioning_key = 'crossattn'
|
||||||
model.cond_stage_key = 'txt'
|
# model.cond_stage_key = 'txt'
|
||||||
# model.cond_stage_model will be set in sd_hijack
|
# # model.cond_stage_model will be set in sd_hijack
|
||||||
|
#
|
||||||
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
# model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||||
|
#
|
||||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
# discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
# model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
||||||
|
#
|
||||||
model.conditioner.wrapped = torch.nn.Module()
|
# model.conditioner.wrapped = torch.nn.Module()
|
||||||
|
#
|
||||||
|
#
|
||||||
sgm.modules.attention.print = shared.ldm_print
|
# sgm.modules.attention.print = shared.ldm_print
|
||||||
sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
# sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||||
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
# sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||||
sgm.modules.encoders.modules.print = shared.ldm_print
|
# sgm.modules.encoders.modules.print = shared.ldm_print
|
||||||
|
#
|
||||||
# this gets the code to load the vanilla attention that we override
|
# # this gets the code to load the vanilla attention that we override
|
||||||
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
# sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||||
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
# sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
||||||
|
|||||||
@@ -35,9 +35,7 @@ def refresh_vae_list():
|
|||||||
|
|
||||||
|
|
||||||
def cross_attention_optimizations():
|
def cross_attention_optimizations():
|
||||||
import modules.sd_hijack
|
return ["Automatic"]
|
||||||
|
|
||||||
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
|
|
||||||
|
|
||||||
|
|
||||||
def sd_unet_items():
|
def sd_unet_items():
|
||||||
|
|||||||
@@ -1,245 +1,243 @@
|
|||||||
import os
|
# import os
|
||||||
import numpy as np
|
# import numpy as np
|
||||||
import PIL
|
# import PIL
|
||||||
import torch
|
# import torch
|
||||||
from torch.utils.data import Dataset, DataLoader, Sampler
|
# from torch.utils.data import Dataset, DataLoader, Sampler
|
||||||
from torchvision import transforms
|
# from torchvision import transforms
|
||||||
from collections import defaultdict
|
# from collections import defaultdict
|
||||||
from random import shuffle, choices
|
# from random import shuffle, choices
|
||||||
|
#
|
||||||
import random
|
# import random
|
||||||
import tqdm
|
# import tqdm
|
||||||
from modules import devices, shared, images
|
# from modules import devices, shared, images
|
||||||
import re
|
# import re
|
||||||
|
#
|
||||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
# re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||||
|
#
|
||||||
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
#
|
||||||
|
# class DatasetEntry:
|
||||||
|
# def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
|
||||||
class DatasetEntry:
|
# self.filename = filename
|
||||||
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
|
# self.filename_text = filename_text
|
||||||
self.filename = filename
|
# self.weight = weight
|
||||||
self.filename_text = filename_text
|
# self.latent_dist = latent_dist
|
||||||
self.weight = weight
|
# self.latent_sample = latent_sample
|
||||||
self.latent_dist = latent_dist
|
# self.cond = cond
|
||||||
self.latent_sample = latent_sample
|
# self.cond_text = cond_text
|
||||||
self.cond = cond
|
# self.pixel_values = pixel_values
|
||||||
self.cond_text = cond_text
|
#
|
||||||
self.pixel_values = pixel_values
|
#
|
||||||
|
# class PersonalizedBase(Dataset):
|
||||||
|
# def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
|
||||||
class PersonalizedBase(Dataset):
|
# re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
|
||||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
|
#
|
||||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
|
# self.placeholder_token = placeholder_token
|
||||||
|
#
|
||||||
self.placeholder_token = placeholder_token
|
# self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
#
|
||||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
# self.dataset = []
|
||||||
|
#
|
||||||
self.dataset = []
|
# with open(template_file, "r") as file:
|
||||||
|
# lines = [x.strip() for x in file.readlines()]
|
||||||
with open(template_file, "r") as file:
|
#
|
||||||
lines = [x.strip() for x in file.readlines()]
|
# self.lines = lines
|
||||||
|
#
|
||||||
self.lines = lines
|
# assert data_root, 'dataset directory not specified'
|
||||||
|
# assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||||
assert data_root, 'dataset directory not specified'
|
# assert os.listdir(data_root), "Dataset directory is empty"
|
||||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
#
|
||||||
assert os.listdir(data_root), "Dataset directory is empty"
|
# self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
|
#
|
||||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
# self.shuffle_tags = shuffle_tags
|
||||||
|
# self.tag_drop_out = tag_drop_out
|
||||||
self.shuffle_tags = shuffle_tags
|
# groups = defaultdict(list)
|
||||||
self.tag_drop_out = tag_drop_out
|
#
|
||||||
groups = defaultdict(list)
|
# print("Preparing dataset...")
|
||||||
|
# for path in tqdm.tqdm(self.image_paths):
|
||||||
print("Preparing dataset...")
|
# alpha_channel = None
|
||||||
for path in tqdm.tqdm(self.image_paths):
|
# if shared.state.interrupted:
|
||||||
alpha_channel = None
|
# raise Exception("interrupted")
|
||||||
if shared.state.interrupted:
|
# try:
|
||||||
raise Exception("interrupted")
|
# image = images.read(path)
|
||||||
try:
|
# #Currently does not work for single color transparency
|
||||||
image = images.read(path)
|
# #We would need to read image.info['transparency'] for that
|
||||||
#Currently does not work for single color transparency
|
# if use_weight and 'A' in image.getbands():
|
||||||
#We would need to read image.info['transparency'] for that
|
# alpha_channel = image.getchannel('A')
|
||||||
if use_weight and 'A' in image.getbands():
|
# image = image.convert('RGB')
|
||||||
alpha_channel = image.getchannel('A')
|
# if not varsize:
|
||||||
image = image.convert('RGB')
|
# image = image.resize((width, height), PIL.Image.BICUBIC)
|
||||||
if not varsize:
|
# except Exception:
|
||||||
image = image.resize((width, height), PIL.Image.BICUBIC)
|
# continue
|
||||||
except Exception:
|
#
|
||||||
continue
|
# text_filename = f"{os.path.splitext(path)[0]}.txt"
|
||||||
|
# filename = os.path.basename(path)
|
||||||
text_filename = f"{os.path.splitext(path)[0]}.txt"
|
#
|
||||||
filename = os.path.basename(path)
|
# if os.path.exists(text_filename):
|
||||||
|
# with open(text_filename, "r", encoding="utf8") as file:
|
||||||
if os.path.exists(text_filename):
|
# filename_text = file.read()
|
||||||
with open(text_filename, "r", encoding="utf8") as file:
|
# else:
|
||||||
filename_text = file.read()
|
# filename_text = os.path.splitext(filename)[0]
|
||||||
else:
|
# filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
||||||
filename_text = os.path.splitext(filename)[0]
|
# if re_word:
|
||||||
filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
# tokens = re_word.findall(filename_text)
|
||||||
if re_word:
|
# filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||||
tokens = re_word.findall(filename_text)
|
#
|
||||||
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
# npimage = np.array(image).astype(np.uint8)
|
||||||
|
# npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||||
npimage = np.array(image).astype(np.uint8)
|
#
|
||||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
# torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||||
|
# latent_sample = None
|
||||||
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
#
|
||||||
latent_sample = None
|
# with devices.autocast():
|
||||||
|
# latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||||
with devices.autocast():
|
#
|
||||||
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
# #Perform latent sampling, even for random sampling.
|
||||||
|
# #We need the sample dimensions for the weights
|
||||||
#Perform latent sampling, even for random sampling.
|
# if latent_sampling_method == "deterministic":
|
||||||
#We need the sample dimensions for the weights
|
# if isinstance(latent_dist, DiagonalGaussianDistribution):
|
||||||
if latent_sampling_method == "deterministic":
|
# # Works only for DiagonalGaussianDistribution
|
||||||
if isinstance(latent_dist, DiagonalGaussianDistribution):
|
# latent_dist.std = 0
|
||||||
# Works only for DiagonalGaussianDistribution
|
# else:
|
||||||
latent_dist.std = 0
|
# latent_sampling_method = "once"
|
||||||
else:
|
# latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||||
latent_sampling_method = "once"
|
#
|
||||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
# if use_weight and alpha_channel is not None:
|
||||||
|
# channels, *latent_size = latent_sample.shape
|
||||||
if use_weight and alpha_channel is not None:
|
# weight_img = alpha_channel.resize(latent_size)
|
||||||
channels, *latent_size = latent_sample.shape
|
# npweight = np.array(weight_img).astype(np.float32)
|
||||||
weight_img = alpha_channel.resize(latent_size)
|
# #Repeat for every channel in the latent sample
|
||||||
npweight = np.array(weight_img).astype(np.float32)
|
# weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
|
||||||
#Repeat for every channel in the latent sample
|
# #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
|
||||||
weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
|
# weight -= weight.min()
|
||||||
#Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
|
# weight /= weight.mean()
|
||||||
weight -= weight.min()
|
# elif use_weight:
|
||||||
weight /= weight.mean()
|
# #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
|
||||||
elif use_weight:
|
# weight = torch.ones(latent_sample.shape)
|
||||||
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
|
# else:
|
||||||
weight = torch.ones(latent_sample.shape)
|
# weight = None
|
||||||
else:
|
#
|
||||||
weight = None
|
# if latent_sampling_method == "random":
|
||||||
|
# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
||||||
if latent_sampling_method == "random":
|
# else:
|
||||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
|
||||||
else:
|
#
|
||||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
|
# if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||||
|
# entry.cond_text = self.create_text(filename_text)
|
||||||
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
#
|
||||||
entry.cond_text = self.create_text(filename_text)
|
# if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||||
|
# with devices.autocast():
|
||||||
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
# entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||||
with devices.autocast():
|
# groups[image.size].append(len(self.dataset))
|
||||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
# self.dataset.append(entry)
|
||||||
groups[image.size].append(len(self.dataset))
|
# del torchdata
|
||||||
self.dataset.append(entry)
|
# del latent_dist
|
||||||
del torchdata
|
# del latent_sample
|
||||||
del latent_dist
|
# del weight
|
||||||
del latent_sample
|
#
|
||||||
del weight
|
# self.length = len(self.dataset)
|
||||||
|
# self.groups = list(groups.values())
|
||||||
self.length = len(self.dataset)
|
# assert self.length > 0, "No images have been found in the dataset."
|
||||||
self.groups = list(groups.values())
|
# self.batch_size = min(batch_size, self.length)
|
||||||
assert self.length > 0, "No images have been found in the dataset."
|
# self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||||
self.batch_size = min(batch_size, self.length)
|
# self.latent_sampling_method = latent_sampling_method
|
||||||
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
#
|
||||||
self.latent_sampling_method = latent_sampling_method
|
# if len(groups) > 1:
|
||||||
|
# print("Buckets:")
|
||||||
if len(groups) > 1:
|
# for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
|
||||||
print("Buckets:")
|
# print(f" {w}x{h}: {len(ids)}")
|
||||||
for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
|
# print()
|
||||||
print(f" {w}x{h}: {len(ids)}")
|
#
|
||||||
print()
|
# def create_text(self, filename_text):
|
||||||
|
# text = random.choice(self.lines)
|
||||||
def create_text(self, filename_text):
|
# tags = filename_text.split(',')
|
||||||
text = random.choice(self.lines)
|
# if self.tag_drop_out != 0:
|
||||||
tags = filename_text.split(',')
|
# tags = [t for t in tags if random.random() > self.tag_drop_out]
|
||||||
if self.tag_drop_out != 0:
|
# if self.shuffle_tags:
|
||||||
tags = [t for t in tags if random.random() > self.tag_drop_out]
|
# random.shuffle(tags)
|
||||||
if self.shuffle_tags:
|
# text = text.replace("[filewords]", ','.join(tags))
|
||||||
random.shuffle(tags)
|
# text = text.replace("[name]", self.placeholder_token)
|
||||||
text = text.replace("[filewords]", ','.join(tags))
|
# return text
|
||||||
text = text.replace("[name]", self.placeholder_token)
|
#
|
||||||
return text
|
# def __len__(self):
|
||||||
|
# return self.length
|
||||||
def __len__(self):
|
#
|
||||||
return self.length
|
# def __getitem__(self, i):
|
||||||
|
# entry = self.dataset[i]
|
||||||
def __getitem__(self, i):
|
# if self.tag_drop_out != 0 or self.shuffle_tags:
|
||||||
entry = self.dataset[i]
|
# entry.cond_text = self.create_text(entry.filename_text)
|
||||||
if self.tag_drop_out != 0 or self.shuffle_tags:
|
# if self.latent_sampling_method == "random":
|
||||||
entry.cond_text = self.create_text(entry.filename_text)
|
# entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||||
if self.latent_sampling_method == "random":
|
# return entry
|
||||||
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
#
|
||||||
return entry
|
#
|
||||||
|
# class GroupedBatchSampler(Sampler):
|
||||||
|
# def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||||
class GroupedBatchSampler(Sampler):
|
# super().__init__(data_source)
|
||||||
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
#
|
||||||
super().__init__(data_source)
|
# n = len(data_source)
|
||||||
|
# self.groups = data_source.groups
|
||||||
n = len(data_source)
|
# self.len = n_batch = n // batch_size
|
||||||
self.groups = data_source.groups
|
# expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
||||||
self.len = n_batch = n // batch_size
|
# self.base = [int(e) // batch_size for e in expected]
|
||||||
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
# self.n_rand_batches = nrb = n_batch - sum(self.base)
|
||||||
self.base = [int(e) // batch_size for e in expected]
|
# self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
||||||
self.n_rand_batches = nrb = n_batch - sum(self.base)
|
# self.batch_size = batch_size
|
||||||
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
#
|
||||||
self.batch_size = batch_size
|
# def __len__(self):
|
||||||
|
# return self.len
|
||||||
def __len__(self):
|
#
|
||||||
return self.len
|
# def __iter__(self):
|
||||||
|
# b = self.batch_size
|
||||||
def __iter__(self):
|
#
|
||||||
b = self.batch_size
|
# for g in self.groups:
|
||||||
|
# shuffle(g)
|
||||||
for g in self.groups:
|
#
|
||||||
shuffle(g)
|
# batches = []
|
||||||
|
# for g in self.groups:
|
||||||
batches = []
|
# batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
||||||
for g in self.groups:
|
# for _ in range(self.n_rand_batches):
|
||||||
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
# rand_group = choices(self.groups, self.probs)[0]
|
||||||
for _ in range(self.n_rand_batches):
|
# batches.append(choices(rand_group, k=b))
|
||||||
rand_group = choices(self.groups, self.probs)[0]
|
#
|
||||||
batches.append(choices(rand_group, k=b))
|
# shuffle(batches)
|
||||||
|
#
|
||||||
shuffle(batches)
|
# yield from batches
|
||||||
|
#
|
||||||
yield from batches
|
#
|
||||||
|
# class PersonalizedDataLoader(DataLoader):
|
||||||
|
# def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||||
class PersonalizedDataLoader(DataLoader):
|
# super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||||
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
# if latent_sampling_method == "random":
|
||||||
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
# self.collate_fn = collate_wrapper_random
|
||||||
if latent_sampling_method == "random":
|
# else:
|
||||||
self.collate_fn = collate_wrapper_random
|
# self.collate_fn = collate_wrapper
|
||||||
else:
|
#
|
||||||
self.collate_fn = collate_wrapper
|
#
|
||||||
|
# class BatchLoader:
|
||||||
|
# def __init__(self, data):
|
||||||
class BatchLoader:
|
# self.cond_text = [entry.cond_text for entry in data]
|
||||||
def __init__(self, data):
|
# self.cond = [entry.cond for entry in data]
|
||||||
self.cond_text = [entry.cond_text for entry in data]
|
# self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||||
self.cond = [entry.cond for entry in data]
|
# if all(entry.weight is not None for entry in data):
|
||||||
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
# self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
|
||||||
if all(entry.weight is not None for entry in data):
|
# else:
|
||||||
self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
|
# self.weight = None
|
||||||
else:
|
# #self.emb_index = [entry.emb_index for entry in data]
|
||||||
self.weight = None
|
# #print(self.latent_sample.device)
|
||||||
#self.emb_index = [entry.emb_index for entry in data]
|
#
|
||||||
#print(self.latent_sample.device)
|
# def pin_memory(self):
|
||||||
|
# self.latent_sample = self.latent_sample.pin_memory()
|
||||||
def pin_memory(self):
|
# return self
|
||||||
self.latent_sample = self.latent_sample.pin_memory()
|
#
|
||||||
return self
|
# def collate_wrapper(batch):
|
||||||
|
# return BatchLoader(batch)
|
||||||
def collate_wrapper(batch):
|
#
|
||||||
return BatchLoader(batch)
|
# class BatchLoaderRandom(BatchLoader):
|
||||||
|
# def __init__(self, data):
|
||||||
class BatchLoaderRandom(BatchLoader):
|
# super().__init__(data)
|
||||||
def __init__(self, data):
|
#
|
||||||
super().__init__(data)
|
# def pin_memory(self):
|
||||||
|
# return self
|
||||||
def pin_memory(self):
|
#
|
||||||
return self
|
# def collate_wrapper_random(batch):
|
||||||
|
# return BatchLoaderRandom(batch)
|
||||||
def collate_wrapper_random(batch):
|
|
||||||
return BatchLoaderRandom(batch)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user