Free WebUI from its Prison

Congratulations WebUI. Say Hello to freedom.
This commit is contained in:
layerdiffusion
2024-08-05 03:58:34 -07:00
parent aafe11b14c
commit bccf9fb23a
26 changed files with 2053 additions and 4392 deletions

View File

@@ -1,250 +0,0 @@
import os
import gc
import time
import numpy as np
import torch
import torchvision
from PIL import Image
from einops import rearrange, repeat
from omegaconf import OmegaConf
import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack, devices
cached_ldsr_model: torch.nn.Module = None
# Create LDSR Class
class LDSR:
def load_model_from_config(self, half_attention):
global cached_ldsr_model
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
print("Loading model from cache")
model: torch.nn.Module = cached_ldsr_model
else:
print(f"Loading model from {self.modelPath}")
_, extension = os.path.splitext(self.modelPath)
if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
else:
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
config = OmegaConf.load(self.yamlPath)
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model: torch.nn.Module = instantiate_from_config(config.model)
model.load_state_dict(sd, strict=False)
model = model.to(shared.device)
if half_attention:
model = model.half()
if shared.cmd_opts.opt_channelslast:
model = model.to(memory_format=torch.channels_last)
sd_hijack.model_hijack.hijack(model) # apply optimization
model.eval()
if shared.opts.ldsr_cached:
cached_ldsr_model = model
return {"model": model}
def __init__(self, model_path, yaml_path):
self.modelPath = model_path
self.yamlPath = yaml_path
@staticmethod
def run(model, selected_path, custom_steps, eta):
example = get_cond(selected_path)
n_runs = 1
guider = None
ckwargs = None
ddim_use_x0_pred = False
temperature = 1.
eta = eta
custom_shape = None
height, width = example["image"].shape[1:3]
split_input = height >= 128 and width >= 128
if split_input:
ks = 128
stride = 64
vqf = 4 #
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01}
else:
if hasattr(model, "split_input_params"):
delattr(model, "split_input_params")
x_t = None
logs = None
for _ in range(n_runs):
if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
logs = make_convolutional_sample(example, model,
custom_steps=custom_steps,
eta=eta, quantize_x0=False,
custom_shape=custom_shape,
temperature=temperature, noise_dropout=0.,
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
ddim_use_x0_pred=ddim_use_x0_pred
)
return logs
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
model = self.load_model_from_config(half_attention)
# Run settings
diffusion_steps = int(steps)
eta = 1.0
gc.collect()
devices.torch_gc()
im_og = image
width_og, height_og = im_og.size
# If we can adjust the max upscale size, then the 4 below should be our variable
down_sample_rate = target_scale / 4
wd = width_og * down_sample_rate
hd = height_og * down_sample_rate
width_downsampled_pre = int(np.ceil(wd))
height_downsampled_pre = int(np.ceil(hd))
if down_sample_rate != 1:
print(
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"]
sample = sample.detach().cpu()
sample = torch.clamp(sample, -1., 1.)
sample = (sample + 1.) / 2. * 255
sample = sample.numpy().astype(np.uint8)
sample = np.transpose(sample, (0, 2, 3, 1))
a = Image.fromarray(sample[0])
# remove padding
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
del model
gc.collect()
devices.torch_gc()
return a
def get_cond(selected_path):
example = {}
up_f = 4
c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
antialias=True)
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.
c = c.to(shared.device)
example["LR_image"] = c
example["image"] = c_up
return example
@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
corrector_kwargs=None, x_t=None
):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
print(f"Sampling with eta = {eta}; steps: {steps}")
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
mask=mask, x0=x0, temperature=temperature, verbose=False,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, x_t=x_t)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
log = {}
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=not (hasattr(model, 'split_input_params')
and model.cond_stage_key == 'coordinates_bbox'),
return_original_cond=True)
if custom_shape is not None:
z = torch.randn(custom_shape)
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
z0 = None
log["input"] = x
log["reconstruction"] = xrec
if ismap(xc):
log["original_conditioning"] = model.to_rgb(xc)
if hasattr(model, 'cond_stage_key'):
log[model.cond_stage_key] = model.to_rgb(xc)
else:
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_model:
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_key == 'class_label':
log[model.cond_stage_key] = xc[model.cond_stage_key]
with model.ema_scope("Plotting"):
t0 = time.time()
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
eta=eta,
quantize_x0=quantize_x0, mask=None, x0=z0,
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
x_t=x_T)
t1 = time.time()
if ddim_use_x0_pred:
sample = intermediates['pred_x0'][-1]
x_sample = model.decode_first_stage(sample)
try:
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except Exception:
pass
log["sample"] = x_sample
log["time"] = t1 - t0
return log

View File

@@ -1,6 +0,0 @@
import os
from modules import paths
def preload(parser):
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))

View File

@@ -1,70 +0,0 @@
import os
from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from modules_forge.utils import prepare_free_memory
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks, errors
import sd_hijack_autoencoder # noqa: F401
import sd_hijack_ddpm_v1 # noqa: F401
class UpscalerLDSR(Upscaler):
def __init__(self, user_path):
self.name = "LDSR"
self.user_path = user_path
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
super().__init__()
scaler_data = UpscalerData("LDSR", None, self)
self.scalers = [scaler_data]
def load_model(self, path: str):
# Remove incorrect project.yaml file if too big
yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt")
local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"])
local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None)
local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None)
local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None)
if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760:
print("Removing invalid LDSR YAML file.")
os.remove(yaml_path)
if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path)
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path
else:
model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
return LDSR(model, yaml)
def do_upscale(self, img, path):
prepare_free_memory(aggressive=True)
try:
ldsr = self.load_model(path)
except Exception:
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
return img
ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale)
def on_ui_settings():
import gradio as gr
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
script_callbacks.on_ui_settings(on_ui_settings)

View File

@@ -1,293 +0,0 @@
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from torch.optim.lr_scheduler import LambdaLR
from ldm.modules.ema import LitEma
from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config
import ldm.models.autoencoder
from packaging import version
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=None,
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if missing:
print(f"Missing Keys: {missing}")
if unexpected:
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
]
return [opt_ae, opt_disc], scheduler
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = {}
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(*args, embed_dim=embed_dim, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
ldm.models.autoencoder.VQModel = VQModel
ldm.models.autoencoder.VQModelInterface = VQModelInterface

File diff suppressed because it is too large Load Diff

View File

@@ -1,147 +0,0 @@
# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py,
# where the license is as follows:
#
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
# OR OTHER DEALINGS IN THE SOFTWARE./
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
class VectorQuantizer2(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
sane_index_shape=False, legacy=True):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices.")
else:
self.re_embed = n_e
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
assert rescale_logits is False, "Only for interface compatible with Gumbel"
assert return_logits is False, "Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
torch.mean((z_q - z.detach()) ** 2)
else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3])
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q

View File

@@ -11,9 +11,12 @@ from transformers.models.clip.modeling_clip import CLIPVisionModelOutput
from annotator.util import HWC3
from typing import Callable, Tuple, Union
from modules.safe import Extra
from modules import devices
import contextlib
Extra = lambda x: contextlib.nullcontext()
def torch_handler(module: str, name: str):
""" Allow all torch access. Bypass A1111 safety whitelist. """

View File

@@ -19,7 +19,6 @@ import cv2
import logging
from typing import Any, Callable, Dict, List
from modules.safe import unsafe_torch_load
from lib_controlnet.logging import logger
@@ -28,7 +27,7 @@ def load_state_dict(ckpt_path, location="cpu"):
if extension.lower() == ".safetensors":
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = unsafe_torch_load(ckpt_path, map_location=torch.device(location))
state_dict = torch.load(ckpt_path, map_location=torch.device(location))
state_dict = get_state_dict(state_dict)
logger.info(f"Loaded state_dict from [{ckpt_path}]")
return state_dict

View File

@@ -24,7 +24,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
from typing import Any
@@ -725,7 +724,7 @@ class Api:
def get_sd_models(self):
import modules.sd_models as sd_models
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename} for x in sd_models.checkpoints_list.values()]
def get_sd_vaes(self):
import modules.sd_vae as sd_vae

View File

