Free WebUI from its Prison

Congratulations WebUI. Say Hello to freedom.
This commit is contained in:
layerdiffusion
2024-08-05 03:58:34 -07:00
parent aafe11b14c
commit bccf9fb23a
26 changed files with 2053 additions and 4392 deletions

View File

@@ -1,250 +0,0 @@
import os
import gc
import time
import numpy as np
import torch
import torchvision
from PIL import Image
from einops import rearrange, repeat
from omegaconf import OmegaConf
import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack, devices
cached_ldsr_model: torch.nn.Module = None
# Create LDSR Class
class LDSR:
def load_model_from_config(self, half_attention):
global cached_ldsr_model
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
print("Loading model from cache")
model: torch.nn.Module = cached_ldsr_model
else:
print(f"Loading model from {self.modelPath}")
_, extension = os.path.splitext(self.modelPath)
if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
else:
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
config = OmegaConf.load(self.yamlPath)
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model: torch.nn.Module = instantiate_from_config(config.model)
model.load_state_dict(sd, strict=False)
model = model.to(shared.device)
if half_attention:
model = model.half()
if shared.cmd_opts.opt_channelslast:
model = model.to(memory_format=torch.channels_last)
sd_hijack.model_hijack.hijack(model) # apply optimization
model.eval()
if shared.opts.ldsr_cached:
cached_ldsr_model = model
return {"model": model}
def __init__(self, model_path, yaml_path):
self.modelPath = model_path
self.yamlPath = yaml_path
@staticmethod
def run(model, selected_path, custom_steps, eta):
example = get_cond(selected_path)
n_runs = 1
guider = None
ckwargs = None
ddim_use_x0_pred = False
temperature = 1.
eta = eta
custom_shape = None
height, width = example["image"].shape[1:3]
split_input = height >= 128 and width >= 128
if split_input:
ks = 128
stride = 64
vqf = 4 #
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01}
else:
if hasattr(model, "split_input_params"):
delattr(model, "split_input_params")
x_t = None
logs = None
for _ in range(n_runs):
if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
logs = make_convolutional_sample(example, model,
custom_steps=custom_steps,
eta=eta, quantize_x0=False,
custom_shape=custom_shape,
temperature=temperature, noise_dropout=0.,
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
ddim_use_x0_pred=ddim_use_x0_pred
)
return logs
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
model = self.load_model_from_config(half_attention)
# Run settings
diffusion_steps = int(steps)
eta = 1.0
gc.collect()
devices.torch_gc()
im_og = image
width_og, height_og = im_og.size
# If we can adjust the max upscale size, then the 4 below should be our variable
down_sample_rate = target_scale / 4
wd = width_og * down_sample_rate
hd = height_og * down_sample_rate
width_downsampled_pre = int(np.ceil(wd))
height_downsampled_pre = int(np.ceil(hd))
if down_sample_rate != 1:
print(
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"]
sample = sample.detach().cpu()
sample = torch.clamp(sample, -1., 1.)
sample = (sample + 1.) / 2. * 255
sample = sample.numpy().astype(np.uint8)
sample = np.transpose(sample, (0, 2, 3, 1))
a = Image.fromarray(sample[0])
# remove padding
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
del model
gc.collect()
devices.torch_gc()
return a
def get_cond(selected_path):
example = {}
up_f = 4
c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
antialias=True)
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.
c = c.to(shared.device)
example["LR_image"] = c
example["image"] = c_up
return example
@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
corrector_kwargs=None, x_t=None
):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
print(f"Sampling with eta = {eta}; steps: {steps}")
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
mask=mask, x0=x0, temperature=temperature, verbose=False,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, x_t=x_t)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
log = {}
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=not (hasattr(model, 'split_input_params')
and model.cond_stage_key == 'coordinates_bbox'),
return_original_cond=True)
if custom_shape is not None:
z = torch.randn(custom_shape)
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
z0 = None
log["input"] = x
log["reconstruction"] = xrec
if ismap(xc):
log["original_conditioning"] = model.to_rgb(xc)
if hasattr(model, 'cond_stage_key'):
log[model.cond_stage_key] = model.to_rgb(xc)
else:
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_model:
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_key == 'class_label':
log[model.cond_stage_key] = xc[model.cond_stage_key]
with model.ema_scope("Plotting"):
t0 = time.time()
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
eta=eta,
quantize_x0=quantize_x0, mask=None, x0=z0,
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
x_t=x_T)
t1 = time.time()
if ddim_use_x0_pred:
sample = intermediates['pred_x0'][-1]
x_sample = model.decode_first_stage(sample)
try:
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except Exception:
pass
log["sample"] = x_sample
log["time"] = t1 - t0
return log

View File

@@ -1,6 +0,0 @@
import os
from modules import paths
def preload(parser):
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))

View File

@@ -1,70 +0,0 @@
import os
from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from modules_forge.utils import prepare_free_memory
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks, errors
import sd_hijack_autoencoder # noqa: F401
import sd_hijack_ddpm_v1 # noqa: F401
class UpscalerLDSR(Upscaler):
def __init__(self, user_path):
self.name = "LDSR"
self.user_path = user_path
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
super().__init__()
scaler_data = UpscalerData("LDSR", None, self)
self.scalers = [scaler_data]
def load_model(self, path: str):
# Remove incorrect project.yaml file if too big
yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt")
local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"])
local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None)
local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None)
local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None)
if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760:
print("Removing invalid LDSR YAML file.")
os.remove(yaml_path)
if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path)
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path
else:
model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
return LDSR(model, yaml)
def do_upscale(self, img, path):
prepare_free_memory(aggressive=True)
try:
ldsr = self.load_model(path)
except Exception:
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
return img
ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale)
def on_ui_settings():
import gradio as gr
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
script_callbacks.on_ui_settings(on_ui_settings)

View File

@@ -1,293 +0,0 @@
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from torch.optim.lr_scheduler import LambdaLR
from ldm.modules.ema import LitEma
from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config
import ldm.models.autoencoder
from packaging import version
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=None,
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if missing:
print(f"Missing Keys: {missing}")
if unexpected:
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
]
return [opt_ae, opt_disc], scheduler
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = {}
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(*args, embed_dim=embed_dim, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
ldm.models.autoencoder.VQModel = VQModel
ldm.models.autoencoder.VQModelInterface = VQModelInterface

File diff suppressed because it is too large Load Diff

View File

@@ -1,147 +0,0 @@
# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py,
# where the license is as follows:
#
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
# OR OTHER DEALINGS IN THE SOFTWARE./
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
class VectorQuantizer2(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
sane_index_shape=False, legacy=True):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices.")
else:
self.re_embed = n_e
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
assert rescale_logits is False, "Only for interface compatible with Gumbel"
assert return_logits is False, "Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
torch.mean((z_q - z.detach()) ** 2)
else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3])
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q

View File

@@ -11,9 +11,12 @@ from transformers.models.clip.modeling_clip import CLIPVisionModelOutput
from annotator.util import HWC3 from annotator.util import HWC3
from typing import Callable, Tuple, Union from typing import Callable, Tuple, Union
from modules.safe import Extra
from modules import devices from modules import devices
import contextlib
Extra = lambda x: contextlib.nullcontext()
def torch_handler(module: str, name: str): def torch_handler(module: str, name: str):
""" Allow all torch access. Bypass A1111 safety whitelist. """ """ Allow all torch access. Bypass A1111 safety whitelist. """

View File

@@ -19,7 +19,6 @@ import cv2
import logging import logging
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List
from modules.safe import unsafe_torch_load
from lib_controlnet.logging import logger from lib_controlnet.logging import logger
@@ -28,7 +27,7 @@ def load_state_dict(ckpt_path, location="cpu"):
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
state_dict = safetensors.torch.load_file(ckpt_path, device=location) state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else: else:
state_dict = unsafe_torch_load(ckpt_path, map_location=torch.device(location)) state_dict = torch.load(ckpt_path, map_location=torch.device(location))
state_dict = get_state_dict(state_dict) state_dict = get_state_dict(state_dict)
logger.info(f"Loaded state_dict from [{ckpt_path}]") logger.info(f"Loaded state_dict from [{ckpt_path}]")
return state_dict return state_dict

View File

@@ -24,7 +24,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin from PIL import PngImagePlugin
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import Any from typing import Any
@@ -725,7 +724,7 @@ class Api:
def get_sd_models(self): def get_sd_models(self):
import modules.sd_models as sd_models import modules.sd_models as sd_models
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()] return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename} for x in sd_models.checkpoints_list.values()]
def get_sd_vaes(self): def get_sd_vaes(self):
import modules.sd_vae as sd_vae import modules.sd_vae as sd_vae

View File