@@ -9,7 +9,7 @@ import modules.textual_inversion.dataset
import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
from backend.nn.unet import default
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, saving_settings
from modules.textual_inversion.learn_schedule import LearnRateScheduler

View File

@@ -1,25 +1,12 @@
import importlib
import logging
import os
import sys
import warnings
import os
from threading import Thread
from modules.timer import startup_timer
class HiddenPrints:
def __enter__(self):
self._original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout.close()
sys.stdout = self._original_stdout
def imports():
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
@@ -35,16 +22,8 @@ def imports():
import gradio # noqa: F401
startup_timer.record("import gradio")
with HiddenPrints():
from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer.record("setup paths")
import ldm.modules.encoders.modules # noqa: F401
import ldm.modules.diffusionmodules.model
startup_timer.record("import ldm")
import sgm.modules.encoders.modules # noqa: F401
startup_timer.record("import sgm")
from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer.record("setup paths")
from modules import shared_init
shared_init.initialize()
@@ -137,15 +116,6 @@ def initialize_rest(*, reload_script_modules=False):
sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
from modules import textual_inversion
textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
sd_hijack.list_optimizers()
startup_timer.record("scripts list_optimizers")
from modules import sd_unet
sd_unet.list_unets()
startup_timer.record("scripts list_unets")

View File

@@ -391,15 +391,15 @@ def prepare_environment():
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
# stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
# stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
# stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
# stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
@@ -456,8 +456,8 @@ def prepare_environment():
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
# git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
# git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)

View File