@@ -9,7 +9,7 @@ import modules.textual_inversion.dataset
import torch import torch
import tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import default from backend.nn.unet import default
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, saving_settings from modules.textual_inversion import textual_inversion, saving_settings
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler

View File

@@ -1,25 +1,12 @@
import importlib import importlib
import logging import logging
import os
import sys import sys
import warnings import warnings
import os import os
from threading import Thread
from modules.timer import startup_timer from modules.timer import startup_timer
class HiddenPrints:
def __enter__(self):
self._original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout.close()
sys.stdout = self._original_stdout
def imports(): def imports():
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
@@ -35,16 +22,8 @@ def imports():
import gradio # noqa: F401 import gradio # noqa: F401
startup_timer.record("import gradio") startup_timer.record("import gradio")
with HiddenPrints(): from modules import paths, timer, import_hook, errors # noqa: F401
from modules import paths, timer, import_hook, errors # noqa: F401 startup_timer.record("setup paths")
startup_timer.record("setup paths")
import ldm.modules.encoders.modules # noqa: F401
import ldm.modules.diffusionmodules.model
startup_timer.record("import ldm")
import sgm.modules.encoders.modules # noqa: F401
startup_timer.record("import sgm")
from modules import shared_init from modules import shared_init
shared_init.initialize() shared_init.initialize()
@@ -137,15 +116,6 @@ def initialize_rest(*, reload_script_modules=False):
sd_vae.refresh_vae_list() sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE") startup_timer.record("refresh VAE")
from modules import textual_inversion
textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
sd_hijack.list_optimizers()
startup_timer.record("scripts list_optimizers")
from modules import sd_unet from modules import sd_unet
sd_unet.list_unets() sd_unet.list_unets()
startup_timer.record("scripts list_unets") startup_timer.record("scripts list_unets")

View File

@@ -391,15 +391,15 @@ def prepare_environment():
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git") assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") # stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") # stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git') huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917") assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") # stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") # stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4") huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
@@ -456,8 +456,8 @@ def prepare_environment():
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash) git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) # git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) # git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash) git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)

View File

@@ -2,45 +2,15 @@ import os
import sys import sys
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401 from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
import modules.safe # noqa: F401
def mute_sdxl_imports():
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
class Dummy:
pass
module = Dummy()
module.LPIPS = None
sys.modules['taming.modules.losses.lpips'] = module
module = Dummy()
module.StableDataModuleFromConfig = None
sys.modules['sgm.data'] = module
# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path) sys.path.insert(0, script_path)
# search for directory of stable diffusion in following places sd_path = os.path.dirname(__file__)
sd_path = None
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths:
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
sd_path = os.path.abspath(possible_sd_path)
break
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
mute_sdxl_imports()
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []), (os.path.join(sd_path, '../repositories/BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), (os.path.join(sd_path, '../repositories/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../repositories/huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
(os.path.join(sd_path, '../huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
] ]
paths = {} paths = {}
@@ -53,13 +23,6 @@ for d, must_exist, what, options in path_dirs:
d = os.path.abspath(d) d = os.path.abspath(d)
if "atstart" in options: if "atstart" in options:
sys.path.insert(0, d) sys.path.insert(0, d)
elif "sgm" in options:
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
sys.path.insert(0, d)
import sgm # noqa: F401
sys.path.pop(0)
else: else:
sys.path.append(d) sys.path.append(d)
paths[what] = d paths[what] = d

View File

@@ -28,8 +28,6 @@ import modules.images as images
import modules.styles import modules.styles
import modules.sd_models as sd_models import modules.sd_models as sd_models
import modules.sd_vae as sd_vae import modules.sd_vae as sd_vae
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType from blendmodes.blend import blendLayers, BlendType
@@ -295,23 +293,7 @@ class StableDiffusionProcessing:
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height) return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
def depth2img_image_conditioning(self, source_image): def depth2img_image_conditioning(self, source_image):
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model raise NotImplementedError('NotImplementedError: depth2img_image_conditioning')
transformer = AddMiDaS(model_type="dpt_hybrid")
transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:],
mode="bicubic",
align_corners=False,
)
(depth_min, depth_max) = torch.aminmax(conditioning)
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
return conditioning
def edit_image_conditioning(self, source_image): def edit_image_conditioning(self, source_image):
conditioning_image = shared.sd_model.encode_first_stage(source_image).mode() conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
@@ -368,11 +350,6 @@ class StableDiffusionProcessing:
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True): def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
source_image = devices.cond_cast_float(source_image) source_image = devices.cond_cast_float(source_image)
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# identify itself with a field common to all models. The conditioning_key is also hybrid.
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
return self.depth2img_image_conditioning(source_image)
if self.sd_model.cond_stage_key == "edit": if self.sd_model.cond_stage_key == "edit":
return self.edit_image_conditioning(source_image) return self.edit_image_conditioning(source_image)

View File

@@ -1,195 +1,195 @@
# this code is adapted from the script contributed by anon from /h/ # # this code is adapted from the script contributed by anon from /h/
#
import pickle # import pickle
import collections # import collections
#
import torch # import torch
import numpy # import numpy
import _codecs # import _codecs
import zipfile # import zipfile
import re # import re
#
#
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage # # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
from modules import errors # from modules import errors
#
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage # TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
#
def encode(*args): # def encode(*args):
out = _codecs.encode(*args) # out = _codecs.encode(*args)
return out # return out
#
#
class RestrictedUnpickler(pickle.Unpickler): # class RestrictedUnpickler(pickle.Unpickler):
extra_handler = None # extra_handler = None
#
def persistent_load(self, saved_id): # def persistent_load(self, saved_id):
assert saved_id[0] == 'storage' # assert saved_id[0] == 'storage'
#
try: # try:
return TypedStorage(_internal=True) # return TypedStorage(_internal=True)
except TypeError: # except TypeError:
return TypedStorage() # PyTorch before 2.0 does not have the _internal argument # return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
#
def find_class(self, module, name): # def find_class(self, module, name):
if self.extra_handler is not None: # if self.extra_handler is not None:
res = self.extra_handler(module, name) # res = self.extra_handler(module, name)
if res is not None: # if res is not None:
return res # return res
#
if module == 'collections' and name == 'OrderedDict': # if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name) # return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: # if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
return getattr(torch._utils, name) # return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']: # if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
return getattr(torch, name) # return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']: # if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name) # return getattr(torch.nn.modules.container, name)
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']: # if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
return getattr(numpy.core.multiarray, name) # return getattr(numpy.core.multiarray, name)
if module == 'numpy' and name in ['dtype', 'ndarray']: # if module == 'numpy' and name in ['dtype', 'ndarray']:
return getattr(numpy, name) # return getattr(numpy, name)
if module == '_codecs' and name == 'encode': # if module == '_codecs' and name == 'encode':
return encode # return encode
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': # if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
import pytorch_lightning.callbacks # import pytorch_lightning.callbacks
return pytorch_lightning.callbacks.model_checkpoint # return pytorch_lightning.callbacks.model_checkpoint
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': # if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
import pytorch_lightning.callbacks.model_checkpoint # import pytorch_lightning.callbacks.model_checkpoint
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint # return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
if module == "__builtin__" and name == 'set': # if module == "__builtin__" and name == 'set':
return set # return set
#
# Forbid everything else. # # Forbid everything else.
raise Exception(f"global '{module}/{name}' is forbidden") # raise Exception(f"global '{module}/{name}' is forbidden")
#
#
# Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/<number>' # # Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/<number>'
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$") # allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$")
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") # data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
#
def check_zip_filenames(filename, names): # def check_zip_filenames(filename, names):
for name in names: # for name in names:
if allowed_zip_names_re.match(name): # if allowed_zip_names_re.match(name):
continue # continue
#
raise Exception(f"bad file inside {filename}: {name}") # raise Exception(f"bad file inside {filename}: {name}")
#
#
def check_pt(filename, extra_handler): # def check_pt(filename, extra_handler):
try: # try:
#
# new pytorch format is a zip file # # new pytorch format is a zip file
with zipfile.ZipFile(filename) as z: # with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist()) # check_zip_filenames(filename, z.namelist())
#
# find filename of data.pkl in zip file: '<directory name>/data.pkl' # # find filename of data.pkl in zip file: '<directory name>/data.pkl'
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] # data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
if len(data_pkl_filenames) == 0: # if len(data_pkl_filenames) == 0:
raise Exception(f"data.pkl not found in {filename}") # raise Exception(f"data.pkl not found in {filename}")
if len(data_pkl_filenames) > 1: # if len(data_pkl_filenames) > 1:
raise Exception(f"Multiple data.pkl found in {filename}") # raise Exception(f"Multiple data.pkl found in {filename}")
with z.open(data_pkl_filenames[0]) as file: # with z.open(data_pkl_filenames[0]) as file:
unpickler = RestrictedUnpickler(file) # unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler # unpickler.extra_handler = extra_handler
unpickler.load() # unpickler.load()
#
except zipfile.BadZipfile: # except zipfile.BadZipfile:
#
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle # # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
with open(filename, "rb") as file: # with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file) # unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler # unpickler.extra_handler = extra_handler
for _ in range(5): # for _ in range(5):
unpickler.load() # unpickler.load()
#
#
def load(filename, *args, **kwargs): # def load(filename, *args, **kwargs):
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs) # return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
#
#
def load_with_extra(filename, extra_handler=None, *args, **kwargs): # def load_with_extra(filename, extra_handler=None, *args, **kwargs):
""" # """
this function is intended to be used by extensions that want to load models with # this function is intended to be used by extensions that want to load models with
some extra classes in them that the usual unpickler would find suspicious. # some extra classes in them that the usual unpickler would find suspicious.
#
Use the extra_handler argument to specify a function that takes module and field name as text, # Use the extra_handler argument to specify a function that takes module and field name as text,
and returns that field's value: # and returns that field's value:
#
```python # ```python
def extra(module, name): # def extra(module, name):
if module == 'collections' and name == 'OrderedDict': # if module == 'collections' and name == 'OrderedDict':
return collections.OrderedDict # return collections.OrderedDict
#
return None # return None
#
safe.load_with_extra('model.pt', extra_handler=extra) # safe.load_with_extra('model.pt', extra_handler=extra)
``` # ```
#
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is # The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
definitely unsafe. # definitely unsafe.
""" # """
#
from modules import shared # from modules import shared
#
try: # try:
if not shared.cmd_opts.disable_safe_unpickle: # if not shared.cmd_opts.disable_safe_unpickle:
check_pt(filename, extra_handler) # check_pt(filename, extra_handler)
#
except pickle.UnpicklingError: # except pickle.UnpicklingError:
errors.report( # errors.report(
f"Error verifying pickled file from {filename}\n" # f"Error verifying pickled file from {filename}\n"
"-----> !!!! The file is most likely corrupted !!!! <-----\n" # "-----> !!!! The file is most likely corrupted !!!! <-----\n"
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", # "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
exc_info=True, # exc_info=True,
) # )
return None # return None
except Exception: # except Exception:
errors.report( # errors.report(
f"Error verifying pickled file from {filename}\n" # f"Error verifying pickled file from {filename}\n"
f"The file may be malicious, so the program is not going to read it.\n" # f"The file may be malicious, so the program is not going to read it.\n"
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", # f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
exc_info=True, # exc_info=True,
) # )
return None # return None
#
return unsafe_torch_load(filename, *args, **kwargs) # return unsafe_torch_load(filename, *args, **kwargs)
#
#
class Extra: # class Extra:
""" # """
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra # A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
(because it's not your code making the torch.load call). The intended use is like this: # (because it's not your code making the torch.load call). The intended use is like this:
#
``` # ```
import torch # import torch
from modules import safe # from modules import safe
#
def handler(module, name): # def handler(module, name):
if module == 'torch' and name in ['float64', 'float16']: # if module == 'torch' and name in ['float64', 'float16']:
return getattr(torch, name) # return getattr(torch, name)
#
return None # return None
#
with safe.Extra(handler): # with safe.Extra(handler):
x = torch.load('model.pt') # x = torch.load('model.pt')
``` # ```
""" # """
#
def __init__(self, handler): # def __init__(self, handler):
self.handler = handler # self.handler = handler
#
def __enter__(self): # def __enter__(self):
global global_extra_handler # global global_extra_handler
#
assert global_extra_handler is None, 'already inside an Extra() block' # assert global_extra_handler is None, 'already inside an Extra() block'
global_extra_handler = self.handler # global_extra_handler = self.handler
#
def __exit__(self, exc_type, exc_val, exc_tb): # def __exit__(self, exc_type, exc_val, exc_tb):
global global_extra_handler # global global_extra_handler
#
global_extra_handler = None # global_extra_handler = None
#
#
unsafe_torch_load = torch.load # unsafe_torch_load = torch.load
global_extra_handler = None # global_extra_handler = None

View File

@@ -1,232 +1,232 @@
import ldm.modules.encoders.modules # import ldm.modules.encoders.modules
import open_clip # import open_clip
import torch # import torch
import transformers.utils.hub # import transformers.utils.hub
#
from modules import shared # from modules import shared
#
#
class ReplaceHelper: # class ReplaceHelper:
def __init__(self): # def __init__(self):
self.replaced = [] # self.replaced = []
#
def replace(self, obj, field, func): # def replace(self, obj, field, func):
original = getattr(obj, field, None) # original = getattr(obj, field, None)
if original is None: # if original is None:
return None # return None
#
self.replaced.append((obj, field, original)) # self.replaced.append((obj, field, original))
setattr(obj, field, func) # setattr(obj, field, func)
#
return original # return original
#
def restore(self): # def restore(self):
for obj, field, original in self.replaced: # for obj, field, original in self.replaced:
setattr(obj, field, original) # setattr(obj, field, original)
#
self.replaced.clear() # self.replaced.clear()
#
#
class DisableInitialization(ReplaceHelper): # class DisableInitialization(ReplaceHelper):
""" # """
When an object of this class enters a `with` block, it starts: # When an object of this class enters a `with` block, it starts:
- preventing torch's layer initialization functions from working # - preventing torch's layer initialization functions from working
- changes CLIP and OpenCLIP to not download model weights # - changes CLIP and OpenCLIP to not download model weights
- changes CLIP to not make requests to check if there is a new version of a file you already have # - changes CLIP to not make requests to check if there is a new version of a file you already have
#
When it leaves the block, it reverts everything to how it was before. # When it leaves the block, it reverts everything to how it was before.
#
Use it like this: # Use it like this:
``` # ```
with DisableInitialization(): # with DisableInitialization():
do_things() # do_things()
``` # ```
""" # """
#
def __init__(self, disable_clip=True): # def __init__(self, disable_clip=True):
super().__init__() # super().__init__()
self.disable_clip = disable_clip # self.disable_clip = disable_clip
#
def replace(self, obj, field, func): # def replace(self, obj, field, func):
original = getattr(obj, field, None) # original = getattr(obj, field, None)
if original is None: # if original is None:
return None # return None
#
self.replaced.append((obj, field, original)) # self.replaced.append((obj, field, original))
setattr(obj, field, func) # setattr(obj, field, func)
#
return original # return original
#
def __enter__(self): # def __enter__(self):
def do_nothing(*args, **kwargs): # def do_nothing(*args, **kwargs):
pass # pass
#
def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs): # def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
return self.create_model_and_transforms(*args, pretrained=None, **kwargs) # return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
#
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): # def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) # res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
res.name_or_path = pretrained_model_name_or_path # res.name_or_path = pretrained_model_name_or_path
return res # return res
#
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs): # def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug # args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs) # return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
#
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs): # def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
#
# this file is always 404, prevent making request # # this file is always 404, prevent making request
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json': # if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
return None # return None
#
try: # try:
res = original(url, *args, local_files_only=True, **kwargs) # res = original(url, *args, local_files_only=True, **kwargs)
if res is None: # if res is None:
res = original(url, *args, local_files_only=False, **kwargs) # res = original(url, *args, local_files_only=False, **kwargs)
return res # return res
except Exception: # except Exception:
return original(url, *args, local_files_only=False, **kwargs) # return original(url, *args, local_files_only=False, **kwargs)
#
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): # def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs) # return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
#
def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs): # def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs) # return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
#
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs): # def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs) # return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
#
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing) # self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing) # self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing) # self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
#
if self.disable_clip: # if self.disable_clip:
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) # self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) # self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) # self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) # self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) # self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) # self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
#
def __exit__(self, exc_type, exc_val, exc_tb): # def __exit__(self, exc_type, exc_val, exc_tb):
self.restore() # self.restore()
#
#
class InitializeOnMeta(ReplaceHelper): # class InitializeOnMeta(ReplaceHelper):
""" # """
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device, # Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
which results in those parameters having no values and taking no memory. model.to() will be broken and # which results in those parameters having no values and taking no memory. model.to() will be broken and
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict. # will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
#
Usage: # Usage:
``` # ```
with sd_disable_initialization.InitializeOnMeta(): # with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model) # sd_model = instantiate_from_config(sd_config.model)
``` # ```
""" # """
#
def __enter__(self): # def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization: # if shared.cmd_opts.disable_model_loading_ram_optimization:
return # return
#
def set_device(x): # def set_device(x):
x["device"] = "meta" # x["device"] = "meta"
return x # return x
#
linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs))) # linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs))) # conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs))) # mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None) # self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
#
def __exit__(self, exc_type, exc_val, exc_tb): # def __exit__(self, exc_type, exc_val, exc_tb):
self.restore() # self.restore()
#
#
class LoadStateDictOnMeta(ReplaceHelper): # class LoadStateDictOnMeta(ReplaceHelper):
""" # """
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device. # Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory. # As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
Meant to be used together with InitializeOnMeta above. # Meant to be used together with InitializeOnMeta above.
#
Usage: # Usage:
``` # ```
with sd_disable_initialization.LoadStateDictOnMeta(state_dict): # with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
model.load_state_dict(state_dict, strict=False) # model.load_state_dict(state_dict, strict=False)
``` # ```
""" # """
#
def __init__(self, state_dict, device, weight_dtype_conversion=None): # def __init__(self, state_dict, device, weight_dtype_conversion=None):
super().__init__() # super().__init__()
self.state_dict = state_dict # self.state_dict = state_dict
self.device = device # self.device = device
self.weight_dtype_conversion = weight_dtype_conversion or {} # self.weight_dtype_conversion = weight_dtype_conversion or {}
self.default_dtype = self.weight_dtype_conversion.get('') # self.default_dtype = self.weight_dtype_conversion.get('')
#
def get_weight_dtype(self, key): # def get_weight_dtype(self, key):
key_first_term, _ = key.split('.', 1) # key_first_term, _ = key.split('.', 1)
return self.weight_dtype_conversion.get(key_first_term, self.default_dtype) # return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
#
def __enter__(self): # def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization: # if shared.cmd_opts.disable_model_loading_ram_optimization:
return # return
#
sd = self.state_dict # sd = self.state_dict
device = self.device # device = self.device
#
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): # def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
used_param_keys = [] # used_param_keys = []
#
for name, param in module._parameters.items(): # for name, param in module._parameters.items():
if param is None: # if param is None:
continue # continue
#
key = prefix + name # key = prefix + name
sd_param = sd.pop(key, None) # sd_param = sd.pop(key, None)
if sd_param is not None: # if sd_param is not None:
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) # state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
used_param_keys.append(key) # used_param_keys.append(key)
#
if param.is_meta: # if param.is_meta:
dtype = sd_param.dtype if sd_param is not None else param.dtype # dtype = sd_param.dtype if sd_param is not None else param.dtype
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) # module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
#
for name in module._buffers: # for name in module._buffers:
key = prefix + name # key = prefix + name
#
sd_param = sd.pop(key, None) # sd_param = sd.pop(key, None)
if sd_param is not None: # if sd_param is not None:
state_dict[key] = sd_param # state_dict[key] = sd_param
used_param_keys.append(key) # used_param_keys.append(key)
#
original(module, state_dict, prefix, *args, **kwargs) # original(module, state_dict, prefix, *args, **kwargs)
#
for key in used_param_keys: # for key in used_param_keys:
state_dict.pop(key, None) # state_dict.pop(key, None)
#
def load_state_dict(original, module, state_dict, strict=True): # def load_state_dict(original, module, state_dict, strict=True):
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help # """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with # because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. # all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
#
In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd). # In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
#
The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads # The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
the function and does not call the original) the state dict will just fail to load because weights # the function and does not call the original) the state dict will just fail to load because weights
would be on the meta device. # would be on the meta device.
""" # """
#
if state_dict is sd: # if state_dict is sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} # state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
#
original(module, state_dict, strict=strict) # original(module, state_dict, strict=strict)
#
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) # module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) # module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) # linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) # conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) # mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs)) # layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs)) # group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
#
def __exit__(self, exc_type, exc_val, exc_tb): # def __exit__(self, exc_type, exc_val, exc_tb):
self.restore() # self.restore()