@@ -2,45 +2,15 @@ import os
import sys
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
import modules.safe # noqa: F401
def mute_sdxl_imports():
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
class Dummy:
pass
module = Dummy()
module.LPIPS = None
sys.modules['taming.modules.losses.lpips'] = module
module = Dummy()
module.StableDataModuleFromConfig = None
sys.modules['sgm.data'] = module
# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path)
# search for directory of stable diffusion in following places
sd_path = None
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths:
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
sd_path = os.path.abspath(possible_sd_path)
break
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
mute_sdxl_imports()
sd_path = os.path.dirname(__file__)
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
(os.path.join(sd_path, '../huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
(os.path.join(sd_path, '../repositories/BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../repositories/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
(os.path.join(sd_path, '../repositories/huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
]
paths = {}
@@ -53,13 +23,6 @@ for d, must_exist, what, options in path_dirs:
d = os.path.abspath(d)
if "atstart" in options:
sys.path.insert(0, d)
elif "sgm" in options:
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
sys.path.insert(0, d)
import sgm # noqa: F401
sys.path.pop(0)
else:
sys.path.append(d)
paths[what] = d

View File

@@ -28,8 +28,6 @@ import modules.images as images
import modules.styles
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
@@ -295,23 +293,7 @@ class StableDiffusionProcessing:
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
def depth2img_image_conditioning(self, source_image):
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
transformer = AddMiDaS(model_type="dpt_hybrid")
transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:],
mode="bicubic",
align_corners=False,
)
(depth_min, depth_max) = torch.aminmax(conditioning)
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
return conditioning
raise NotImplementedError('NotImplementedError: depth2img_image_conditioning')
def edit_image_conditioning(self, source_image):
conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
@@ -368,11 +350,6 @@ class StableDiffusionProcessing:
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
source_image = devices.cond_cast_float(source_image)
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# identify itself with a field common to all models. The conditioning_key is also hybrid.
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
return self.depth2img_image_conditioning(source_image)
if self.sd_model.cond_stage_key == "edit":
return self.edit_image_conditioning(source_image)

View File

@@ -1,195 +1,195 @@
# this code is adapted from the script contributed by anon from /h/
import pickle
import collections
import torch
import numpy
import _codecs
import zipfile
import re
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
from modules import errors
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args):
out = _codecs.encode(*args)
return out
class RestrictedUnpickler(pickle.Unpickler):
extra_handler = None
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
try:
return TypedStorage(_internal=True)
except TypeError:
return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
def find_class(self, module, name):
if self.extra_handler is not None:
res = self.extra_handler(module, name)
if res is not None:
return res
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
return getattr(numpy.core.multiarray, name)
if module == 'numpy' and name in ['dtype', 'ndarray']:
return getattr(numpy, name)
if module == '_codecs' and name == 'encode':
return encode
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
import pytorch_lightning.callbacks
return pytorch_lightning.callbacks.model_checkpoint
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
import pytorch_lightning.callbacks.model_checkpoint
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
if module == "__builtin__" and name == 'set':
return set
# Forbid everything else.
raise Exception(f"global '{module}/{name}' is forbidden")
# Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/<number>'
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$")
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
def check_zip_filenames(filename, names):
for name in names:
if allowed_zip_names_re.match(name):
continue
raise Exception(f"bad file inside {filename}: {name}")
def check_pt(filename, extra_handler):
try:
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
if len(data_pkl_filenames) == 0:
raise Exception(f"data.pkl not found in {filename}")
if len(data_pkl_filenames) > 1:
raise Exception(f"Multiple data.pkl found in {filename}")
with z.open(data_pkl_filenames[0]) as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
unpickler.load()
except zipfile.BadZipfile:
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
for _ in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
"""
this function is intended to be used by extensions that want to load models with
some extra classes in them that the usual unpickler would find suspicious.
Use the extra_handler argument to specify a function that takes module and field name as text,
and returns that field's value:
```python
def extra(module, name):
if module == 'collections' and name == 'OrderedDict':
return collections.OrderedDict
return None
safe.load_with_extra('model.pt', extra_handler=extra)
```
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
definitely unsafe.
"""
from modules import shared
try:
if not shared.cmd_opts.disable_safe_unpickle:
check_pt(filename, extra_handler)
except pickle.UnpicklingError:
errors.report(
f"Error verifying pickled file from {filename}\n"
"-----> !!!! The file is most likely corrupted !!!! <-----\n"
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
exc_info=True,
)
return None
except Exception:
errors.report(
f"Error verifying pickled file from {filename}\n"
f"The file may be malicious, so the program is not going to read it.\n"
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
exc_info=True,
)
return None
return unsafe_torch_load(filename, *args, **kwargs)
class Extra:
"""
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
(because it's not your code making the torch.load call). The intended use is like this:
```
import torch
from modules import safe
def handler(module, name):
if module == 'torch' and name in ['float64', 'float16']:
return getattr(torch, name)
return None
with safe.Extra(handler):
x = torch.load('model.pt')
```
"""
def __init__(self, handler):
self.handler = handler
def __enter__(self):
global global_extra_handler
assert global_extra_handler is None, 'already inside an Extra() block'
global_extra_handler = self.handler
def __exit__(self, exc_type, exc_val, exc_tb):
global global_extra_handler
global_extra_handler = None
unsafe_torch_load = torch.load
global_extra_handler = None
# # this code is adapted from the script contributed by anon from /h/
#
# import pickle
# import collections
#
# import torch
# import numpy
# import _codecs
# import zipfile
# import re
#
#
# # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
# from modules import errors
#
# TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
#
# def encode(*args):
# out = _codecs.encode(*args)
# return out
#
#
# class RestrictedUnpickler(pickle.Unpickler):
# extra_handler = None
#
# def persistent_load(self, saved_id):
# assert saved_id[0] == 'storage'
#
# try:
# return TypedStorage(_internal=True)
# except TypeError:
# return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
#
# def find_class(self, module, name):
# if self.extra_handler is not None:
# res = self.extra_handler(module, name)
# if res is not None:
# return res
#
# if module == 'collections' and name == 'OrderedDict':
# return getattr(collections, name)
# if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
# return getattr(torch._utils, name)
# if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
# return getattr(torch, name)
# if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
# return getattr(torch.nn.modules.container, name)
# if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
# return getattr(numpy.core.multiarray, name)
# if module == 'numpy' and name in ['dtype', 'ndarray']:
# return getattr(numpy, name)
# if module == '_codecs' and name == 'encode':
# return encode
# if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
# import pytorch_lightning.callbacks
# return pytorch_lightning.callbacks.model_checkpoint
# if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
# import pytorch_lightning.callbacks.model_checkpoint
# return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
# if module == "__builtin__" and name == 'set':
# return set
#
# # Forbid everything else.
# raise Exception(f"global '{module}/{name}' is forbidden")
#
#
# # Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/<number>'
# allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$")
# data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
#
# def check_zip_filenames(filename, names):
# for name in names:
# if allowed_zip_names_re.match(name):
# continue
#
# raise Exception(f"bad file inside {filename}: {name}")
#
#
# def check_pt(filename, extra_handler):
# try:
#
# # new pytorch format is a zip file
# with zipfile.ZipFile(filename) as z:
# check_zip_filenames(filename, z.namelist())
#
# # find filename of data.pkl in zip file: '<directory name>/data.pkl'
# data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
# if len(data_pkl_filenames) == 0:
# raise Exception(f"data.pkl not found in {filename}")
# if len(data_pkl_filenames) > 1:
# raise Exception(f"Multiple data.pkl found in {filename}")
# with z.open(data_pkl_filenames[0]) as file:
# unpickler = RestrictedUnpickler(file)
# unpickler.extra_handler = extra_handler
# unpickler.load()
#
# except zipfile.BadZipfile:
#
# # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
# with open(filename, "rb") as file:
# unpickler = RestrictedUnpickler(file)
# unpickler.extra_handler = extra_handler
# for _ in range(5):
# unpickler.load()
#
#
# def load(filename, *args, **kwargs):
# return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
#
#
# def load_with_extra(filename, extra_handler=None, *args, **kwargs):
# """
# this function is intended to be used by extensions that want to load models with
# some extra classes in them that the usual unpickler would find suspicious.
#
# Use the extra_handler argument to specify a function that takes module and field name as text,
# and returns that field's value:
#
# ```python
# def extra(module, name):
# if module == 'collections' and name == 'OrderedDict':
# return collections.OrderedDict
#
# return None
#
# safe.load_with_extra('model.pt', extra_handler=extra)
# ```
#
# The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
# definitely unsafe.
# """
#
# from modules import shared
#
# try:
# if not shared.cmd_opts.disable_safe_unpickle:
# check_pt(filename, extra_handler)
#
# except pickle.UnpicklingError:
# errors.report(
# f"Error verifying pickled file from {filename}\n"
# "-----> !!!! The file is most likely corrupted !!!! <-----\n"
# "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
# exc_info=True,
# )
# return None
# except Exception:
# errors.report(
# f"Error verifying pickled file from {filename}\n"
# f"The file may be malicious, so the program is not going to read it.\n"
# f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
# exc_info=True,
# )
# return None
#
# return unsafe_torch_load(filename, *args, **kwargs)
#
#
# class Extra:
# """
# A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
# (because it's not your code making the torch.load call). The intended use is like this:
#
# ```
# import torch
# from modules import safe
#
# def handler(module, name):
# if module == 'torch' and name in ['float64', 'float16']:
# return getattr(torch, name)
#
# return None
#
# with safe.Extra(handler):
# x = torch.load('model.pt')
# ```
# """
#
# def __init__(self, handler):
# self.handler = handler
#
# def __enter__(self):
# global global_extra_handler
#
# assert global_extra_handler is None, 'already inside an Extra() block'
# global_extra_handler = self.handler
#
# def __exit__(self, exc_type, exc_val, exc_tb):
# global global_extra_handler
#
# global_extra_handler = None
#
#
# unsafe_torch_load = torch.load
# global_extra_handler = None

View File

@@ -1,232 +1,232 @@
import ldm.modules.encoders.modules
import open_clip
import torch
import transformers.utils.hub
from modules import shared
class ReplaceHelper:
def __init__(self):
self.replaced = []
def replace(self, obj, field, func):
original = getattr(obj, field, None)
if original is None:
return None
self.replaced.append((obj, field, original))
setattr(obj, field, func)
return original
def restore(self):
for obj, field, original in self.replaced:
setattr(obj, field, original)
self.replaced.clear()
class DisableInitialization(ReplaceHelper):
"""
When an object of this class enters a `with` block, it starts:
- preventing torch's layer initialization functions from working
- changes CLIP and OpenCLIP to not download model weights
- changes CLIP to not make requests to check if there is a new version of a file you already have
When it leaves the block, it reverts everything to how it was before.
Use it like this:
```
with DisableInitialization():
do_things()
```
"""
def __init__(self, disable_clip=True):
super().__init__()
self.disable_clip = disable_clip
def replace(self, obj, field, func):
original = getattr(obj, field, None)
if original is None:
return None
self.replaced.append((obj, field, original))
setattr(obj, field, func)
return original
def __enter__(self):
def do_nothing(*args, **kwargs):
pass
def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
res.name_or_path = pretrained_model_name_or_path
return res
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
# this file is always 404, prevent making request
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
return None
try:
res = original(url, *args, local_files_only=True, **kwargs)
if res is None:
res = original(url, *args, local_files_only=False, **kwargs)
return res
except Exception:
return original(url, *args, local_files_only=False, **kwargs)
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
if self.disable_clip:
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()
class InitializeOnMeta(ReplaceHelper):
"""
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
which results in those parameters having no values and taking no memory. model.to() will be broken and
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
Usage:
```
with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)
```
"""
def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
return
def set_device(x):
x["device"] = "meta"
return x
linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()
class LoadStateDictOnMeta(ReplaceHelper):
"""
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
Meant to be used together with InitializeOnMeta above.
Usage:
```
with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
model.load_state_dict(state_dict, strict=False)
```
"""
def __init__(self, state_dict, device, weight_dtype_conversion=None):
super().__init__()
self.state_dict = state_dict
self.device = device
self.weight_dtype_conversion = weight_dtype_conversion or {}
self.default_dtype = self.weight_dtype_conversion.get('')
def get_weight_dtype(self, key):
key_first_term, _ = key.split('.', 1)
return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
return
sd = self.state_dict
device = self.device
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
used_param_keys = []
for name, param in module._parameters.items():
if param is None:
continue
key = prefix + name
sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
used_param_keys.append(key)
if param.is_meta:
dtype = sd_param.dtype if sd_param is not None else param.dtype
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
for name in module._buffers:
key = prefix + name
sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param
used_param_keys.append(key)
original(module, state_dict, prefix, *args, **kwargs)
for key in used_param_keys:
state_dict.pop(key, None)
def load_state_dict(original, module, state_dict, strict=True):
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
the function and does not call the original) the state dict will just fail to load because weights
would be on the meta device.
"""
if state_dict is sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
original(module, state_dict, strict=strict)
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()
# import ldm.modules.encoders.modules
# import open_clip
# import torch
# import transformers.utils.hub
#
# from modules import shared
#
#
# class ReplaceHelper:
# def __init__(self):
# self.replaced = []
#
# def replace(self, obj, field, func):
# original = getattr(obj, field, None)
# if original is None:
# return None
#
# self.replaced.append((obj, field, original))
# setattr(obj, field, func)
#
# return original
#
# def restore(self):
# for obj, field, original in self.replaced:
# setattr(obj, field, original)
#
# self.replaced.clear()
#
#
# class DisableInitialization(ReplaceHelper):
# """
# When an object of this class enters a `with` block, it starts:
# - preventing torch's layer initialization functions from working
# - changes CLIP and OpenCLIP to not download model weights
# - changes CLIP to not make requests to check if there is a new version of a file you already have
#
# When it leaves the block, it reverts everything to how it was before.
#
# Use it like this:
# ```
# with DisableInitialization():
# do_things()
# ```
# """
#
# def __init__(self, disable_clip=True):
# super().__init__()
# self.disable_clip = disable_clip
#
# def replace(self, obj, field, func):
# original = getattr(obj, field, None)
# if original is None:
# return None
#
# self.replaced.append((obj, field, original))
# setattr(obj, field, func)
#
# return original
#
# def __enter__(self):
# def do_nothing(*args, **kwargs):
# pass
#
# def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
# return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
#
# def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
# res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
# res.name_or_path = pretrained_model_name_or_path
# return res
#
# def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
# args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
# return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
#
# def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
#
# # this file is always 404, prevent making request
# if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
# return None
#
# try:
# res = original(url, *args, local_files_only=True, **kwargs)
# if res is None:
# res = original(url, *args, local_files_only=False, **kwargs)
# return res
# except Exception:
# return original(url, *args, local_files_only=False, **kwargs)
#
# def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
# return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
#
# def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
# return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
#
# def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
# return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
#
# self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
# self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
# self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
#
# if self.disable_clip:
# self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
# self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
# self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
# self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
# self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
# self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
#
# def __exit__(self, exc_type, exc_val, exc_tb):
# self.restore()
#
#
# class InitializeOnMeta(ReplaceHelper):
# """
# Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
# which results in those parameters having no values and taking no memory. model.to() will be broken and
# will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
#
# Usage:
# ```
# with sd_disable_initialization.InitializeOnMeta():
# sd_model = instantiate_from_config(sd_config.model)
# ```
# """
#
# def __enter__(self):
# if shared.cmd_opts.disable_model_loading_ram_optimization:
# return
#
# def set_device(x):
# x["device"] = "meta"
# return x
#
# linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
# conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
# mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
# self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
#
# def __exit__(self, exc_type, exc_val, exc_tb):
# self.restore()
#
#
# class LoadStateDictOnMeta(ReplaceHelper):
# """
# Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
# As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
# Meant to be used together with InitializeOnMeta above.
#
# Usage:
# ```
# with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
# model.load_state_dict(state_dict, strict=False)
# ```
# """
#
# def __init__(self, state_dict, device, weight_dtype_conversion=None):
# super().__init__()
# self.state_dict = state_dict
# self.device = device
# self.weight_dtype_conversion = weight_dtype_conversion or {}
# self.default_dtype = self.weight_dtype_conversion.get('')
#
# def get_weight_dtype(self, key):
# key_first_term, _ = key.split('.', 1)
# return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
#
# def __enter__(self):
# if shared.cmd_opts.disable_model_loading_ram_optimization:
# return
#
# sd = self.state_dict
# device = self.device
#
# def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
# used_param_keys = []
#
# for name, param in module._parameters.items():
# if param is None:
# continue
#
# key = prefix + name
# sd_param = sd.pop(key, None)
# if sd_param is not None:
# state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
# used_param_keys.append(key)
#
# if param.is_meta:
# dtype = sd_param.dtype if sd_param is not None else param.dtype
# module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
#
# for name in module._buffers:
# key = prefix + name
#
# sd_param = sd.pop(key, None)
# if sd_param is not None:
# state_dict[key] = sd_param
# used_param_keys.append(key)
#
# original(module, state_dict, prefix, *args, **kwargs)
#
# for key in used_param_keys:
# state_dict.pop(key, None)
#
# def load_state_dict(original, module, state_dict, strict=True):
# """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
# because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
# all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
#
# In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
#
# The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
# the function and does not call the original) the state dict will just fail to load because weights
# would be on the meta device.
# """
#
# if state_dict is sd:
# state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
#
# original(module, state_dict, strict=strict)
#
# module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
# module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
# linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
# conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
# mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
# layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
# group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
#
# def __exit__(self, exc_type, exc_val, exc_tb):
# self.restore()

View File

@@ -1,124 +1,3 @@
import torch
from torch.nn.functional import silu
from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
import sgm.modules.attention
import sgm.modules.diffusionmodules.model
import sgm.modules.diffusionmodules.openaimodel
import sgm.modules.encoders.modules
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
# new memory efficient cross attention blocks do not support hypernets and we already
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2
ldm.modules.attention.print = shared.ldm_print
ldm.modules.diffusionmodules.model.print = shared.ldm_print
ldm.util.print = shared.ldm_print
ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback()
new_optimizers = [x for x in new_optimizers if x.is_available()]
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
optimizers.clear()
optimizers.extend(new_optimizers)
def apply_optimizations(option=None):
return
def undo_optimizations():
return
def fix_checkpoint():
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
checkpoints to be added when not training (there's a warning)"""
pass
def weighted_loss(sd_model, pred, target, mean=True):
#Calculate the weight normally, but ignore the mean
loss = sd_model._old_get_loss(pred, target, mean=False)
#Check if we have weights available
weight = getattr(sd_model, '_custom_loss_weight', None)
if weight is not None:
loss *= weight
#Return the loss, as mean if specified
return loss.mean() if mean else loss
def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try:
#Temporarily append weights to a place accessible during loss calc
sd_model._custom_loss_weight = w
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if not hasattr(sd_model, '_old_get_loss'):
sd_model._old_get_loss = sd_model.get_loss
sd_model.get_loss = MethodType(weighted_loss, sd_model)
#Run the standard forward function, but with the patched 'get_loss'
return sd_model.forward(x, c, *args, **kwargs)
finally:
try:
#Delete temporary weights if appended
del sd_model._custom_loss_weight
except AttributeError:
pass
#If we have an old loss function, reset the loss function to the original one
if hasattr(sd_model, '_old_get_loss'):
sd_model.get_loss = sd_model._old_get_loss
del sd_model._old_get_loss
def apply_weighted_forward(sd_model):
#Add new function 'weighted_forward' that can be called to calc weighted loss
sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
def undo_weighted_forward(sd_model):
try:
del sd_model.weighted_forward
except AttributeError:
pass
class StableDiffusionModelHijack:
fixes = None
layers = None
@@ -156,74 +35,234 @@ class StableDiffusionModelHijack:
pass
class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
self.textual_inversion_key = textual_inversion_key
self.weight = self.wrapped.weight
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
self.embeddings.fixes = None
inputs_embeds = self.wrapped(input_ids)
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
return inputs_embeds
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
emb = devices.cond_cast_unet(vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
vecs.append(tensor)
return torch.stack(vecs)
class TextualInversionEmbeddings(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
super().__init__(num_embeddings, embedding_dim, **kwargs)
self.embeddings = model_hijack
self.textual_inversion_key = textual_inversion_key
@property
def wrapped(self):
return super().forward
def forward(self, input_ids):
return EmbeddingsWithFixes.forward(self, input_ids)
def add_circular_option_to_conv_2d():
conv2d_constructor = torch.nn.Conv2d.__init__
def conv2d_constructor_circular(self, *args, **kwargs):
return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
torch.nn.Conv2d.__init__ = conv2d_constructor_circular
model_hijack = StableDiffusionModelHijack()
def register_buffer(self, name, attr):
"""
Fix register buffer bug for Mac OS.
"""
if type(attr) == torch.Tensor:
if attr.device != devices.device:
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
setattr(self, name, attr)
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
# import torch
# from torch.nn.functional import silu
# from types import MethodType
#
# from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
# from modules.hypernetworks import hypernetwork
# from modules.shared import cmd_opts
# from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
#
# import ldm.modules.attention
# import ldm.modules.diffusionmodules.model
# import ldm.modules.diffusionmodules.openaimodel
# import ldm.models.diffusion.ddpm
# import ldm.models.diffusion.ddim
# import ldm.models.diffusion.plms
# import ldm.modules.encoders.modules
#
# import sgm.modules.attention
# import sgm.modules.diffusionmodules.model
# import sgm.modules.diffusionmodules.openaimodel
# import sgm.modules.encoders.modules
#
# attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
# diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
#
# # new memory efficient cross attention blocks do not support hypernets and we already
# # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
# ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
#
# # silence new console spam from SD2
# ldm.modules.attention.print = shared.ldm_print
# ldm.modules.diffusionmodules.model.print = shared.ldm_print
# ldm.util.print = shared.ldm_print
# ldm.models.diffusion.ddpm.print = shared.ldm_print
#
# optimizers = []
# current_optimizer: sd_hijack_optimizations.SdOptimization = None
#
# ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
# ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
#
# sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
# sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
#
#
# def list_optimizers():
# new_optimizers = script_callbacks.list_optimizers_callback()
#
# new_optimizers = [x for x in new_optimizers if x.is_available()]
#
# new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
#
# optimizers.clear()
# optimizers.extend(new_optimizers)
#
#
# def apply_optimizations(option=None):
# return
#
#
# def undo_optimizations():
# return
#
#
# def fix_checkpoint():
# """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
# checkpoints to be added when not training (there's a warning)"""
#
# pass
#
#
# def weighted_loss(sd_model, pred, target, mean=True):
# #Calculate the weight normally, but ignore the mean
# loss = sd_model._old_get_loss(pred, target, mean=False)
#
# #Check if we have weights available
# weight = getattr(sd_model, '_custom_loss_weight', None)
# if weight is not None:
# loss *= weight
#
# #Return the loss, as mean if specified
# return loss.mean() if mean else loss
#
# def weighted_forward(sd_model, x, c, w, *args, **kwargs):
# try:
# #Temporarily append weights to a place accessible during loss calc
# sd_model._custom_loss_weight = w
#
# #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
# #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
# if not hasattr(sd_model, '_old_get_loss'):
# sd_model._old_get_loss = sd_model.get_loss
# sd_model.get_loss = MethodType(weighted_loss, sd_model)
#
# #Run the standard forward function, but with the patched 'get_loss'
# return sd_model.forward(x, c, *args, **kwargs)
# finally:
# try:
# #Delete temporary weights if appended
# del sd_model._custom_loss_weight
# except AttributeError:
# pass
#
# #If we have an old loss function, reset the loss function to the original one
# if hasattr(sd_model, '_old_get_loss'):
# sd_model.get_loss = sd_model._old_get_loss
# del sd_model._old_get_loss
#
# def apply_weighted_forward(sd_model):
# #Add new function 'weighted_forward' that can be called to calc weighted loss
# sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
#
# def undo_weighted_forward(sd_model):
# try:
# del sd_model.weighted_forward
# except AttributeError:
# pass
#
#
# class StableDiffusionModelHijack:
# fixes = None
# layers = None
# circular_enabled = False
# clip = None
# optimization_method = None
#
# def __init__(self):
# self.extra_generation_params = {}
# self.comments = []
#
# def apply_optimizations(self, option=None):
# pass
#
# def convert_sdxl_to_ssd(self, m):
# pass
#
# def hijack(self, m):
# pass
#
# def undo_hijack(self, m):
# pass
#
# def apply_circular(self, enable):
# pass
#
# def clear_comments(self):
# self.comments = []
# self.extra_generation_params = {}
#
# def get_prompt_lengths(self, text, cond_stage_model):
# pass
#
# def redo_hijack(self, m):
# pass
#
#
# class EmbeddingsWithFixes(torch.nn.Module):
# def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
# super().__init__()
# self.wrapped = wrapped
# self.embeddings = embeddings
# self.textual_inversion_key = textual_inversion_key
# self.weight = self.wrapped.weight
#
# def forward(self, input_ids):
# batch_fixes = self.embeddings.fixes
# self.embeddings.fixes = None
#
# inputs_embeds = self.wrapped(input_ids)
#
# if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
# return inputs_embeds
#
# vecs = []
# for fixes, tensor in zip(batch_fixes, inputs_embeds):
# for offset, embedding in fixes:
# vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
# emb = devices.cond_cast_unet(vec)
# emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
# tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
#
# vecs.append(tensor)
#
# return torch.stack(vecs)
#
#
# class TextualInversionEmbeddings(torch.nn.Embedding):
# def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
# super().__init__(num_embeddings, embedding_dim, **kwargs)
#
# self.embeddings = model_hijack
# self.textual_inversion_key = textual_inversion_key
#
# @property
# def wrapped(self):
# return super().forward
#
# def forward(self, input_ids):
# return EmbeddingsWithFixes.forward(self, input_ids)
#
#
# def add_circular_option_to_conv_2d():
# conv2d_constructor = torch.nn.Conv2d.__init__
#
# def conv2d_constructor_circular(self, *args, **kwargs):
# return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
#
# torch.nn.Conv2d.__init__ = conv2d_constructor_circular
#
#
#
#
#
# def register_buffer(self, name, attr):
# """
# Fix register buffer bug for Mac OS.
# """
#
# if type(attr) == torch.Tensor:
# if attr.device != devices.device:
# attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
#
# setattr(self, name, attr)
#
#
# ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
# ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer

View File

@@ -1,46 +1,46 @@
from torch.utils.checkpoint import checkpoint
import ldm.modules.attention
import ldm.modules.diffusionmodules.openaimodel
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)
stored = []
def add():
if len(stored) != 0:
return
stored.extend([
ldm.modules.attention.BasicTransformerBlock.forward,
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
])
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
def remove():
if len(stored) == 0:
return
ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
stored.clear()
# from torch.utils.checkpoint import checkpoint
#
# import ldm.modules.attention
# import ldm.modules.diffusionmodules.openaimodel
#
#
# def BasicTransformerBlock_forward(self, x, context=None):
# return checkpoint(self._forward, x, context)
#
#
# def AttentionBlock_forward(self, x):
# return checkpoint(self._forward, x)
#
#
# def ResBlock_forward(self, x, emb):
# return checkpoint(self._forward, x, emb)
#
#
# stored = []
#
#
# def add():
# if len(stored) != 0:
# return
#
# stored.extend([
# ldm.modules.attention.BasicTransformerBlock.forward,
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
# ])
#
# ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
#
#
# def remove():
# if len(stored) == 0:
# return
#
# ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
#
# stored.clear()
#

File diff suppressed because it is too large Load Diff

View File

@@ -1,154 +1,154 @@
import torch
from packaging import version
from einops import repeat
import math
from modules import devices
from modules.sd_hijack_utils import CondFunc
class TorchHijackForUnet:
"""
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
"""
def __getattr__(self, item):
if item == 'cat':
return self.cat
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def cat(self, tensors, *args, **kwargs):
if len(tensors) == 2:
a, b = tensors
if a.shape[-2:] != b.shape[-2:]:
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
tensors = (a, b)
return torch.cat(tensors, *args, **kwargs)
th = TorchHijackForUnet()
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
"""Always make sure inputs to unet are in correct dtype."""
if isinstance(cond, dict):
for y in cond.keys():
if isinstance(cond[y], list):
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
else:
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
with devices.autocast():
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
if devices.unet_needs_upcast:
return result.float()
else:
return result
# Monkey patch to create timestep embed tensor on device, avoiding a block.
def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
# Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
# Prevents a lot of unnecessary aten::copy_ calls
def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
class GELUHijack(torch.nn.GELU, torch.nn.Module):
def __init__(self, *args, **kwargs):
torch.nn.GELU.__init__(self, *args, **kwargs)
def forward(self, x):
if devices.unet_needs_upcast:
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
else:
return torch.nn.GELU.forward(self, x)
ddpm_edit_hijack = None
def hijack_ddpm_edit():
global ddpm_edit_hijack
if not ddpm_edit_hijack:
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
dtype = torch.float32
else:
dtype = devices.dtype_unet
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
# import torch
# from packaging import version
# from einops import repeat
# import math
#
# from modules import devices
# from modules.sd_hijack_utils import CondFunc
#
#
# class TorchHijackForUnet:
# """
# This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
# this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
# """
#
# def __getattr__(self, item):
# if item == 'cat':
# return self.cat
#
# if hasattr(torch, item):
# return getattr(torch, item)
#
# raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
#
# def cat(self, tensors, *args, **kwargs):
# if len(tensors) == 2:
# a, b = tensors
# if a.shape[-2:] != b.shape[-2:]:
# a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
#
# tensors = (a, b)
#
# return torch.cat(tensors, *args, **kwargs)
#
#
# th = TorchHijackForUnet()
#
#
# # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
# def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
# """Always make sure inputs to unet are in correct dtype."""
# if isinstance(cond, dict):
# for y in cond.keys():
# if isinstance(cond[y], list):
# cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
# else:
# cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
#
# with devices.autocast():
# result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
# if devices.unet_needs_upcast:
# return result.float()
# else:
# return result
#
#
# # Monkey patch to create timestep embed tensor on device, avoiding a block.
# def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
# """
# Create sinusoidal timestep embeddings.
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# if not repeat_only:
# half = dim // 2
# freqs = torch.exp(
# -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
# )
# args = timesteps[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# else:
# embedding = repeat(timesteps, 'b -> b d', d=dim)
# return embedding
#
#
# # Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
# # Prevents a lot of unnecessary aten::copy_ calls
# def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
# # note: if no context is given, cross-attention defaults to self-attention
# if not isinstance(context, list):
# context = [context]
# b, c, h, w = x.shape
# x_in = x
# x = self.norm(x)
# if not self.use_linear:
# x = self.proj_in(x)
# x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
# if self.use_linear:
# x = self.proj_in(x)
# for i, block in enumerate(self.transformer_blocks):
# x = block(x, context=context[i])
# if self.use_linear:
# x = self.proj_out(x)
# x = x.view(b, h, w, c).permute(0, 3, 1, 2)
# if not self.use_linear:
# x = self.proj_out(x)
# return x + x_in
#
#
# class GELUHijack(torch.nn.GELU, torch.nn.Module):
# def __init__(self, *args, **kwargs):
# torch.nn.GELU.__init__(self, *args, **kwargs)
# def forward(self, x):
# if devices.unet_needs_upcast:
# return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
# else:
# return torch.nn.GELU.forward(self, x)
#
#
# ddpm_edit_hijack = None
# def hijack_ddpm_edit():
# global ddpm_edit_hijack
# if not ddpm_edit_hijack:
# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
# ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
#
#
# unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
# CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
#
# if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
# CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
# CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
# CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
#
# first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
# first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
#
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
# CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
#
#
# def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
# if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
# dtype = torch.float32
# else:
# dtype = devices.dtype_unet
# return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
#
#
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
# CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)

View File

@@ -10,7 +10,6 @@ import re
import safetensors.torch
from omegaconf import OmegaConf, ListConfig
from urllib import request
import ldm.modules.midas as midas
import gc
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
@@ -415,89 +414,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
def enable_midas_autodownload():
"""
Gives the ldm.modules.midas.api.load_model function automatic downloading.
When the 512-depth-ema model, and other future models like it, is loaded,
it calls midas.api.load_model to load the associated midas depth model.
This function applies a wrapper to download the model to the correct
location automatically.
"""
midas_path = os.path.join(paths.models_path, 'midas')
# stable-diffusion-stability-ai hard-codes the midas model path to
# a location that differs from where other scripts using this model look.
# HACK: Overriding the path here.
for k, v in midas.api.ISL_PATHS.items():
file_name = os.path.basename(v)
midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
midas_urls = {
"dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
"midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
"midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
}
midas.api.load_model_inner = midas.api.load_model
def load_model_wrapper(model_type):
path = midas.api.ISL_PATHS[model_type]
if not os.path.exists(path):
if not os.path.exists(midas_path):
os.mkdir(midas_path)
print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path)
print(f"{model_type} downloaded")
return midas.api.load_model_inner(model_type)
midas.api.load_model = load_model_wrapper
pass
def patch_given_betas():
import ldm.models.diffusion.ddpm
def patched_register_schedule(*args, **kwargs):
"""a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
if isinstance(args[1], ListConfig):
args = (args[0], np.array(args[1]), *args[2:])
original_register_schedule(*args, **kwargs)
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
pass
def repair_config(sd_config, state_dict=None):
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
if hasattr(sd_config.model.params, 'unet_config'):
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
sd_config.model.params.unet_config.params.use_fp16 = True
if hasattr(sd_config.model.params, 'first_stage_config'):
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
# For UnCLIP-L, override the hardcoded karlo directory
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
# Do not use checkpoint for inference.
# This helps prevent extra performance overhead on checking parameters.
# The perf overhead is about 100ms/it on 4090 for SDXL.
if hasattr(sd_config.model.params, "network_config"):
sd_config.model.params.network_config.params.use_checkpoint = False
if hasattr(sd_config.model.params, "unet_config"):
sd_config.model.params.unet_config.params.use_checkpoint = False
pass
def rescale_zero_terminal_snr_abar(alphas_cumprod):

View File

@@ -1,137 +1,137 @@
import os
import torch
from modules import shared, paths, sd_disable_initialization, devices
sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
config_default = shared.sd_default_config
# config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict):
"""
Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
"""
import ldm.modules.diffusionmodules.openaimodel
device = devices.device
with sd_disable_initialization.DisableInitialization():
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
use_checkpoint=False,
use_fp16=False,
image_size=32,
in_channels=4,
out_channels=4,
model_channels=320,
attention_resolutions=[4, 2, 1],
num_res_blocks=2,
channel_mult=[1, 2, 4, 4],
num_head_channels=64,
use_spatial_transformer=True,
use_linear_in_transformer=True,
transformer_depth=1,
context_dim=1024,
legacy=False
)
unet.eval()
with torch.no_grad():
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
unet.load_state_dict(unet_sd, strict=True)
unet.to(device=device, dtype=devices.dtype_unet)
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
with devices.autocast():
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
return out < -1
def guess_model_config_from_state_dict(sd, filename):
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if "model.diffusion_model.x_embedder.proj.weight" in sd:
return config_sd3
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
if diffusion_model_input.shape[1] == 9:
return config_sdxl_inpainting
else:
return config_sdxl
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
return config_unclip
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
return config_unopenclip
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
if diffusion_model_input.shape[1] == 9:
return config_sd2_inpainting
# elif is_using_v_parameterization_for_sd2(sd):
# return config_sd2v
else:
return config_sd2v
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
return config_inpainting
if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
return config_alt_diffusion_m18
return config_alt_diffusion
return config_default
def find_checkpoint_config(state_dict, info):
if info is None:
return guess_model_config_from_state_dict(state_dict, "")
config = find_checkpoint_config_near_filename(info)
if config is not None:
return config
return guess_model_config_from_state_dict(state_dict, info.filename)
def find_checkpoint_config_near_filename(info):
if info is None:
return None
config = f"{os.path.splitext(info.filename)[0]}.yaml"
if os.path.exists(config):
return config
return None
# import os
#
# import torch
#
# from modules import shared, paths, sd_disable_initialization, devices
#
# sd_configs_path = shared.sd_configs_path
# # sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
# # sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
#
#
# config_default = shared.sd_default_config
# # config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
# config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
# config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
# config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
# config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
# config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
# config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
# config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
# config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
# config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
# config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
# config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
# config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
# config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
#
#
# def is_using_v_parameterization_for_sd2(state_dict):
# """
# Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
# """
#
# import ldm.modules.diffusionmodules.openaimodel
#
# device = devices.device
#
# with sd_disable_initialization.DisableInitialization():
# unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
# use_checkpoint=False,
# use_fp16=False,
# image_size=32,
# in_channels=4,
# out_channels=4,
# model_channels=320,
# attention_resolutions=[4, 2, 1],
# num_res_blocks=2,
# channel_mult=[1, 2, 4, 4],
# num_head_channels=64,
# use_spatial_transformer=True,
# use_linear_in_transformer=True,
# transformer_depth=1,
# context_dim=1024,
# legacy=False
# )
# unet.eval()
#
# with torch.no_grad():
# unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
# unet.load_state_dict(unet_sd, strict=True)
# unet.to(device=device, dtype=devices.dtype_unet)
#
# test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
# x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
#
# with devices.autocast():
# out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
#
# return out < -1
#
#
# def guess_model_config_from_state_dict(sd, filename):
# sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
# diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
# sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
#
# if "model.diffusion_model.x_embedder.proj.weight" in sd:
# return config_sd3
#
# if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
# if diffusion_model_input.shape[1] == 9:
# return config_sdxl_inpainting
# else:
# return config_sdxl
#
# if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
# return config_sdxl_refiner
# elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
# return config_depth_model
# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
# return config_unclip
# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
# return config_unopenclip
#
# if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
# if diffusion_model_input.shape[1] == 9:
# return config_sd2_inpainting
# # elif is_using_v_parameterization_for_sd2(sd):
# # return config_sd2v
# else:
# return config_sd2v
#
# if diffusion_model_input is not None:
# if diffusion_model_input.shape[1] == 9:
# return config_inpainting
# if diffusion_model_input.shape[1] == 8:
# return config_instruct_pix2pix
#
# if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
# if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
# return config_alt_diffusion_m18
# return config_alt_diffusion
#
# return config_default
#
#
# def find_checkpoint_config(state_dict, info):
# if info is None:
# return guess_model_config_from_state_dict(state_dict, "")
#
# config = find_checkpoint_config_near_filename(info)
# if config is not None:
# return config
#
# return guess_model_config_from_state_dict(state_dict, info.filename)
#
#
# def find_checkpoint_config_near_filename(info):
# if info is None:
# return None
#
# config = f"{os.path.splitext(info.filename)[0]}.yaml"
# if os.path.exists(config):
# return config
#
# return None
#

View File

@@ -1,4 +1,3 @@
from ldm.models.diffusion.ddpm import LatentDiffusion
from typing import TYPE_CHECKING
@@ -6,7 +5,7 @@ if TYPE_CHECKING:
from modules.sd_models import CheckpointInfo
class WebuiSdModel(LatentDiffusion):
class WebuiSdModel:
"""This class is not actually instantinated, but its fields are created and fieeld by webui"""
lowvram: bool

View File

@@ -1,115 +1,115 @@
from __future__ import annotations
import torch
import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser
from modules import torch_utils
from backend import memory_management
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
for embedder in self.conditioner.embedders:
embedder.ucg_rate = 0.0
width = getattr(batch, 'width', 1024) or 1024
height = getattr(batch, 'height', 1024) or 1024
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
sdxl_conds = {
"txt": batch,
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
}
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
return c
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
if self.model.diffusion_model.in_channels == 9:
x = torch.cat([x] + cond['c_concat'], dim=1)
return self.model(x, t, cond, *args, **kwargs)
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
return x
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
res = []
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
encoded = embedder.encode_embedding_init_text(init_text, nvpt)
res.append(encoded)
return torch.cat(res, dim=1)
def tokenize(self: sgm.modules.GeneralConditioner, texts):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
return embedder.tokenize(texts)
raise AssertionError('no tokenizer available')
def process_texts(self, texts):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
return embedder.process_texts(texts)
def get_target_prompt_token_count(self, token_count):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
return embedder.get_target_prompt_token_count(token_count)
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
sgm.modules.GeneralConditioner.tokenize = tokenize
sgm.modules.GeneralConditioner.process_texts = process_texts
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
dtype = torch_utils.get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt'
# model.cond_stage_model will be set in sd_hijack
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
model.conditioner.wrapped = torch.nn.Module()
sgm.modules.attention.print = shared.ldm_print
sgm.modules.diffusionmodules.model.print = shared.ldm_print
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
sgm.modules.encoders.modules.print = shared.ldm_print
# this gets the code to load the vanilla attention that we override
sgm.modules.attention.SDP_IS_AVAILABLE = True
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
# from __future__ import annotations
#
# import torch
#
# import sgm.models.diffusion
# import sgm.modules.diffusionmodules.denoiser_scaling
# import sgm.modules.diffusionmodules.discretizer
# from modules import devices, shared, prompt_parser
# from modules import torch_utils
#
# from backend import memory_management
#
#
# def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
#
# for embedder in self.conditioner.embedders:
# embedder.ucg_rate = 0.0
#
# width = getattr(batch, 'width', 1024) or 1024
# height = getattr(batch, 'height', 1024) or 1024
# is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
# aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
#
# devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
#
# sdxl_conds = {
# "txt": batch,
# "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
# "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
# "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
# "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
# }
#
# force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
# c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
#
# return c
#
#
# def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
# if self.model.diffusion_model.in_channels == 9:
# x = torch.cat([x] + cond['c_concat'], dim=1)
#
# return self.model(x, t, cond, *args, **kwargs)
#
#
# def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
# return x
#
#
# sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
# sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
# sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
#
#
# def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
# res = []
#
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
# encoded = embedder.encode_embedding_init_text(init_text, nvpt)
# res.append(encoded)
#
# return torch.cat(res, dim=1)
#
#
# def tokenize(self: sgm.modules.GeneralConditioner, texts):
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
# return embedder.tokenize(texts)
#
# raise AssertionError('no tokenizer available')
#
#
#
# def process_texts(self, texts):
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
# return embedder.process_texts(texts)
#
#
# def get_target_prompt_token_count(self, token_count):
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
# return embedder.get_target_prompt_token_count(token_count)
#
#
# # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
# sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
# sgm.modules.GeneralConditioner.tokenize = tokenize
# sgm.modules.GeneralConditioner.process_texts = process_texts
# sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
#
#
# def extend_sdxl(model):
# """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
#
# dtype = torch_utils.get_param(model.model.diffusion_model).dtype
# model.model.diffusion_model.dtype = dtype
# model.model.conditioning_key = 'crossattn'
# model.cond_stage_key = 'txt'
# # model.cond_stage_model will be set in sd_hijack
#
# model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
#
# discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
# model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
#
# model.conditioner.wrapped = torch.nn.Module()
#
#
# sgm.modules.attention.print = shared.ldm_print
# sgm.modules.diffusionmodules.model.print = shared.ldm_print
# sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
# sgm.modules.encoders.modules.print = shared.ldm_print
#
# # this gets the code to load the vanilla attention that we override
# sgm.modules.attention.SDP_IS_AVAILABLE = True
# sgm.modules.attention.XFORMERS_IS_AVAILABLE = False

View File

@@ -35,9 +35,7 @@ def refresh_vae_list():
def cross_attention_optimizations():
import modules.sd_hijack
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
return ["Automatic"]
def sd_unet_items():

View File

@@ -1,245 +1,243 @@
import os
import numpy as np
import PIL
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
from collections import defaultdict
from random import shuffle, choices
import random
import tqdm
from modules import devices, shared, images
import re
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
self.filename = filename
self.filename_text = filename_text
self.weight = weight
self.latent_dist = latent_dist
self.latent_sample = latent_sample
self.cond = cond
self.cond_text = cond_text
self.pixel_values = pixel_values
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
self.placeholder_token = placeholder_token
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
with open(template_file, "r") as file:
lines = [x.strip() for x in file.readlines()]
self.lines = lines
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
groups = defaultdict(list)
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
alpha_channel = None
if shared.state.interrupted:
raise Exception("interrupted")
try:
image = images.read(path)
#Currently does not work for single color transparency
#We would need to read image.info['transparency'] for that
if use_weight and 'A' in image.getbands():
alpha_channel = image.getchannel('A')
image = image.convert('RGB')
if not varsize:
image = image.resize((width, height), PIL.Image.BICUBIC)
except Exception:
continue
text_filename = f"{os.path.splitext(path)[0]}.txt"
filename = os.path.basename(path)
if os.path.exists(text_filename):
with open(text_filename, "r", encoding="utf8") as file:
filename_text = file.read()
else:
filename_text = os.path.splitext(filename)[0]
filename_text = re.sub(re_numbers_at_start, '', filename_text)
if re_word:
tokens = re_word.findall(filename_text)
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
latent_sample = None
with devices.autocast():
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
#Perform latent sampling, even for random sampling.
#We need the sample dimensions for the weights
if latent_sampling_method == "deterministic":
if isinstance(latent_dist, DiagonalGaussianDistribution):
# Works only for DiagonalGaussianDistribution
latent_dist.std = 0
else:
latent_sampling_method = "once"
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
if use_weight and alpha_channel is not None:
channels, *latent_size = latent_sample.shape
weight_img = alpha_channel.resize(latent_size)
npweight = np.array(weight_img).astype(np.float32)
#Repeat for every channel in the latent sample
weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
#Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
weight -= weight.min()
weight /= weight.mean()
elif use_weight:
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
weight = torch.ones(latent_sample.shape)
else:
weight = None
if latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
else:
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text)
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
groups[image.size].append(len(self.dataset))
self.dataset.append(entry)
del torchdata
del latent_dist
del latent_sample
del weight
self.length = len(self.dataset)
self.groups = list(groups.values())
assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.latent_sampling_method = latent_sampling_method
if len(groups) > 1:
print("Buckets:")
for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
print(f" {w}x{h}: {len(ids)}")
print()
def create_text(self, filename_text):
text = random.choice(self.lines)
tags = filename_text.split(',')
if self.tag_drop_out != 0:
tags = [t for t in tags if random.random() > self.tag_drop_out]
if self.shuffle_tags:
random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags))
text = text.replace("[name]", self.placeholder_token)
return text
def __len__(self):
return self.length
def __getitem__(self, i):
entry = self.dataset[i]
if self.tag_drop_out != 0 or self.shuffle_tags:
entry.cond_text = self.create_text(entry.filename_text)
if self.latent_sampling_method == "random":
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
class GroupedBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int):
super().__init__(data_source)
n = len(data_source)
self.groups = data_source.groups
self.len = n_batch = n // batch_size
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
self.base = [int(e) // batch_size for e in expected]
self.n_rand_batches = nrb = n_batch - sum(self.base)
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.batch_size = batch_size
def __len__(self):
return self.len
def __iter__(self):
b = self.batch_size
for g in self.groups:
shuffle(g)
batches = []
for g in self.groups:
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for _ in range(self.n_rand_batches):
rand_group = choices(self.groups, self.probs)[0]
batches.append(choices(rand_group, k=b))
shuffle(batches)
yield from batches
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else:
self.collate_fn = collate_wrapper
class BatchLoader:
def __init__(self, data):
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
if all(entry.weight is not None for entry in data):
self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
else:
self.weight = None
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory()
return self
def collate_wrapper(batch):
return BatchLoader(batch)
class BatchLoaderRandom(BatchLoader):
def __init__(self, data):
super().__init__(data)
def pin_memory(self):
return self
def collate_wrapper_random(batch):
return BatchLoaderRandom(batch)
# import os
# import numpy as np
# import PIL
# import torch
# from torch.utils.data import Dataset, DataLoader, Sampler
# from torchvision import transforms
# from collections import defaultdict
# from random import shuffle, choices
#
# import random
# import tqdm
# from modules import devices, shared, images
# import re
#
# re_numbers_at_start = re.compile(r"^[-\d]+\s*")
#
#
# class DatasetEntry:
# def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
# self.filename = filename
# self.filename_text = filename_text
# self.weight = weight
# self.latent_dist = latent_dist
# self.latent_sample = latent_sample
# self.cond = cond
# self.cond_text = cond_text
# self.pixel_values = pixel_values
#
#
# class PersonalizedBase(Dataset):
# def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
# re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
#
# self.placeholder_token = placeholder_token
#
# self.flip = transforms.RandomHorizontalFlip(p=flip_p)
#
# self.dataset = []
#
# with open(template_file, "r") as file:
# lines = [x.strip() for x in file.readlines()]
#
# self.lines = lines
#
# assert data_root, 'dataset directory not specified'
# assert os.path.isdir(data_root), "Dataset directory doesn't exist"
# assert os.listdir(data_root), "Dataset directory is empty"
#
# self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
#
# self.shuffle_tags = shuffle_tags
# self.tag_drop_out = tag_drop_out
# groups = defaultdict(list)
#
# print("Preparing dataset...")
# for path in tqdm.tqdm(self.image_paths):
# alpha_channel = None
# if shared.state.interrupted:
# raise Exception("interrupted")
# try:
# image = images.read(path)
# #Currently does not work for single color transparency
# #We would need to read image.info['transparency'] for that
# if use_weight and 'A' in image.getbands():
# alpha_channel = image.getchannel('A')
# image = image.convert('RGB')
# if not varsize:
# image = image.resize((width, height), PIL.Image.BICUBIC)
# except Exception:
# continue
#
# text_filename = f"{os.path.splitext(path)[0]}.txt"
# filename = os.path.basename(path)
#
# if os.path.exists(text_filename):
# with open(text_filename, "r", encoding="utf8") as file:
# filename_text = file.read()
# else:
# filename_text = os.path.splitext(filename)[0]
# filename_text = re.sub(re_numbers_at_start, '', filename_text)
# if re_word:
# tokens = re_word.findall(filename_text)
# filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
#
# npimage = np.array(image).astype(np.uint8)
# npimage = (npimage / 127.5 - 1.0).astype(np.float32)
#
# torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
# latent_sample = None
#
# with devices.autocast():
# latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
#
# #Perform latent sampling, even for random sampling.
# #We need the sample dimensions for the weights
# if latent_sampling_method == "deterministic":
# if isinstance(latent_dist, DiagonalGaussianDistribution):
# # Works only for DiagonalGaussianDistribution
# latent_dist.std = 0
# else:
# latent_sampling_method = "once"
# latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
#
# if use_weight and alpha_channel is not None:
# channels, *latent_size = latent_sample.shape
# weight_img = alpha_channel.resize(latent_size)
# npweight = np.array(weight_img).astype(np.float32)
# #Repeat for every channel in the latent sample
# weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
# #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
# weight -= weight.min()
# weight /= weight.mean()
# elif use_weight:
# #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
# weight = torch.ones(latent_sample.shape)
# else:
# weight = None
#
# if latent_sampling_method == "random":
# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
# else:
# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
#
# if not (self.tag_drop_out != 0 or self.shuffle_tags):
# entry.cond_text = self.create_text(filename_text)
#
# if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
# with devices.autocast():
# entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
# groups[image.size].append(len(self.dataset))
# self.dataset.append(entry)
# del torchdata
# del latent_dist
# del latent_sample
# del weight
#
# self.length = len(self.dataset)
# self.groups = list(groups.values())
# assert self.length > 0, "No images have been found in the dataset."
# self.batch_size = min(batch_size, self.length)
# self.gradient_step = min(gradient_step, self.length // self.batch_size)
# self.latent_sampling_method = latent_sampling_method
#
# if len(groups) > 1:
# print("Buckets:")
# for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
# print(f" {w}x{h}: {len(ids)}")
# print()
#
# def create_text(self, filename_text):
# text = random.choice(self.lines)
# tags = filename_text.split(',')
# if self.tag_drop_out != 0:
# tags = [t for t in tags if random.random() > self.tag_drop_out]
# if self.shuffle_tags:
# random.shuffle(tags)
# text = text.replace("[filewords]", ','.join(tags))
# text = text.replace("[name]", self.placeholder_token)
# return text
#
# def __len__(self):
# return self.length
#
# def __getitem__(self, i):
# entry = self.dataset[i]
# if self.tag_drop_out != 0 or self.shuffle_tags:
# entry.cond_text = self.create_text(entry.filename_text)
# if self.latent_sampling_method == "random":
# entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
# return entry
#
#
# class GroupedBatchSampler(Sampler):
# def __init__(self, data_source: PersonalizedBase, batch_size: int):
# super().__init__(data_source)
#
# n = len(data_source)
# self.groups = data_source.groups
# self.len = n_batch = n // batch_size
# expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
# self.base = [int(e) // batch_size for e in expected]
# self.n_rand_batches = nrb = n_batch - sum(self.base)
# self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
# self.batch_size = batch_size
#
# def __len__(self):
# return self.len
#
# def __iter__(self):
# b = self.batch_size
#
# for g in self.groups:
# shuffle(g)
#
# batches = []
# for g in self.groups:
# batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
# for _ in range(self.n_rand_batches):
# rand_group = choices(self.groups, self.probs)[0]
# batches.append(choices(rand_group, k=b))
#
# shuffle(batches)
#
# yield from batches
#
#
# class PersonalizedDataLoader(DataLoader):
# def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
# super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
# if latent_sampling_method == "random":
# self.collate_fn = collate_wrapper_random
# else:
# self.collate_fn = collate_wrapper
#
#
# class BatchLoader:
# def __init__(self, data):
# self.cond_text = [entry.cond_text for entry in data]
# self.cond = [entry.cond for entry in data]
# self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
# if all(entry.weight is not None for entry in data):
# self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
# else:
# self.weight = None
# #self.emb_index = [entry.emb_index for entry in data]
# #print(self.latent_sample.device)
#
# def pin_memory(self):
# self.latent_sample = self.latent_sample.pin_memory()
# return self
#
# def collate_wrapper(batch):
# return BatchLoader(batch)
#
# class BatchLoaderRandom(BatchLoader):
# def __init__(self, data):
# super().__init__(data)
#
# def pin_memory(self):
# return self
#
# def collate_wrapper_random(batch):
# return BatchLoaderRandom(batch)