View File

@@ -1,124 +1,3 @@
import torch
from torch.nn.functional import silu
from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
import sgm.modules.attention
import sgm.modules.diffusionmodules.model
import sgm.modules.diffusionmodules.openaimodel
import sgm.modules.encoders.modules
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
# new memory efficient cross attention blocks do not support hypernets and we already
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2
ldm.modules.attention.print = shared.ldm_print
ldm.modules.diffusionmodules.model.print = shared.ldm_print
ldm.util.print = shared.ldm_print
ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback()
new_optimizers = [x for x in new_optimizers if x.is_available()]
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
optimizers.clear()
optimizers.extend(new_optimizers)
def apply_optimizations(option=None):
return
def undo_optimizations():
return
def fix_checkpoint():
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
checkpoints to be added when not training (there's a warning)"""
pass
def weighted_loss(sd_model, pred, target, mean=True):
#Calculate the weight normally, but ignore the mean
loss = sd_model._old_get_loss(pred, target, mean=False)
#Check if we have weights available
weight = getattr(sd_model, '_custom_loss_weight', None)
if weight is not None:
loss *= weight
#Return the loss, as mean if specified
return loss.mean() if mean else loss
def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try:
#Temporarily append weights to a place accessible during loss calc
sd_model._custom_loss_weight = w
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if not hasattr(sd_model, '_old_get_loss'):
sd_model._old_get_loss = sd_model.get_loss
sd_model.get_loss = MethodType(weighted_loss, sd_model)
#Run the standard forward function, but with the patched 'get_loss'
return sd_model.forward(x, c, *args, **kwargs)
finally:
try:
#Delete temporary weights if appended
del sd_model._custom_loss_weight
except AttributeError:
pass
#If we have an old loss function, reset the loss function to the original one
if hasattr(sd_model, '_old_get_loss'):
sd_model.get_loss = sd_model._old_get_loss
del sd_model._old_get_loss
def apply_weighted_forward(sd_model):
#Add new function 'weighted_forward' that can be called to calc weighted loss
sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
def undo_weighted_forward(sd_model):
try:
del sd_model.weighted_forward
except AttributeError:
pass
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
fixes = None fixes = None
layers = None layers = None
@@ -156,74 +35,234 @@ class StableDiffusionModelHijack:
pass pass
class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
self.textual_inversion_key = textual_inversion_key
self.weight = self.wrapped.weight
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
self.embeddings.fixes = None
inputs_embeds = self.wrapped(input_ids)
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
return inputs_embeds
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
emb = devices.cond_cast_unet(vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
vecs.append(tensor)
return torch.stack(vecs)
class TextualInversionEmbeddings(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
super().__init__(num_embeddings, embedding_dim, **kwargs)
self.embeddings = model_hijack
self.textual_inversion_key = textual_inversion_key
@property
def wrapped(self):
return super().forward
def forward(self, input_ids):
return EmbeddingsWithFixes.forward(self, input_ids)
def add_circular_option_to_conv_2d():
conv2d_constructor = torch.nn.Conv2d.__init__
def conv2d_constructor_circular(self, *args, **kwargs):
return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
torch.nn.Conv2d.__init__ = conv2d_constructor_circular
model_hijack = StableDiffusionModelHijack() model_hijack = StableDiffusionModelHijack()
# import torch
def register_buffer(self, name, attr): # from torch.nn.functional import silu
""" # from types import MethodType
Fix register buffer bug for Mac OS. #
""" # from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
# from modules.hypernetworks import hypernetwork
if type(attr) == torch.Tensor: # from modules.shared import cmd_opts
if attr.device != devices.device: # from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None)) #
# import ldm.modules.attention
setattr(self, name, attr) # import ldm.modules.diffusionmodules.model
# import ldm.modules.diffusionmodules.openaimodel
# import ldm.models.diffusion.ddpm
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer # import ldm.models.diffusion.ddim
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer # import ldm.models.diffusion.plms
# import ldm.modules.encoders.modules
#
# import sgm.modules.attention
# import sgm.modules.diffusionmodules.model
# import sgm.modules.diffusionmodules.openaimodel
# import sgm.modules.encoders.modules
#
# attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
# diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
#
# # new memory efficient cross attention blocks do not support hypernets and we already
# # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
# ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
#
# # silence new console spam from SD2
# ldm.modules.attention.print = shared.ldm_print
# ldm.modules.diffusionmodules.model.print = shared.ldm_print
# ldm.util.print = shared.ldm_print
# ldm.models.diffusion.ddpm.print = shared.ldm_print
#
# optimizers = []
# current_optimizer: sd_hijack_optimizations.SdOptimization = None
#
# ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
# ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
#
# sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
# sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
#
#
# def list_optimizers():
# new_optimizers = script_callbacks.list_optimizers_callback()
#
# new_optimizers = [x for x in new_optimizers if x.is_available()]
#
# new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
#
# optimizers.clear()
# optimizers.extend(new_optimizers)
#
#
# def apply_optimizations(option=None):
# return
#
#
# def undo_optimizations():
# return
#
#
# def fix_checkpoint():
# """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
# checkpoints to be added when not training (there's a warning)"""
#
# pass
#
#
# def weighted_loss(sd_model, pred, target, mean=True):
# #Calculate the weight normally, but ignore the mean
# loss = sd_model._old_get_loss(pred, target, mean=False)
#
# #Check if we have weights available
# weight = getattr(sd_model, '_custom_loss_weight', None)
# if weight is not None:
# loss *= weight
#
# #Return the loss, as mean if specified
# return loss.mean() if mean else loss
#
# def weighted_forward(sd_model, x, c, w, *args, **kwargs):
# try:
# #Temporarily append weights to a place accessible during loss calc
# sd_model._custom_loss_weight = w
#
# #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
# #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
# if not hasattr(sd_model, '_old_get_loss'):
# sd_model._old_get_loss = sd_model.get_loss
# sd_model.get_loss = MethodType(weighted_loss, sd_model)
#
# #Run the standard forward function, but with the patched 'get_loss'
# return sd_model.forward(x, c, *args, **kwargs)
# finally:
# try:
# #Delete temporary weights if appended
# del sd_model._custom_loss_weight
# except AttributeError:
# pass
#
# #If we have an old loss function, reset the loss function to the original one
# if hasattr(sd_model, '_old_get_loss'):
# sd_model.get_loss = sd_model._old_get_loss
# del sd_model._old_get_loss
#
# def apply_weighted_forward(sd_model):
# #Add new function 'weighted_forward' that can be called to calc weighted loss
# sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
#
# def undo_weighted_forward(sd_model):
# try:
# del sd_model.weighted_forward
# except AttributeError:
# pass
#
#
# class StableDiffusionModelHijack:
# fixes = None
# layers = None
# circular_enabled = False
# clip = None
# optimization_method = None
#
# def __init__(self):
# self.extra_generation_params = {}
# self.comments = []
#
# def apply_optimizations(self, option=None):
# pass
#
# def convert_sdxl_to_ssd(self, m):
# pass
#
# def hijack(self, m):
# pass
#
# def undo_hijack(self, m):
# pass
#
# def apply_circular(self, enable):
# pass
#
# def clear_comments(self):
# self.comments = []
# self.extra_generation_params = {}
#
# def get_prompt_lengths(self, text, cond_stage_model):
# pass
#
# def redo_hijack(self, m):
# pass
#
#
# class EmbeddingsWithFixes(torch.nn.Module):
# def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
# super().__init__()
# self.wrapped = wrapped
# self.embeddings = embeddings
# self.textual_inversion_key = textual_inversion_key
# self.weight = self.wrapped.weight
#
# def forward(self, input_ids):
# batch_fixes = self.embeddings.fixes
# self.embeddings.fixes = None
#
# inputs_embeds = self.wrapped(input_ids)
#
# if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
# return inputs_embeds
#
# vecs = []
# for fixes, tensor in zip(batch_fixes, inputs_embeds):
# for offset, embedding in fixes:
# vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
# emb = devices.cond_cast_unet(vec)
# emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
# tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
#
# vecs.append(tensor)
#
# return torch.stack(vecs)
#
#
# class TextualInversionEmbeddings(torch.nn.Embedding):
# def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
# super().__init__(num_embeddings, embedding_dim, **kwargs)
#
# self.embeddings = model_hijack
# self.textual_inversion_key = textual_inversion_key
#
# @property
# def wrapped(self):
# return super().forward
#
# def forward(self, input_ids):
# return EmbeddingsWithFixes.forward(self, input_ids)
#
#
# def add_circular_option_to_conv_2d():
# conv2d_constructor = torch.nn.Conv2d.__init__
#
# def conv2d_constructor_circular(self, *args, **kwargs):
# return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
#
# torch.nn.Conv2d.__init__ = conv2d_constructor_circular
#
#
#
#
#
# def register_buffer(self, name, attr):
# """
# Fix register buffer bug for Mac OS.
# """
#
# if type(attr) == torch.Tensor:
# if attr.device != devices.device:
# attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
#
# setattr(self, name, attr)
#
#
# ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
# ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer

View File

@@ -1,46 +1,46 @@
from torch.utils.checkpoint import checkpoint # from torch.utils.checkpoint import checkpoint
#
import ldm.modules.attention # import ldm.modules.attention
import ldm.modules.diffusionmodules.openaimodel # import ldm.modules.diffusionmodules.openaimodel
#
#
def BasicTransformerBlock_forward(self, x, context=None): # def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context) # return checkpoint(self._forward, x, context)
#
#
def AttentionBlock_forward(self, x): # def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x) # return checkpoint(self._forward, x)
#
#
def ResBlock_forward(self, x, emb): # def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb) # return checkpoint(self._forward, x, emb)
#
#
stored = [] # stored = []
#
#
def add(): # def add():
if len(stored) != 0: # if len(stored) != 0:
return # return
#
stored.extend([ # stored.extend([
ldm.modules.attention.BasicTransformerBlock.forward, # ldm.modules.attention.BasicTransformerBlock.forward,
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, # ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward # ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
]) # ])
#
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward # ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward # ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward # ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
#
#
def remove(): # def remove():
if len(stored) == 0: # if len(stored) == 0:
return # return
#
ldm.modules.attention.BasicTransformerBlock.forward = stored[0] # ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] # ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] # ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
#
stored.clear() # stored.clear()
#

File diff suppressed because it is too large Load Diff

View File

@@ -1,154 +1,154 @@
import torch # import torch
from packaging import version # from packaging import version
from einops import repeat # from einops import repeat
import math # import math
#
from modules import devices # from modules import devices
from modules.sd_hijack_utils import CondFunc # from modules.sd_hijack_utils import CondFunc
#
#
class TorchHijackForUnet: # class TorchHijackForUnet:
""" # """
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; # This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 # this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
""" # """
#
def __getattr__(self, item): # def __getattr__(self, item):
if item == 'cat': # if item == 'cat':
return self.cat # return self.cat
#
if hasattr(torch, item): # if hasattr(torch, item):
return getattr(torch, item) # return getattr(torch, item)
#
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") # raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
#
def cat(self, tensors, *args, **kwargs): # def cat(self, tensors, *args, **kwargs):
if len(tensors) == 2: # if len(tensors) == 2:
a, b = tensors # a, b = tensors
if a.shape[-2:] != b.shape[-2:]: # if a.shape[-2:] != b.shape[-2:]:
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") # a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
#
tensors = (a, b) # tensors = (a, b)
#
return torch.cat(tensors, *args, **kwargs) # return torch.cat(tensors, *args, **kwargs)
#
#
th = TorchHijackForUnet() # th = TorchHijackForUnet()
#
#
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling # # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): # def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
"""Always make sure inputs to unet are in correct dtype.""" # """Always make sure inputs to unet are in correct dtype."""
if isinstance(cond, dict): # if isinstance(cond, dict):
for y in cond.keys(): # for y in cond.keys():
if isinstance(cond[y], list): # if isinstance(cond[y], list):
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] # cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
else: # else:
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] # cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
#
with devices.autocast(): # with devices.autocast():
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) # result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
if devices.unet_needs_upcast: # if devices.unet_needs_upcast:
return result.float() # return result.float()
else: # else:
return result # return result
#
#
# Monkey patch to create timestep embed tensor on device, avoiding a block. # # Monkey patch to create timestep embed tensor on device, avoiding a block.
def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False): # def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
""" # """
Create sinusoidal timestep embeddings. # Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element. # :param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional. # These may be fractional.
:param dim: the dimension of the output. # :param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings. # :param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings. # :return: an [N x dim] Tensor of positional embeddings.
""" # """
if not repeat_only: # if not repeat_only:
half = dim // 2 # half = dim // 2
freqs = torch.exp( # freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half # -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
) # )
args = timesteps[:, None].float() * freqs[None] # args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: # if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) # embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else: # else:
embedding = repeat(timesteps, 'b -> b d', d=dim) # embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding # return embedding
#
#
# Monkey patch to SpatialTransformer removing unnecessary contiguous calls. # # Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
# Prevents a lot of unnecessary aten::copy_ calls # # Prevents a lot of unnecessary aten::copy_ calls
def spatial_transformer_forward(_, self, x: torch.Tensor, context=None): # def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
# note: if no context is given, cross-attention defaults to self-attention # # note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list): # if not isinstance(context, list):
context = [context] # context = [context]
b, c, h, w = x.shape # b, c, h, w = x.shape
x_in = x # x_in = x
x = self.norm(x) # x = self.norm(x)
if not self.use_linear: # if not self.use_linear:
x = self.proj_in(x) # x = self.proj_in(x)
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) # x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
if self.use_linear: # if self.use_linear:
x = self.proj_in(x) # x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks): # for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i]) # x = block(x, context=context[i])
if self.use_linear: # if self.use_linear:
x = self.proj_out(x) # x = self.proj_out(x)
x = x.view(b, h, w, c).permute(0, 3, 1, 2) # x = x.view(b, h, w, c).permute(0, 3, 1, 2)
if not self.use_linear: # if not self.use_linear:
x = self.proj_out(x) # x = self.proj_out(x)
return x + x_in # return x + x_in
#
#
class GELUHijack(torch.nn.GELU, torch.nn.Module): # class GELUHijack(torch.nn.GELU, torch.nn.Module):
def __init__(self, *args, **kwargs): # def __init__(self, *args, **kwargs):
torch.nn.GELU.__init__(self, *args, **kwargs) # torch.nn.GELU.__init__(self, *args, **kwargs)
def forward(self, x): # def forward(self, x):
if devices.unet_needs_upcast: # if devices.unet_needs_upcast:
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) # return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
else: # else:
return torch.nn.GELU.forward(self, x) # return torch.nn.GELU.forward(self, x)
#
#
ddpm_edit_hijack = None # ddpm_edit_hijack = None
def hijack_ddpm_edit(): # def hijack_ddpm_edit():
global ddpm_edit_hijack # global ddpm_edit_hijack
if not ddpm_edit_hijack: # if not ddpm_edit_hijack:
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) # CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) # CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model) # ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
#
#
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast # unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) # CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) # CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) # CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) # CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
#
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): # if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) # CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) # CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) # CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
#
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 # first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) # first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) # CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) # CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) # CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
#
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) # CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) # CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
#
#
def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): # def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
if devices.unet_needs_upcast and timesteps.dtype == torch.int64: # if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
dtype = torch.float32 # dtype = torch.float32
else: # else:
dtype = devices.dtype_unet # dtype = devices.dtype_unet
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) # return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
#
#
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) # CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) # CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)

View File

@@ -10,7 +10,6 @@ import re
import safetensors.torch import safetensors.torch
from omegaconf import OmegaConf, ListConfig from omegaconf import OmegaConf, ListConfig
from urllib import request from urllib import request
import ldm.modules.midas as midas
import gc import gc
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
@@ -415,89 +414,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
def enable_midas_autodownload(): def enable_midas_autodownload():
""" pass
Gives the ldm.modules.midas.api.load_model function automatic downloading.
When the 512-depth-ema model, and other future models like it, is loaded,
it calls midas.api.load_model to load the associated midas depth model.
This function applies a wrapper to download the model to the correct
location automatically.
"""
midas_path = os.path.join(paths.models_path, 'midas')
# stable-diffusion-stability-ai hard-codes the midas model path to
# a location that differs from where other scripts using this model look.
# HACK: Overriding the path here.
for k, v in midas.api.ISL_PATHS.items():
file_name = os.path.basename(v)
midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
midas_urls = {
"dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
"midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
"midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
}
midas.api.load_model_inner = midas.api.load_model
def load_model_wrapper(model_type):
path = midas.api.ISL_PATHS[model_type]
if not os.path.exists(path):
if not os.path.exists(midas_path):
os.mkdir(midas_path)
print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path)
print(f"{model_type} downloaded")
return midas.api.load_model_inner(model_type)
midas.api.load_model = load_model_wrapper
def patch_given_betas(): def patch_given_betas():
import ldm.models.diffusion.ddpm pass
def patched_register_schedule(*args, **kwargs):
"""a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
if isinstance(args[1], ListConfig):
args = (args[0], np.array(args[1]), *args[2:])
original_register_schedule(*args, **kwargs)
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
def repair_config(sd_config, state_dict=None): def repair_config(sd_config, state_dict=None):
if not hasattr(sd_config.model.params, "use_ema"): pass
sd_config.model.params.use_ema = False
if hasattr(sd_config.model.params, 'unet_config'):
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
sd_config.model.params.unet_config.params.use_fp16 = True
if hasattr(sd_config.model.params, 'first_stage_config'):
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
# For UnCLIP-L, override the hardcoded karlo directory
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
# Do not use checkpoint for inference.
# This helps prevent extra performance overhead on checking parameters.
# The perf overhead is about 100ms/it on 4090 for SDXL.
if hasattr(sd_config.model.params, "network_config"):
sd_config.model.params.network_config.params.use_checkpoint = False
if hasattr(sd_config.model.params, "unet_config"):
sd_config.model.params.unet_config.params.use_checkpoint = False
def rescale_zero_terminal_snr_abar(alphas_cumprod): def rescale_zero_terminal_snr_abar(alphas_cumprod):

View File

@@ -1,137 +1,137 @@
import os # import os
#
import torch # import torch
#
from modules import shared, paths, sd_disable_initialization, devices # from modules import shared, paths, sd_disable_initialization, devices
#
sd_configs_path = shared.sd_configs_path # sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") # # sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference") # # sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
#
#
config_default = shared.sd_default_config # config_default = shared.sd_default_config
# config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") # # config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") # config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") # config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") # config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") # config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") # config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") # config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") # config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") # config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") # config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") # config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") # config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") # config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") # config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
#
#
def is_using_v_parameterization_for_sd2(state_dict): # def is_using_v_parameterization_for_sd2(state_dict):
""" # """
Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. # Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
""" # """
#
import ldm.modules.diffusionmodules.openaimodel # import ldm.modules.diffusionmodules.openaimodel
#
device = devices.device # device = devices.device
#
with sd_disable_initialization.DisableInitialization(): # with sd_disable_initialization.DisableInitialization():
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( # unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
use_checkpoint=False, # use_checkpoint=False,
use_fp16=False, # use_fp16=False,
image_size=32, # image_size=32,
in_channels=4, # in_channels=4,
out_channels=4, # out_channels=4,
model_channels=320, # model_channels=320,
attention_resolutions=[4, 2, 1], # attention_resolutions=[4, 2, 1],
num_res_blocks=2, # num_res_blocks=2,
channel_mult=[1, 2, 4, 4], # channel_mult=[1, 2, 4, 4],
num_head_channels=64, # num_head_channels=64,
use_spatial_transformer=True, # use_spatial_transformer=True,
use_linear_in_transformer=True, # use_linear_in_transformer=True,
transformer_depth=1, # transformer_depth=1,
context_dim=1024, # context_dim=1024,
legacy=False # legacy=False
) # )
unet.eval() # unet.eval()
#
with torch.no_grad(): # with torch.no_grad():
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} # unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
unet.load_state_dict(unet_sd, strict=True) # unet.load_state_dict(unet_sd, strict=True)
unet.to(device=device, dtype=devices.dtype_unet) # unet.to(device=device, dtype=devices.dtype_unet)
#
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 # test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 # x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
#
with devices.autocast(): # with devices.autocast():
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item() # out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
#
return out < -1 # return out < -1
#
#
def guess_model_config_from_state_dict(sd, filename): # def guess_model_config_from_state_dict(sd, filename):
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) # sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) # diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) # sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
#
if "model.diffusion_model.x_embedder.proj.weight" in sd: # if "model.diffusion_model.x_embedder.proj.weight" in sd:
return config_sd3 # return config_sd3
#
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: # if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
if diffusion_model_input.shape[1] == 9: # if diffusion_model_input.shape[1] == 9:
return config_sdxl_inpainting # return config_sdxl_inpainting
else: # else:
return config_sdxl # return config_sdxl
#
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: # if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
return config_sdxl_refiner # return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: # elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model # return config_depth_model
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: # elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
return config_unclip # return config_unclip
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024: # elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
return config_unopenclip # return config_unopenclip
#
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: # if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
if diffusion_model_input.shape[1] == 9: # if diffusion_model_input.shape[1] == 9:
return config_sd2_inpainting # return config_sd2_inpainting
# elif is_using_v_parameterization_for_sd2(sd): # # elif is_using_v_parameterization_for_sd2(sd):
# return config_sd2v # # return config_sd2v
else: # else:
return config_sd2v # return config_sd2v
#
if diffusion_model_input is not None: # if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9: # if diffusion_model_input.shape[1] == 9:
return config_inpainting # return config_inpainting
if diffusion_model_input.shape[1] == 8: # if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix # return config_instruct_pix2pix
#
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: # if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: # if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
return config_alt_diffusion_m18 # return config_alt_diffusion_m18
return config_alt_diffusion # return config_alt_diffusion
#
return config_default # return config_default
#
#
def find_checkpoint_config(state_dict, info): # def find_checkpoint_config(state_dict, info):
if info is None: # if info is None:
return guess_model_config_from_state_dict(state_dict, "") # return guess_model_config_from_state_dict(state_dict, "")
#
config = find_checkpoint_config_near_filename(info) # config = find_checkpoint_config_near_filename(info)
if config is not None: # if config is not None:
return config # return config
#
return guess_model_config_from_state_dict(state_dict, info.filename) # return guess_model_config_from_state_dict(state_dict, info.filename)
#
#
def find_checkpoint_config_near_filename(info): # def find_checkpoint_config_near_filename(info):
if info is None: # if info is None:
return None # return None
#
config = f"{os.path.splitext(info.filename)[0]}.yaml" # config = f"{os.path.splitext(info.filename)[0]}.yaml"
if os.path.exists(config): # if os.path.exists(config):
return config # return config
#
return None # return None
#

View File

@@ -1,4 +1,3 @@
from ldm.models.diffusion.ddpm import LatentDiffusion
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -6,7 +5,7 @@ if TYPE_CHECKING:
from modules.sd_models import CheckpointInfo from modules.sd_models import CheckpointInfo
class WebuiSdModel(LatentDiffusion): class WebuiSdModel:
"""This class is not actually instantinated, but its fields are created and fieeld by webui""" """This class is not actually instantinated, but its fields are created and fieeld by webui"""
lowvram: bool lowvram: bool

View File

@@ -1,115 +1,115 @@
from __future__ import annotations # from __future__ import annotations
#
import torch # import torch
#
import sgm.models.diffusion # import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling # import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer # import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser # from modules import devices, shared, prompt_parser
from modules import torch_utils # from modules import torch_utils
#
from backend import memory_management # from backend import memory_management
#
#
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): # def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
#
for embedder in self.conditioner.embedders: # for embedder in self.conditioner.embedders:
embedder.ucg_rate = 0.0 # embedder.ucg_rate = 0.0
#
width = getattr(batch, 'width', 1024) or 1024 # width = getattr(batch, 'width', 1024) or 1024
height = getattr(batch, 'height', 1024) or 1024 # height = getattr(batch, 'height', 1024) or 1024
is_negative_prompt = getattr(batch, 'is_negative_prompt', False) # is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score # aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
#
devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype()) # devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
#
sdxl_conds = { # sdxl_conds = {
"txt": batch, # "txt": batch,
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), # "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), # "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), # "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), # "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
} # }
#
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) # force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) # c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
#
return c # return c
#
#
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs): # def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
if self.model.diffusion_model.in_channels == 9: # if self.model.diffusion_model.in_channels == 9:
x = torch.cat([x] + cond['c_concat'], dim=1) # x = torch.cat([x] + cond['c_concat'], dim=1)
#
return self.model(x, t, cond, *args, **kwargs) # return self.model(x, t, cond, *args, **kwargs)
#
#
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility # def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
return x # return x
#
#
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning # sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model # sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding # sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
#
#
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): # def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
res = [] # res = []
#
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: # for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
encoded = embedder.encode_embedding_init_text(init_text, nvpt) # encoded = embedder.encode_embedding_init_text(init_text, nvpt)
res.append(encoded) # res.append(encoded)
#
return torch.cat(res, dim=1) # return torch.cat(res, dim=1)
#
#
def tokenize(self: sgm.modules.GeneralConditioner, texts): # def tokenize(self: sgm.modules.GeneralConditioner, texts):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]: # for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
return embedder.tokenize(texts) # return embedder.tokenize(texts)
#
raise AssertionError('no tokenizer available') # raise AssertionError('no tokenizer available')
#
#
#
def process_texts(self, texts): # def process_texts(self, texts):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: # for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
return embedder.process_texts(texts) # return embedder.process_texts(texts)
#
#
def get_target_prompt_token_count(self, token_count): # def get_target_prompt_token_count(self, token_count):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: # for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
return embedder.get_target_prompt_token_count(token_count) # return embedder.get_target_prompt_token_count(token_count)
#
#
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist # # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text # sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
sgm.modules.GeneralConditioner.tokenize = tokenize # sgm.modules.GeneralConditioner.tokenize = tokenize
sgm.modules.GeneralConditioner.process_texts = process_texts # sgm.modules.GeneralConditioner.process_texts = process_texts
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count # sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
#
#
def extend_sdxl(model): # def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" # """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
#
dtype = torch_utils.get_param(model.model.diffusion_model).dtype # dtype = torch_utils.get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype # model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn' # model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt' # model.cond_stage_key = 'txt'
# model.cond_stage_model will be set in sd_hijack # # model.cond_stage_model will be set in sd_hijack
#
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" # model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
#
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() # discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) # model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
#
model.conditioner.wrapped = torch.nn.Module() # model.conditioner.wrapped = torch.nn.Module()
#
#
sgm.modules.attention.print = shared.ldm_print # sgm.modules.attention.print = shared.ldm_print
sgm.modules.diffusionmodules.model.print = shared.ldm_print # sgm.modules.diffusionmodules.model.print = shared.ldm_print
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print # sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
sgm.modules.encoders.modules.print = shared.ldm_print # sgm.modules.encoders.modules.print = shared.ldm_print
#
# this gets the code to load the vanilla attention that we override # # this gets the code to load the vanilla attention that we override
sgm.modules.attention.SDP_IS_AVAILABLE = True # sgm.modules.attention.SDP_IS_AVAILABLE = True
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False # sgm.modules.attention.XFORMERS_IS_AVAILABLE = False

View File

@@ -35,9 +35,7 @@ def refresh_vae_list():
def cross_attention_optimizations(): def cross_attention_optimizations():
import modules.sd_hijack return ["Automatic"]
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
def sd_unet_items(): def sd_unet_items():

View File

@@ -1,245 +1,243 @@
import os # import os
import numpy as np # import numpy as np
import PIL # import PIL
import torch # import torch
from torch.utils.data import Dataset, DataLoader, Sampler # from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms # from torchvision import transforms
from collections import defaultdict # from collections import defaultdict
from random import shuffle, choices # from random import shuffle, choices
#
import random # import random
import tqdm # import tqdm
from modules import devices, shared, images # from modules import devices, shared, images
import re # import re
#
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution # re_numbers_at_start = re.compile(r"^[-\d]+\s*")
#
re_numbers_at_start = re.compile(r"^[-\d]+\s*") #
# class DatasetEntry:
# def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
class DatasetEntry: # self.filename = filename
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None): # self.filename_text = filename_text
self.filename = filename # self.weight = weight
self.filename_text = filename_text # self.latent_dist = latent_dist
self.weight = weight # self.latent_sample = latent_sample
self.latent_dist = latent_dist # self.cond = cond
self.latent_sample = latent_sample # self.cond_text = cond_text
self.cond = cond # self.pixel_values = pixel_values
self.cond_text = cond_text #
self.pixel_values = pixel_values #
# class PersonalizedBase(Dataset):
# def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
class PersonalizedBase(Dataset): # re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): #
re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None # self.placeholder_token = placeholder_token
#
self.placeholder_token = placeholder_token # self.flip = transforms.RandomHorizontalFlip(p=flip_p)
#
self.flip = transforms.RandomHorizontalFlip(p=flip_p) # self.dataset = []
#
self.dataset = [] # with open(template_file, "r") as file:
# lines = [x.strip() for x in file.readlines()]
with open(template_file, "r") as file: #
lines = [x.strip() for x in file.readlines()] # self.lines = lines
#
self.lines = lines # assert data_root, 'dataset directory not specified'
# assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert data_root, 'dataset directory not specified' # assert os.listdir(data_root), "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist" #
assert os.listdir(data_root), "Dataset directory is empty" # self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
#
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] # self.shuffle_tags = shuffle_tags
# self.tag_drop_out = tag_drop_out
self.shuffle_tags = shuffle_tags # groups = defaultdict(list)
self.tag_drop_out = tag_drop_out #
groups = defaultdict(list) # print("Preparing dataset...")
# for path in tqdm.tqdm(self.image_paths):
print("Preparing dataset...") # alpha_channel = None
for path in tqdm.tqdm(self.image_paths): # if shared.state.interrupted:
alpha_channel = None # raise Exception("interrupted")
if shared.state.interrupted: # try:
raise Exception("interrupted") # image = images.read(path)
try: # #Currently does not work for single color transparency
image = images.read(path) # #We would need to read image.info['transparency'] for that
#Currently does not work for single color transparency # if use_weight and 'A' in image.getbands():
#We would need to read image.info['transparency'] for that # alpha_channel = image.getchannel('A')
if use_weight and 'A' in image.getbands(): # image = image.convert('RGB')
alpha_channel = image.getchannel('A') # if not varsize:
image = image.convert('RGB') # image = image.resize((width, height), PIL.Image.BICUBIC)
if not varsize: # except Exception:
image = image.resize((width, height), PIL.Image.BICUBIC) # continue
except Exception: #
continue # text_filename = f"{os.path.splitext(path)[0]}.txt"
# filename = os.path.basename(path)
text_filename = f"{os.path.splitext(path)[0]}.txt" #
filename = os.path.basename(path) # if os.path.exists(text_filename):
# with open(text_filename, "r", encoding="utf8") as file:
if os.path.exists(text_filename): # filename_text = file.read()
with open(text_filename, "r", encoding="utf8") as file: # else:
filename_text = file.read() # filename_text = os.path.splitext(filename)[0]
else: # filename_text = re.sub(re_numbers_at_start, '', filename_text)
filename_text = os.path.splitext(filename)[0] # if re_word:
filename_text = re.sub(re_numbers_at_start, '', filename_text) # tokens = re_word.findall(filename_text)
if re_word: # filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
tokens = re_word.findall(filename_text) #
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) # npimage = np.array(image).astype(np.uint8)
# npimage = (npimage / 127.5 - 1.0).astype(np.float32)
npimage = np.array(image).astype(np.uint8) #
npimage = (npimage / 127.5 - 1.0).astype(np.float32) # torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
# latent_sample = None
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) #
latent_sample = None # with devices.autocast():
# latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
with devices.autocast(): #
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) # #Perform latent sampling, even for random sampling.
# #We need the sample dimensions for the weights
#Perform latent sampling, even for random sampling. # if latent_sampling_method == "deterministic":
#We need the sample dimensions for the weights # if isinstance(latent_dist, DiagonalGaussianDistribution):
if latent_sampling_method == "deterministic": # # Works only for DiagonalGaussianDistribution
if isinstance(latent_dist, DiagonalGaussianDistribution): # latent_dist.std = 0
# Works only for DiagonalGaussianDistribution # else:
latent_dist.std = 0 # latent_sampling_method = "once"
else: # latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
latent_sampling_method = "once" #
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) # if use_weight and alpha_channel is not None:
# channels, *latent_size = latent_sample.shape
if use_weight and alpha_channel is not None: # weight_img = alpha_channel.resize(latent_size)
channels, *latent_size = latent_sample.shape # npweight = np.array(weight_img).astype(np.float32)
weight_img = alpha_channel.resize(latent_size) # #Repeat for every channel in the latent sample
npweight = np.array(weight_img).astype(np.float32) # weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
#Repeat for every channel in the latent sample # #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size) # weight -= weight.min()
#Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default. # weight /= weight.mean()
weight -= weight.min() # elif use_weight:
weight /= weight.mean() # #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
elif use_weight: # weight = torch.ones(latent_sample.shape)
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later # else:
weight = torch.ones(latent_sample.shape) # weight = None
else: #
weight = None # if latent_sampling_method == "random":
# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
if latent_sampling_method == "random": # else:
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) # entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
else: #
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight) # if not (self.tag_drop_out != 0 or self.shuffle_tags):
# entry.cond_text = self.create_text(filename_text)
if not (self.tag_drop_out != 0 or self.shuffle_tags): #
entry.cond_text = self.create_text(filename_text) # if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
# with devices.autocast():
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): # entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
with devices.autocast(): # groups[image.size].append(len(self.dataset))
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) # self.dataset.append(entry)
groups[image.size].append(len(self.dataset)) # del torchdata
self.dataset.append(entry) # del latent_dist
del torchdata # del latent_sample
del latent_dist # del weight
del latent_sample #
del weight # self.length = len(self.dataset)
# self.groups = list(groups.values())
self.length = len(self.dataset) # assert self.length > 0, "No images have been found in the dataset."
self.groups = list(groups.values()) # self.batch_size = min(batch_size, self.length)
assert self.length > 0, "No images have been found in the dataset." # self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.batch_size = min(batch_size, self.length) # self.latent_sampling_method = latent_sampling_method
self.gradient_step = min(gradient_step, self.length // self.batch_size) #
self.latent_sampling_method = latent_sampling_method # if len(groups) > 1:
# print("Buckets:")
if len(groups) > 1: # for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
print("Buckets:") # print(f" {w}x{h}: {len(ids)}")
for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]): # print()
print(f" {w}x{h}: {len(ids)}") #
print() # def create_text(self, filename_text):
# text = random.choice(self.lines)
def create_text(self, filename_text): # tags = filename_text.split(',')
text = random.choice(self.lines) # if self.tag_drop_out != 0:
tags = filename_text.split(',') # tags = [t for t in tags if random.random() > self.tag_drop_out]
if self.tag_drop_out != 0: # if self.shuffle_tags:
tags = [t for t in tags if random.random() > self.tag_drop_out] # random.shuffle(tags)
if self.shuffle_tags: # text = text.replace("[filewords]", ','.join(tags))
random.shuffle(tags) # text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ','.join(tags)) # return text
text = text.replace("[name]", self.placeholder_token) #
return text # def __len__(self):
# return self.length
def __len__(self): #
return self.length # def __getitem__(self, i):
# entry = self.dataset[i]
def __getitem__(self, i): # if self.tag_drop_out != 0 or self.shuffle_tags:
entry = self.dataset[i] # entry.cond_text = self.create_text(entry.filename_text)
if self.tag_drop_out != 0 or self.shuffle_tags: # if self.latent_sampling_method == "random":
entry.cond_text = self.create_text(entry.filename_text) # entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
if self.latent_sampling_method == "random": # return entry
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) #
return entry #
# class GroupedBatchSampler(Sampler):
# def __init__(self, data_source: PersonalizedBase, batch_size: int):
class GroupedBatchSampler(Sampler): # super().__init__(data_source)
def __init__(self, data_source: PersonalizedBase, batch_size: int): #
super().__init__(data_source) # n = len(data_source)
# self.groups = data_source.groups
n = len(data_source) # self.len = n_batch = n // batch_size
self.groups = data_source.groups # expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
self.len = n_batch = n // batch_size # self.base = [int(e) // batch_size for e in expected]
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups] # self.n_rand_batches = nrb = n_batch - sum(self.base)
self.base = [int(e) // batch_size for e in expected] # self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.n_rand_batches = nrb = n_batch - sum(self.base) # self.batch_size = batch_size
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] #
self.batch_size = batch_size # def __len__(self):
# return self.len
def __len__(self): #
return self.len # def __iter__(self):
# b = self.batch_size
def __iter__(self): #
b = self.batch_size # for g in self.groups:
# shuffle(g)
for g in self.groups: #
shuffle(g) # batches = []
# for g in self.groups:
batches = [] # batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for g in self.groups: # for _ in range(self.n_rand_batches):
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) # rand_group = choices(self.groups, self.probs)[0]
for _ in range(self.n_rand_batches): # batches.append(choices(rand_group, k=b))
rand_group = choices(self.groups, self.probs)[0] #
batches.append(choices(rand_group, k=b)) # shuffle(batches)
#
shuffle(batches) # yield from batches
#
yield from batches #
# class PersonalizedDataLoader(DataLoader):
# def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
class PersonalizedDataLoader(DataLoader): # super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): # if latent_sampling_method == "random":
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) # self.collate_fn = collate_wrapper_random
if latent_sampling_method == "random": # else:
self.collate_fn = collate_wrapper_random # self.collate_fn = collate_wrapper
else: #
self.collate_fn = collate_wrapper #
# class BatchLoader:
# def __init__(self, data):
class BatchLoader: # self.cond_text = [entry.cond_text for entry in data]
def __init__(self, data): # self.cond = [entry.cond for entry in data]
self.cond_text = [entry.cond_text for entry in data] # self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
self.cond = [entry.cond for entry in data] # if all(entry.weight is not None for entry in data):
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) # self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
if all(entry.weight is not None for entry in data): # else:
self.weight = torch.stack([entry.weight for entry in data]).squeeze(1) # self.weight = None
else: # #self.emb_index = [entry.emb_index for entry in data]
self.weight = None # #print(self.latent_sample.device)
#self.emb_index = [entry.emb_index for entry in data] #
#print(self.latent_sample.device) # def pin_memory(self):
# self.latent_sample = self.latent_sample.pin_memory()
def pin_memory(self): # return self
self.latent_sample = self.latent_sample.pin_memory() #
return self # def collate_wrapper(batch):
# return BatchLoader(batch)
def collate_wrapper(batch): #
return BatchLoader(batch) # class BatchLoaderRandom(BatchLoader):
# def __init__(self, data):
class BatchLoaderRandom(BatchLoader): # super().__init__(data)
def __init__(self, data): #
super().__init__(data) # def pin_memory(self):
# return self
def pin_memory(self): #
return self # def collate_wrapper_random(batch):
# return BatchLoaderRandom(batch)
def collate_wrapper_random(batch):
return BatchLoaderRandom(batch)