mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-11 02:19:59 +00:00
Free WebUI from its Prison
Congratulations WebUI. Say Hello to freedom.
This commit is contained in:
@@ -1,250 +0,0 @@
|
||||
import os
|
||||
import gc
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import OmegaConf
|
||||
import safetensors.torch
|
||||
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import instantiate_from_config, ismap
|
||||
from modules import shared, sd_hijack, devices
|
||||
|
||||
cached_ldsr_model: torch.nn.Module = None
|
||||
|
||||
|
||||
# Create LDSR Class
|
||||
class LDSR:
|
||||
def load_model_from_config(self, half_attention):
|
||||
global cached_ldsr_model
|
||||
|
||||
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
|
||||
print("Loading model from cache")
|
||||
model: torch.nn.Module = cached_ldsr_model
|
||||
else:
|
||||
print(f"Loading model from {self.modelPath}")
|
||||
_, extension = os.path.splitext(self.modelPath)
|
||||
if extension.lower() == ".safetensors":
|
||||
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
|
||||
else:
|
||||
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
||||
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
|
||||
config = OmegaConf.load(self.yamlPath)
|
||||
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
|
||||
model: torch.nn.Module = instantiate_from_config(config.model)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
model = model.to(shared.device)
|
||||
if half_attention:
|
||||
model = model.half()
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
|
||||
sd_hijack.model_hijack.hijack(model) # apply optimization
|
||||
model.eval()
|
||||
|
||||
if shared.opts.ldsr_cached:
|
||||
cached_ldsr_model = model
|
||||
|
||||
return {"model": model}
|
||||
|
||||
def __init__(self, model_path, yaml_path):
|
||||
self.modelPath = model_path
|
||||
self.yamlPath = yaml_path
|
||||
|
||||
@staticmethod
|
||||
def run(model, selected_path, custom_steps, eta):
|
||||
example = get_cond(selected_path)
|
||||
|
||||
n_runs = 1
|
||||
guider = None
|
||||
ckwargs = None
|
||||
ddim_use_x0_pred = False
|
||||
temperature = 1.
|
||||
eta = eta
|
||||
custom_shape = None
|
||||
|
||||
height, width = example["image"].shape[1:3]
|
||||
split_input = height >= 128 and width >= 128
|
||||
|
||||
if split_input:
|
||||
ks = 128
|
||||
stride = 64
|
||||
vqf = 4 #
|
||||
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
|
||||
"vqf": vqf,
|
||||
"patch_distributed_vq": True,
|
||||
"tie_braker": False,
|
||||
"clip_max_weight": 0.5,
|
||||
"clip_min_weight": 0.01,
|
||||
"clip_max_tie_weight": 0.5,
|
||||
"clip_min_tie_weight": 0.01}
|
||||
else:
|
||||
if hasattr(model, "split_input_params"):
|
||||
delattr(model, "split_input_params")
|
||||
|
||||
x_t = None
|
||||
logs = None
|
||||
for _ in range(n_runs):
|
||||
if custom_shape is not None:
|
||||
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
||||
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
||||
|
||||
logs = make_convolutional_sample(example, model,
|
||||
custom_steps=custom_steps,
|
||||
eta=eta, quantize_x0=False,
|
||||
custom_shape=custom_shape,
|
||||
temperature=temperature, noise_dropout=0.,
|
||||
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
|
||||
ddim_use_x0_pred=ddim_use_x0_pred
|
||||
)
|
||||
return logs
|
||||
|
||||
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
|
||||
model = self.load_model_from_config(half_attention)
|
||||
|
||||
# Run settings
|
||||
diffusion_steps = int(steps)
|
||||
eta = 1.0
|
||||
|
||||
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
im_og = image
|
||||
width_og, height_og = im_og.size
|
||||
# If we can adjust the max upscale size, then the 4 below should be our variable
|
||||
down_sample_rate = target_scale / 4
|
||||
wd = width_og * down_sample_rate
|
||||
hd = height_og * down_sample_rate
|
||||
width_downsampled_pre = int(np.ceil(wd))
|
||||
height_downsampled_pre = int(np.ceil(hd))
|
||||
|
||||
if down_sample_rate != 1:
|
||||
print(
|
||||
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||
else:
|
||||
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
||||
|
||||
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
||||
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
||||
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
||||
|
||||
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
||||
|
||||
sample = logs["sample"]
|
||||
sample = sample.detach().cpu()
|
||||
sample = torch.clamp(sample, -1., 1.)
|
||||
sample = (sample + 1.) / 2. * 255
|
||||
sample = sample.numpy().astype(np.uint8)
|
||||
sample = np.transpose(sample, (0, 2, 3, 1))
|
||||
a = Image.fromarray(sample[0])
|
||||
|
||||
# remove padding
|
||||
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
return a
|
||||
|
||||
|
||||
def get_cond(selected_path):
|
||||
example = {}
|
||||
up_f = 4
|
||||
c = selected_path.convert('RGB')
|
||||
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
||||
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
|
||||
antialias=True)
|
||||
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
|
||||
c = rearrange(c, '1 c h w -> 1 h w c')
|
||||
c = 2. * c - 1.
|
||||
|
||||
c = c.to(shared.device)
|
||||
example["LR_image"] = c
|
||||
example["image"] = c_up
|
||||
|
||||
return example
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
|
||||
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
|
||||
corrector_kwargs=None, x_t=None
|
||||
):
|
||||
ddim = DDIMSampler(model)
|
||||
bs = shape[0]
|
||||
shape = shape[1:]
|
||||
print(f"Sampling with eta = {eta}; steps: {steps}")
|
||||
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
|
||||
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
|
||||
mask=mask, x0=x0, temperature=temperature, verbose=False,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs, x_t=x_t)
|
||||
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
||||
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
||||
log = {}
|
||||
|
||||
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
||||
return_first_stage_outputs=True,
|
||||
force_c_encode=not (hasattr(model, 'split_input_params')
|
||||
and model.cond_stage_key == 'coordinates_bbox'),
|
||||
return_original_cond=True)
|
||||
|
||||
if custom_shape is not None:
|
||||
z = torch.randn(custom_shape)
|
||||
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
||||
|
||||
z0 = None
|
||||
|
||||
log["input"] = x
|
||||
log["reconstruction"] = xrec
|
||||
|
||||
if ismap(xc):
|
||||
log["original_conditioning"] = model.to_rgb(xc)
|
||||
if hasattr(model, 'cond_stage_key'):
|
||||
log[model.cond_stage_key] = model.to_rgb(xc)
|
||||
|
||||
else:
|
||||
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
||||
if model.cond_stage_model:
|
||||
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
||||
if model.cond_stage_key == 'class_label':
|
||||
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
||||
|
||||
with model.ema_scope("Plotting"):
|
||||
t0 = time.time()
|
||||
|
||||
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
|
||||
eta=eta,
|
||||
quantize_x0=quantize_x0, mask=None, x0=z0,
|
||||
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
|
||||
x_t=x_T)
|
||||
t1 = time.time()
|
||||
|
||||
if ddim_use_x0_pred:
|
||||
sample = intermediates['pred_x0'][-1]
|
||||
|
||||
x_sample = model.decode_first_stage(sample)
|
||||
|
||||
try:
|
||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||
log["sample_noquant"] = x_sample_noquant
|
||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
log["sample"] = x_sample
|
||||
log["time"] = t1 - t0
|
||||
|
||||
return log
|
||||
@@ -1,6 +0,0 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
|
||||
@@ -1,70 +0,0 @@
|
||||
import os
|
||||
|
||||
from modules.modelloader import load_file_from_url
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules_forge.utils import prepare_free_memory
|
||||
from ldsr_model_arch import LDSR
|
||||
from modules import shared, script_callbacks, errors
|
||||
import sd_hijack_autoencoder # noqa: F401
|
||||
import sd_hijack_ddpm_v1 # noqa: F401
|
||||
|
||||
|
||||
class UpscalerLDSR(Upscaler):
|
||||
def __init__(self, user_path):
|
||||
self.name = "LDSR"
|
||||
self.user_path = user_path
|
||||
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
||||
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
||||
super().__init__()
|
||||
scaler_data = UpscalerData("LDSR", None, self)
|
||||
self.scalers = [scaler_data]
|
||||
|
||||
def load_model(self, path: str):
|
||||
# Remove incorrect project.yaml file if too big
|
||||
yaml_path = os.path.join(self.model_path, "project.yaml")
|
||||
old_model_path = os.path.join(self.model_path, "model.pth")
|
||||
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
||||
|
||||
local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"])
|
||||
local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None)
|
||||
local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None)
|
||||
local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None)
|
||||
|
||||
if os.path.exists(yaml_path):
|
||||
statinfo = os.stat(yaml_path)
|
||||
if statinfo.st_size >= 10485760:
|
||||
print("Removing invalid LDSR YAML file.")
|
||||
os.remove(yaml_path)
|
||||
|
||||
if os.path.exists(old_model_path):
|
||||
print("Renaming model from model.pth to model.ckpt")
|
||||
os.rename(old_model_path, new_model_path)
|
||||
|
||||
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
|
||||
model = local_safetensors_path
|
||||
else:
|
||||
model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
|
||||
|
||||
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
|
||||
|
||||
return LDSR(model, yaml)
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
prepare_free_memory(aggressive=True)
|
||||
try:
|
||||
ldsr = self.load_model(path)
|
||||
except Exception:
|
||||
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
|
||||
return img
|
||||
ddim_steps = shared.opts.ldsr_steps
|
||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
import gradio as gr
|
||||
|
||||
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
|
||||
shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
@@ -1,293 +0,0 @@
|
||||
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
||||
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
||||
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from ldm.modules.ema import LitEma
|
||||
from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
import ldm.models.autoencoder
|
||||
from packaging import version
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=None,
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=None):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys or []:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if missing:
|
||||
print(f"Missing Keys: {missing}")
|
||||
if unexpected:
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_,_,ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
if self.global_step <= 4:
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# https://github.com/pytorch/pytorch/issues/37142
|
||||
# try not to fool the heuristics
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train",
|
||||
predicted_indices=ind)
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log(f"val{suffix}/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor*self.learning_rate
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr_d, betas=(0.5, 0.9))
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
{
|
||||
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||
log = {}
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log["inputs"] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(*args, embed_dim=embed_dim, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
ldm.models.autoencoder.VQModel = VQModel
|
||||
ldm.models.autoencoder.VQModelInterface = VQModelInterface
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,147 +0,0 @@
|
||||
# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py,
|
||||
# where the license is as follows:
|
||||
#
|
||||
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
||||
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
||||
# OR OTHER DEALINGS IN THE SOFTWARE./
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class VectorQuantizer2(nn.Module):
|
||||
"""
|
||||
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
||||
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
||||
"""
|
||||
|
||||
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||
# backwards compatibility we use the buggy version by default, but you can
|
||||
# specify legacy=False to fix it.
|
||||
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
||||
sane_index_shape=False, legacy=True):
|
||||
super().__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
self.beta = beta
|
||||
self.legacy = legacy
|
||||
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed + 1
|
||||
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices.")
|
||||
else:
|
||||
self.re_embed = n_e
|
||||
|
||||
self.sane_index_shape = sane_index_shape
|
||||
|
||||
def remap_to_used(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
match = (inds[:, :, None] == used[None, None, ...]).long()
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2) < 1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
|
||||
def unmap_to_all(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
if self.re_embed > self.used.shape[0]: # extra token
|
||||
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
||||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
||||
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
||||
assert rescale_logits is False, "Only for interface compatible with Gumbel"
|
||||
assert return_logits is False, "Only for interface compatible with Gumbel"
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
||||
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
|
||||
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
||||
|
||||
min_encoding_indices = torch.argmin(d, dim=1)
|
||||
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
||||
perplexity = None
|
||||
min_encodings = None
|
||||
|
||||
# compute loss for embedding
|
||||
if not self.legacy:
|
||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
|
||||
torch.mean((z_q - z.detach()) ** 2)
|
||||
else:
|
||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
|
||||
torch.mean((z_q - z.detach()) ** 2)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
||||
|
||||
if self.remap is not None:
|
||||
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
||||
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
||||
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
||||
|
||||
if self.sane_index_shape:
|
||||
min_encoding_indices = min_encoding_indices.reshape(
|
||||
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
||||
|
||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||
|
||||
def get_codebook_entry(self, indices, shape):
|
||||
# shape specifying (batch, height, width, channel)
|
||||
if self.remap is not None:
|
||||
indices = indices.reshape(shape[0], -1) # add batch axis
|
||||
indices = self.unmap_to_all(indices)
|
||||
indices = indices.reshape(-1) # flatten again
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = self.embedding(indices)
|
||||
|
||||
if shape is not None:
|
||||
z_q = z_q.view(shape)
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
@@ -11,9 +11,12 @@ from transformers.models.clip.modeling_clip import CLIPVisionModelOutput
|
||||
from annotator.util import HWC3
|
||||
from typing import Callable, Tuple, Union
|
||||
|
||||
from modules.safe import Extra
|
||||
from modules import devices
|
||||
|
||||
import contextlib
|
||||
|
||||
Extra = lambda x: contextlib.nullcontext()
|
||||
|
||||
|
||||
def torch_handler(module: str, name: str):
|
||||
""" Allow all torch access. Bypass A1111 safety whitelist. """
|
||||
|
||||
@@ -19,7 +19,6 @@ import cv2
|
||||
import logging
|
||||
|
||||
from typing import Any, Callable, Dict, List
|
||||
from modules.safe import unsafe_torch_load
|
||||
from lib_controlnet.logging import logger
|
||||
|
||||
|
||||
@@ -28,7 +27,7 @@ def load_state_dict(ckpt_path, location="cpu"):
|
||||
if extension.lower() == ".safetensors":
|
||||
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
||||
else:
|
||||
state_dict = unsafe_torch_load(ckpt_path, map_location=torch.device(location))
|
||||
state_dict = torch.load(ckpt_path, map_location=torch.device(location))
|
||||
state_dict = get_state_dict(state_dict)
|
||||
logger.info(f"Loaded state_dict from [{ckpt_path}]")
|
||||
return state_dict
|
||||
|
||||
@@ -24,7 +24,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
|
||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import Any
|
||||
@@ -725,7 +724,7 @@ class Api:
|
||||
|
||||
def get_sd_models(self):
|
||||
import modules.sd_models as sd_models
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename} for x in sd_models.checkpoints_list.values()]
|
||||
|
||||
def get_sd_vaes(self):
|
||||
import modules.sd_vae as sd_vae
|
||||
|
||||
@@ -9,7 +9,7 @@ import modules.textual_inversion.dataset
|
||||
import torch
|
||||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
from backend.nn.unet import default
|
||||
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||
from modules.textual_inversion import textual_inversion, saving_settings
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
@@ -1,25 +1,12 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
|
||||
from threading import Thread
|
||||
|
||||
from modules.timer import startup_timer
|
||||
|
||||
|
||||
class HiddenPrints:
|
||||
def __enter__(self):
|
||||
self._original_stdout = sys.stdout
|
||||
sys.stdout = open(os.devnull, 'w')
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
sys.stdout.close()
|
||||
sys.stdout = self._original_stdout
|
||||
|
||||
|
||||
def imports():
|
||||
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
@@ -35,16 +22,8 @@ def imports():
|
||||
import gradio # noqa: F401
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
with HiddenPrints():
|
||||
from modules import paths, timer, import_hook, errors # noqa: F401
|
||||
startup_timer.record("setup paths")
|
||||
|
||||
import ldm.modules.encoders.modules # noqa: F401
|
||||
import ldm.modules.diffusionmodules.model
|
||||
startup_timer.record("import ldm")
|
||||
|
||||
import sgm.modules.encoders.modules # noqa: F401
|
||||
startup_timer.record("import sgm")
|
||||
from modules import paths, timer, import_hook, errors # noqa: F401
|
||||
startup_timer.record("setup paths")
|
||||
|
||||
from modules import shared_init
|
||||
shared_init.initialize()
|
||||
@@ -137,15 +116,6 @@ def initialize_rest(*, reload_script_modules=False):
|
||||
sd_vae.refresh_vae_list()
|
||||
startup_timer.record("refresh VAE")
|
||||
|
||||
from modules import textual_inversion
|
||||
textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||
startup_timer.record("refresh textual inversion templates")
|
||||
|
||||
from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
|
||||
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
|
||||
sd_hijack.list_optimizers()
|
||||
startup_timer.record("scripts list_optimizers")
|
||||
|
||||
from modules import sd_unet
|
||||
sd_unet.list_unets()
|
||||
startup_timer.record("scripts list_unets")
|
||||
|
||||
@@ -391,15 +391,15 @@ def prepare_environment():
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||
|
||||
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
# stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
# stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
# stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
# stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
@@ -456,8 +456,8 @@ def prepare_environment():
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
|
||||
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
# git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
# git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
|
||||
@@ -2,45 +2,15 @@ import os
|
||||
import sys
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
|
||||
|
||||
import modules.safe # noqa: F401
|
||||
|
||||
|
||||
def mute_sdxl_imports():
|
||||
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
module = Dummy()
|
||||
module.LPIPS = None
|
||||
sys.modules['taming.modules.losses.lpips'] = module
|
||||
|
||||
module = Dummy()
|
||||
module.StableDataModuleFromConfig = None
|
||||
sys.modules['sgm.data'] = module
|
||||
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
sys.path.insert(0, script_path)
|
||||
|
||||
# search for directory of stable diffusion in following places
|
||||
sd_path = None
|
||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
|
||||
for possible_sd_path in possible_sd_paths:
|
||||
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
||||
sd_path = os.path.abspath(possible_sd_path)
|
||||
break
|
||||
|
||||
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
||||
|
||||
mute_sdxl_imports()
|
||||
sd_path = os.path.dirname(__file__)
|
||||
|
||||
path_dirs = [
|
||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
(os.path.join(sd_path, '../huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
|
||||
(os.path.join(sd_path, '../repositories/BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../repositories/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
(os.path.join(sd_path, '../repositories/huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
|
||||
]
|
||||
|
||||
paths = {}
|
||||
@@ -53,13 +23,6 @@ for d, must_exist, what, options in path_dirs:
|
||||
d = os.path.abspath(d)
|
||||
if "atstart" in options:
|
||||
sys.path.insert(0, d)
|
||||
elif "sgm" in options:
|
||||
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
|
||||
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
|
||||
|
||||
sys.path.insert(0, d)
|
||||
import sgm # noqa: F401
|
||||
sys.path.pop(0)
|
||||
else:
|
||||
sys.path.append(d)
|
||||
paths[what] = d
|
||||
|
||||
@@ -28,8 +28,6 @@ import modules.images as images
|
||||
import modules.styles
|
||||
import modules.sd_models as sd_models
|
||||
import modules.sd_vae as sd_vae
|
||||
from ldm.data.util import AddMiDaS
|
||||
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
||||
|
||||
from einops import repeat, rearrange
|
||||
from blendmodes.blend import blendLayers, BlendType
|
||||
@@ -295,23 +293,7 @@ class StableDiffusionProcessing:
|
||||
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
||||
|
||||
def depth2img_image_conditioning(self, source_image):
|
||||
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
||||
transformer = AddMiDaS(model_type="dpt_hybrid")
|
||||
transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
|
||||
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
||||
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
||||
|
||||
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
||||
conditioning = torch.nn.functional.interpolate(
|
||||
self.sd_model.depth_model(midas_in),
|
||||
size=conditioning_image.shape[2:],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
(depth_min, depth_max) = torch.aminmax(conditioning)
|
||||
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
||||
return conditioning
|
||||
raise NotImplementedError('NotImplementedError: depth2img_image_conditioning')
|
||||
|
||||
def edit_image_conditioning(self, source_image):
|
||||
conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
|
||||
@@ -368,11 +350,6 @@ class StableDiffusionProcessing:
|
||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
|
||||
source_image = devices.cond_cast_float(source_image)
|
||||
|
||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
||||
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
||||
return self.depth2img_image_conditioning(source_image)
|
||||
|
||||
if self.sd_model.cond_stage_key == "edit":
|
||||
return self.edit_image_conditioning(source_image)
|
||||
|
||||
|
||||
390
modules/safe.py
390
modules/safe.py
@@ -1,195 +1,195 @@
|
||||
# this code is adapted from the script contributed by anon from /h/
|
||||
|
||||
import pickle
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import numpy
|
||||
import _codecs
|
||||
import zipfile
|
||||
import re
|
||||
|
||||
|
||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||
from modules import errors
|
||||
|
||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||
|
||||
def encode(*args):
|
||||
out = _codecs.encode(*args)
|
||||
return out
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
extra_handler = None
|
||||
|
||||
def persistent_load(self, saved_id):
|
||||
assert saved_id[0] == 'storage'
|
||||
|
||||
try:
|
||||
return TypedStorage(_internal=True)
|
||||
except TypeError:
|
||||
return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
|
||||
|
||||
def find_class(self, module, name):
|
||||
if self.extra_handler is not None:
|
||||
res = self.extra_handler(module, name)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return getattr(collections, name)
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
||||
return getattr(torch._utils, name)
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
|
||||
return getattr(torch, name)
|
||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||
return getattr(torch.nn.modules.container, name)
|
||||
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
||||
return getattr(numpy.core.multiarray, name)
|
||||
if module == 'numpy' and name in ['dtype', 'ndarray']:
|
||||
return getattr(numpy, name)
|
||||
if module == '_codecs' and name == 'encode':
|
||||
return encode
|
||||
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
||||
import pytorch_lightning.callbacks
|
||||
return pytorch_lightning.callbacks.model_checkpoint
|
||||
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
||||
import pytorch_lightning.callbacks.model_checkpoint
|
||||
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
||||
if module == "__builtin__" and name == 'set':
|
||||
return set
|
||||
|
||||
# Forbid everything else.
|
||||
raise Exception(f"global '{module}/{name}' is forbidden")
|
||||
|
||||
|
||||
# Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/<number>'
|
||||
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$")
|
||||
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||
|
||||
def check_zip_filenames(filename, names):
|
||||
for name in names:
|
||||
if allowed_zip_names_re.match(name):
|
||||
continue
|
||||
|
||||
raise Exception(f"bad file inside {filename}: {name}")
|
||||
|
||||
|
||||
def check_pt(filename, extra_handler):
|
||||
try:
|
||||
|
||||
# new pytorch format is a zip file
|
||||
with zipfile.ZipFile(filename) as z:
|
||||
check_zip_filenames(filename, z.namelist())
|
||||
|
||||
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||
if len(data_pkl_filenames) == 0:
|
||||
raise Exception(f"data.pkl not found in {filename}")
|
||||
if len(data_pkl_filenames) > 1:
|
||||
raise Exception(f"Multiple data.pkl found in {filename}")
|
||||
with z.open(data_pkl_filenames[0]) as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
unpickler.load()
|
||||
|
||||
except zipfile.BadZipfile:
|
||||
|
||||
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
||||
with open(filename, "rb") as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
for _ in range(5):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||
|
||||
|
||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
"""
|
||||
this function is intended to be used by extensions that want to load models with
|
||||
some extra classes in them that the usual unpickler would find suspicious.
|
||||
|
||||
Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||
and returns that field's value:
|
||||
|
||||
```python
|
||||
def extra(module, name):
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return collections.OrderedDict
|
||||
|
||||
return None
|
||||
|
||||
safe.load_with_extra('model.pt', extra_handler=extra)
|
||||
```
|
||||
|
||||
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
||||
definitely unsafe.
|
||||
"""
|
||||
|
||||
from modules import shared
|
||||
|
||||
try:
|
||||
if not shared.cmd_opts.disable_safe_unpickle:
|
||||
check_pt(filename, extra_handler)
|
||||
|
||||
except pickle.UnpicklingError:
|
||||
errors.report(
|
||||
f"Error verifying pickled file from {filename}\n"
|
||||
"-----> !!!! The file is most likely corrupted !!!! <-----\n"
|
||||
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
except Exception:
|
||||
errors.report(
|
||||
f"Error verifying pickled file from {filename}\n"
|
||||
f"The file may be malicious, so the program is not going to read it.\n"
|
||||
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
return unsafe_torch_load(filename, *args, **kwargs)
|
||||
|
||||
|
||||
class Extra:
|
||||
"""
|
||||
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
||||
(because it's not your code making the torch.load call). The intended use is like this:
|
||||
|
||||
```
|
||||
import torch
|
||||
from modules import safe
|
||||
|
||||
def handler(module, name):
|
||||
if module == 'torch' and name in ['float64', 'float16']:
|
||||
return getattr(torch, name)
|
||||
|
||||
return None
|
||||
|
||||
with safe.Extra(handler):
|
||||
x = torch.load('model.pt')
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, handler):
|
||||
self.handler = handler
|
||||
|
||||
def __enter__(self):
|
||||
global global_extra_handler
|
||||
|
||||
assert global_extra_handler is None, 'already inside an Extra() block'
|
||||
global_extra_handler = self.handler
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global global_extra_handler
|
||||
|
||||
global_extra_handler = None
|
||||
|
||||
|
||||
unsafe_torch_load = torch.load
|
||||
global_extra_handler = None
|
||||
# # this code is adapted from the script contributed by anon from /h/
|
||||
#
|
||||
# import pickle
|
||||
# import collections
|
||||
#
|
||||
# import torch
|
||||
# import numpy
|
||||
# import _codecs
|
||||
# import zipfile
|
||||
# import re
|
||||
#
|
||||
#
|
||||
# # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||
# from modules import errors
|
||||
#
|
||||
# TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||
#
|
||||
# def encode(*args):
|
||||
# out = _codecs.encode(*args)
|
||||
# return out
|
||||
#
|
||||
#
|
||||
# class RestrictedUnpickler(pickle.Unpickler):
|
||||
# extra_handler = None
|
||||
#
|
||||
# def persistent_load(self, saved_id):
|
||||
# assert saved_id[0] == 'storage'
|
||||
#
|
||||
# try:
|
||||
# return TypedStorage(_internal=True)
|
||||
# except TypeError:
|
||||
# return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
|
||||
#
|
||||
# def find_class(self, module, name):
|
||||
# if self.extra_handler is not None:
|
||||
# res = self.extra_handler(module, name)
|
||||
# if res is not None:
|
||||
# return res
|
||||
#
|
||||
# if module == 'collections' and name == 'OrderedDict':
|
||||
# return getattr(collections, name)
|
||||
# if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
||||
# return getattr(torch._utils, name)
|
||||
# if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
|
||||
# return getattr(torch, name)
|
||||
# if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||
# return getattr(torch.nn.modules.container, name)
|
||||
# if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
||||
# return getattr(numpy.core.multiarray, name)
|
||||
# if module == 'numpy' and name in ['dtype', 'ndarray']:
|
||||
# return getattr(numpy, name)
|
||||
# if module == '_codecs' and name == 'encode':
|
||||
# return encode
|
||||
# if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
||||
# import pytorch_lightning.callbacks
|
||||
# return pytorch_lightning.callbacks.model_checkpoint
|
||||
# if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
||||
# import pytorch_lightning.callbacks.model_checkpoint
|
||||
# return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
||||
# if module == "__builtin__" and name == 'set':
|
||||
# return set
|
||||
#
|
||||
# # Forbid everything else.
|
||||
# raise Exception(f"global '{module}/{name}' is forbidden")
|
||||
#
|
||||
#
|
||||
# # Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/<number>'
|
||||
# allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$")
|
||||
# data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||
#
|
||||
# def check_zip_filenames(filename, names):
|
||||
# for name in names:
|
||||
# if allowed_zip_names_re.match(name):
|
||||
# continue
|
||||
#
|
||||
# raise Exception(f"bad file inside {filename}: {name}")
|
||||
#
|
||||
#
|
||||
# def check_pt(filename, extra_handler):
|
||||
# try:
|
||||
#
|
||||
# # new pytorch format is a zip file
|
||||
# with zipfile.ZipFile(filename) as z:
|
||||
# check_zip_filenames(filename, z.namelist())
|
||||
#
|
||||
# # find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||
# data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||
# if len(data_pkl_filenames) == 0:
|
||||
# raise Exception(f"data.pkl not found in {filename}")
|
||||
# if len(data_pkl_filenames) > 1:
|
||||
# raise Exception(f"Multiple data.pkl found in {filename}")
|
||||
# with z.open(data_pkl_filenames[0]) as file:
|
||||
# unpickler = RestrictedUnpickler(file)
|
||||
# unpickler.extra_handler = extra_handler
|
||||
# unpickler.load()
|
||||
#
|
||||
# except zipfile.BadZipfile:
|
||||
#
|
||||
# # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
||||
# with open(filename, "rb") as file:
|
||||
# unpickler = RestrictedUnpickler(file)
|
||||
# unpickler.extra_handler = extra_handler
|
||||
# for _ in range(5):
|
||||
# unpickler.load()
|
||||
#
|
||||
#
|
||||
# def load(filename, *args, **kwargs):
|
||||
# return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||
#
|
||||
#
|
||||
# def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
# """
|
||||
# this function is intended to be used by extensions that want to load models with
|
||||
# some extra classes in them that the usual unpickler would find suspicious.
|
||||
#
|
||||
# Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||
# and returns that field's value:
|
||||
#
|
||||
# ```python
|
||||
# def extra(module, name):
|
||||
# if module == 'collections' and name == 'OrderedDict':
|
||||
# return collections.OrderedDict
|
||||
#
|
||||
# return None
|
||||
#
|
||||
# safe.load_with_extra('model.pt', extra_handler=extra)
|
||||
# ```
|
||||
#
|
||||
# The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
||||
# definitely unsafe.
|
||||
# """
|
||||
#
|
||||
# from modules import shared
|
||||
#
|
||||
# try:
|
||||
# if not shared.cmd_opts.disable_safe_unpickle:
|
||||
# check_pt(filename, extra_handler)
|
||||
#
|
||||
# except pickle.UnpicklingError:
|
||||
# errors.report(
|
||||
# f"Error verifying pickled file from {filename}\n"
|
||||
# "-----> !!!! The file is most likely corrupted !!!! <-----\n"
|
||||
# "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
|
||||
# exc_info=True,
|
||||
# )
|
||||
# return None
|
||||
# except Exception:
|
||||
# errors.report(
|
||||
# f"Error verifying pickled file from {filename}\n"
|
||||
# f"The file may be malicious, so the program is not going to read it.\n"
|
||||
# f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
|
||||
# exc_info=True,
|
||||
# )
|
||||
# return None
|
||||
#
|
||||
# return unsafe_torch_load(filename, *args, **kwargs)
|
||||
#
|
||||
#
|
||||
# class Extra:
|
||||
# """
|
||||
# A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
||||
# (because it's not your code making the torch.load call). The intended use is like this:
|
||||
#
|
||||
# ```
|
||||
# import torch
|
||||
# from modules import safe
|
||||
#
|
||||
# def handler(module, name):
|
||||
# if module == 'torch' and name in ['float64', 'float16']:
|
||||
# return getattr(torch, name)
|
||||
#
|
||||
# return None
|
||||
#
|
||||
# with safe.Extra(handler):
|
||||
# x = torch.load('model.pt')
|
||||
# ```
|
||||
# """
|
||||
#
|
||||
# def __init__(self, handler):
|
||||
# self.handler = handler
|
||||
#
|
||||
# def __enter__(self):
|
||||
# global global_extra_handler
|
||||
#
|
||||
# assert global_extra_handler is None, 'already inside an Extra() block'
|
||||
# global_extra_handler = self.handler
|
||||
#
|
||||
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# global global_extra_handler
|
||||
#
|
||||
# global_extra_handler = None
|
||||
#
|
||||
#
|
||||
# unsafe_torch_load = torch.load
|
||||
# global_extra_handler = None
|
||||
|
||||
@@ -1,232 +1,232 @@
|
||||
import ldm.modules.encoders.modules
|
||||
import open_clip
|
||||
import torch
|
||||
import transformers.utils.hub
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
||||
class ReplaceHelper:
|
||||
def __init__(self):
|
||||
self.replaced = []
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
original = getattr(obj, field, None)
|
||||
if original is None:
|
||||
return None
|
||||
|
||||
self.replaced.append((obj, field, original))
|
||||
setattr(obj, field, func)
|
||||
|
||||
return original
|
||||
|
||||
def restore(self):
|
||||
for obj, field, original in self.replaced:
|
||||
setattr(obj, field, original)
|
||||
|
||||
self.replaced.clear()
|
||||
|
||||
|
||||
class DisableInitialization(ReplaceHelper):
|
||||
"""
|
||||
When an object of this class enters a `with` block, it starts:
|
||||
- preventing torch's layer initialization functions from working
|
||||
- changes CLIP and OpenCLIP to not download model weights
|
||||
- changes CLIP to not make requests to check if there is a new version of a file you already have
|
||||
|
||||
When it leaves the block, it reverts everything to how it was before.
|
||||
|
||||
Use it like this:
|
||||
```
|
||||
with DisableInitialization():
|
||||
do_things()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, disable_clip=True):
|
||||
super().__init__()
|
||||
self.disable_clip = disable_clip
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
original = getattr(obj, field, None)
|
||||
if original is None:
|
||||
return None
|
||||
|
||||
self.replaced.append((obj, field, original))
|
||||
setattr(obj, field, func)
|
||||
|
||||
return original
|
||||
|
||||
def __enter__(self):
|
||||
def do_nothing(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
|
||||
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
||||
|
||||
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
||||
res.name_or_path = pretrained_model_name_or_path
|
||||
return res
|
||||
|
||||
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
||||
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
||||
return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
|
||||
|
||||
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
||||
|
||||
# this file is always 404, prevent making request
|
||||
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
||||
return None
|
||||
|
||||
try:
|
||||
res = original(url, *args, local_files_only=True, **kwargs)
|
||||
if res is None:
|
||||
res = original(url, *args, local_files_only=False, **kwargs)
|
||||
return res
|
||||
except Exception:
|
||||
return original(url, *args, local_files_only=False, **kwargs)
|
||||
|
||||
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
|
||||
|
||||
def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
|
||||
|
||||
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
||||
|
||||
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||
|
||||
if self.disable_clip:
|
||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.restore()
|
||||
|
||||
|
||||
class InitializeOnMeta(ReplaceHelper):
|
||||
"""
|
||||
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
|
||||
which results in those parameters having no values and taking no memory. model.to() will be broken and
|
||||
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
|
||||
|
||||
Usage:
|
||||
```
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
```
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
return
|
||||
|
||||
def set_device(x):
|
||||
x["device"] = "meta"
|
||||
return x
|
||||
|
||||
linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
|
||||
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
|
||||
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
|
||||
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.restore()
|
||||
|
||||
|
||||
class LoadStateDictOnMeta(ReplaceHelper):
|
||||
"""
|
||||
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
|
||||
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
|
||||
Meant to be used together with InitializeOnMeta above.
|
||||
|
||||
Usage:
|
||||
```
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, state_dict, device, weight_dtype_conversion=None):
|
||||
super().__init__()
|
||||
self.state_dict = state_dict
|
||||
self.device = device
|
||||
self.weight_dtype_conversion = weight_dtype_conversion or {}
|
||||
self.default_dtype = self.weight_dtype_conversion.get('')
|
||||
|
||||
def get_weight_dtype(self, key):
|
||||
key_first_term, _ = key.split('.', 1)
|
||||
return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
|
||||
|
||||
def __enter__(self):
|
||||
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
return
|
||||
|
||||
sd = self.state_dict
|
||||
device = self.device
|
||||
|
||||
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
|
||||
used_param_keys = []
|
||||
|
||||
for name, param in module._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
key = prefix + name
|
||||
sd_param = sd.pop(key, None)
|
||||
if sd_param is not None:
|
||||
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
||||
used_param_keys.append(key)
|
||||
|
||||
if param.is_meta:
|
||||
dtype = sd_param.dtype if sd_param is not None else param.dtype
|
||||
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
|
||||
|
||||
for name in module._buffers:
|
||||
key = prefix + name
|
||||
|
||||
sd_param = sd.pop(key, None)
|
||||
if sd_param is not None:
|
||||
state_dict[key] = sd_param
|
||||
used_param_keys.append(key)
|
||||
|
||||
original(module, state_dict, prefix, *args, **kwargs)
|
||||
|
||||
for key in used_param_keys:
|
||||
state_dict.pop(key, None)
|
||||
|
||||
def load_state_dict(original, module, state_dict, strict=True):
|
||||
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
|
||||
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
|
||||
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
|
||||
|
||||
In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
|
||||
|
||||
The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
|
||||
the function and does not call the original) the state dict will just fail to load because weights
|
||||
would be on the meta device.
|
||||
"""
|
||||
|
||||
if state_dict is sd:
|
||||
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||
|
||||
original(module, state_dict, strict=strict)
|
||||
|
||||
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
|
||||
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
|
||||
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
||||
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
||||
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
||||
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
|
||||
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.restore()
|
||||
# import ldm.modules.encoders.modules
|
||||
# import open_clip
|
||||
# import torch
|
||||
# import transformers.utils.hub
|
||||
#
|
||||
# from modules import shared
|
||||
#
|
||||
#
|
||||
# class ReplaceHelper:
|
||||
# def __init__(self):
|
||||
# self.replaced = []
|
||||
#
|
||||
# def replace(self, obj, field, func):
|
||||
# original = getattr(obj, field, None)
|
||||
# if original is None:
|
||||
# return None
|
||||
#
|
||||
# self.replaced.append((obj, field, original))
|
||||
# setattr(obj, field, func)
|
||||
#
|
||||
# return original
|
||||
#
|
||||
# def restore(self):
|
||||
# for obj, field, original in self.replaced:
|
||||
# setattr(obj, field, original)
|
||||
#
|
||||
# self.replaced.clear()
|
||||
#
|
||||
#
|
||||
# class DisableInitialization(ReplaceHelper):
|
||||
# """
|
||||
# When an object of this class enters a `with` block, it starts:
|
||||
# - preventing torch's layer initialization functions from working
|
||||
# - changes CLIP and OpenCLIP to not download model weights
|
||||
# - changes CLIP to not make requests to check if there is a new version of a file you already have
|
||||
#
|
||||
# When it leaves the block, it reverts everything to how it was before.
|
||||
#
|
||||
# Use it like this:
|
||||
# ```
|
||||
# with DisableInitialization():
|
||||
# do_things()
|
||||
# ```
|
||||
# """
|
||||
#
|
||||
# def __init__(self, disable_clip=True):
|
||||
# super().__init__()
|
||||
# self.disable_clip = disable_clip
|
||||
#
|
||||
# def replace(self, obj, field, func):
|
||||
# original = getattr(obj, field, None)
|
||||
# if original is None:
|
||||
# return None
|
||||
#
|
||||
# self.replaced.append((obj, field, original))
|
||||
# setattr(obj, field, func)
|
||||
#
|
||||
# return original
|
||||
#
|
||||
# def __enter__(self):
|
||||
# def do_nothing(*args, **kwargs):
|
||||
# pass
|
||||
#
|
||||
# def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
|
||||
# return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
||||
#
|
||||
# def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
||||
# res.name_or_path = pretrained_model_name_or_path
|
||||
# return res
|
||||
#
|
||||
# def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
||||
# args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
||||
# return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
|
||||
#
|
||||
# def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
||||
#
|
||||
# # this file is always 404, prevent making request
|
||||
# if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
||||
# return None
|
||||
#
|
||||
# try:
|
||||
# res = original(url, *args, local_files_only=True, **kwargs)
|
||||
# if res is None:
|
||||
# res = original(url, *args, local_files_only=False, **kwargs)
|
||||
# return res
|
||||
# except Exception:
|
||||
# return original(url, *args, local_files_only=False, **kwargs)
|
||||
#
|
||||
# def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||
# return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
|
||||
#
|
||||
# def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
# return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
|
||||
#
|
||||
# def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
# return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
||||
#
|
||||
# self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||
# self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||
# self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||
#
|
||||
# if self.disable_clip:
|
||||
# self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
# self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
# self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
# self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
# self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
# self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
#
|
||||
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# self.restore()
|
||||
#
|
||||
#
|
||||
# class InitializeOnMeta(ReplaceHelper):
|
||||
# """
|
||||
# Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
|
||||
# which results in those parameters having no values and taking no memory. model.to() will be broken and
|
||||
# will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
|
||||
#
|
||||
# Usage:
|
||||
# ```
|
||||
# with sd_disable_initialization.InitializeOnMeta():
|
||||
# sd_model = instantiate_from_config(sd_config.model)
|
||||
# ```
|
||||
# """
|
||||
#
|
||||
# def __enter__(self):
|
||||
# if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
# return
|
||||
#
|
||||
# def set_device(x):
|
||||
# x["device"] = "meta"
|
||||
# return x
|
||||
#
|
||||
# linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
|
||||
# conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
|
||||
# mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
|
||||
# self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
|
||||
#
|
||||
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# self.restore()
|
||||
#
|
||||
#
|
||||
# class LoadStateDictOnMeta(ReplaceHelper):
|
||||
# """
|
||||
# Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
|
||||
# As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
|
||||
# Meant to be used together with InitializeOnMeta above.
|
||||
#
|
||||
# Usage:
|
||||
# ```
|
||||
# with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
|
||||
# model.load_state_dict(state_dict, strict=False)
|
||||
# ```
|
||||
# """
|
||||
#
|
||||
# def __init__(self, state_dict, device, weight_dtype_conversion=None):
|
||||
# super().__init__()
|
||||
# self.state_dict = state_dict
|
||||
# self.device = device
|
||||
# self.weight_dtype_conversion = weight_dtype_conversion or {}
|
||||
# self.default_dtype = self.weight_dtype_conversion.get('')
|
||||
#
|
||||
# def get_weight_dtype(self, key):
|
||||
# key_first_term, _ = key.split('.', 1)
|
||||
# return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
|
||||
#
|
||||
# def __enter__(self):
|
||||
# if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
# return
|
||||
#
|
||||
# sd = self.state_dict
|
||||
# device = self.device
|
||||
#
|
||||
# def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
|
||||
# used_param_keys = []
|
||||
#
|
||||
# for name, param in module._parameters.items():
|
||||
# if param is None:
|
||||
# continue
|
||||
#
|
||||
# key = prefix + name
|
||||
# sd_param = sd.pop(key, None)
|
||||
# if sd_param is not None:
|
||||
# state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
||||
# used_param_keys.append(key)
|
||||
#
|
||||
# if param.is_meta:
|
||||
# dtype = sd_param.dtype if sd_param is not None else param.dtype
|
||||
# module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
|
||||
#
|
||||
# for name in module._buffers:
|
||||
# key = prefix + name
|
||||
#
|
||||
# sd_param = sd.pop(key, None)
|
||||
# if sd_param is not None:
|
||||
# state_dict[key] = sd_param
|
||||
# used_param_keys.append(key)
|
||||
#
|
||||
# original(module, state_dict, prefix, *args, **kwargs)
|
||||
#
|
||||
# for key in used_param_keys:
|
||||
# state_dict.pop(key, None)
|
||||
#
|
||||
# def load_state_dict(original, module, state_dict, strict=True):
|
||||
# """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
|
||||
# because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
|
||||
# all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
|
||||
#
|
||||
# In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
|
||||
#
|
||||
# The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
|
||||
# the function and does not call the original) the state dict will just fail to load because weights
|
||||
# would be on the meta device.
|
||||
# """
|
||||
#
|
||||
# if state_dict is sd:
|
||||
# state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||
#
|
||||
# original(module, state_dict, strict=strict)
|
||||
#
|
||||
# module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
|
||||
# module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
|
||||
# linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
||||
# conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
||||
# mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
||||
# layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
|
||||
# group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
|
||||
#
|
||||
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# self.restore()
|
||||
|
||||
@@ -1,124 +1,3 @@
|
||||
import torch
|
||||
from torch.nn.functional import silu
|
||||
from types import MethodType
|
||||
|
||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
import ldm.models.diffusion.ddpm
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
import ldm.modules.encoders.modules
|
||||
|
||||
import sgm.modules.attention
|
||||
import sgm.modules.diffusionmodules.model
|
||||
import sgm.modules.diffusionmodules.openaimodel
|
||||
import sgm.modules.encoders.modules
|
||||
|
||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
# new memory efficient cross attention blocks do not support hypernets and we already
|
||||
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
||||
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||
|
||||
# silence new console spam from SD2
|
||||
ldm.modules.attention.print = shared.ldm_print
|
||||
ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
ldm.util.print = shared.ldm_print
|
||||
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||
|
||||
optimizers = []
|
||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||
|
||||
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
|
||||
|
||||
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
|
||||
|
||||
|
||||
def list_optimizers():
|
||||
new_optimizers = script_callbacks.list_optimizers_callback()
|
||||
|
||||
new_optimizers = [x for x in new_optimizers if x.is_available()]
|
||||
|
||||
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
||||
|
||||
optimizers.clear()
|
||||
optimizers.extend(new_optimizers)
|
||||
|
||||
|
||||
def apply_optimizations(option=None):
|
||||
return
|
||||
|
||||
|
||||
def undo_optimizations():
|
||||
return
|
||||
|
||||
|
||||
def fix_checkpoint():
|
||||
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||
checkpoints to be added when not training (there's a warning)"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def weighted_loss(sd_model, pred, target, mean=True):
|
||||
#Calculate the weight normally, but ignore the mean
|
||||
loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||
|
||||
#Check if we have weights available
|
||||
weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||
if weight is not None:
|
||||
loss *= weight
|
||||
|
||||
#Return the loss, as mean if specified
|
||||
return loss.mean() if mean else loss
|
||||
|
||||
def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
||||
try:
|
||||
#Temporarily append weights to a place accessible during loss calc
|
||||
sd_model._custom_loss_weight = w
|
||||
|
||||
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||
if not hasattr(sd_model, '_old_get_loss'):
|
||||
sd_model._old_get_loss = sd_model.get_loss
|
||||
sd_model.get_loss = MethodType(weighted_loss, sd_model)
|
||||
|
||||
#Run the standard forward function, but with the patched 'get_loss'
|
||||
return sd_model.forward(x, c, *args, **kwargs)
|
||||
finally:
|
||||
try:
|
||||
#Delete temporary weights if appended
|
||||
del sd_model._custom_loss_weight
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
#If we have an old loss function, reset the loss function to the original one
|
||||
if hasattr(sd_model, '_old_get_loss'):
|
||||
sd_model.get_loss = sd_model._old_get_loss
|
||||
del sd_model._old_get_loss
|
||||
|
||||
def apply_weighted_forward(sd_model):
|
||||
#Add new function 'weighted_forward' that can be called to calc weighted loss
|
||||
sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
|
||||
|
||||
def undo_weighted_forward(sd_model):
|
||||
try:
|
||||
del sd_model.weighted_forward
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
fixes = None
|
||||
layers = None
|
||||
@@ -156,74 +35,234 @@ class StableDiffusionModelHijack:
|
||||
pass
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.embeddings = embeddings
|
||||
self.textual_inversion_key = textual_inversion_key
|
||||
self.weight = self.wrapped.weight
|
||||
|
||||
def forward(self, input_ids):
|
||||
batch_fixes = self.embeddings.fixes
|
||||
self.embeddings.fixes = None
|
||||
|
||||
inputs_embeds = self.wrapped(input_ids)
|
||||
|
||||
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
vecs = []
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||
emb = devices.cond_cast_unet(vec)
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
||||
|
||||
vecs.append(tensor)
|
||||
|
||||
return torch.stack(vecs)
|
||||
|
||||
|
||||
class TextualInversionEmbeddings(torch.nn.Embedding):
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
|
||||
super().__init__(num_embeddings, embedding_dim, **kwargs)
|
||||
|
||||
self.embeddings = model_hijack
|
||||
self.textual_inversion_key = textual_inversion_key
|
||||
|
||||
@property
|
||||
def wrapped(self):
|
||||
return super().forward
|
||||
|
||||
def forward(self, input_ids):
|
||||
return EmbeddingsWithFixes.forward(self, input_ids)
|
||||
|
||||
|
||||
def add_circular_option_to_conv_2d():
|
||||
conv2d_constructor = torch.nn.Conv2d.__init__
|
||||
|
||||
def conv2d_constructor_circular(self, *args, **kwargs):
|
||||
return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
||||
|
||||
torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
||||
|
||||
|
||||
model_hijack = StableDiffusionModelHijack()
|
||||
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
"""
|
||||
Fix register buffer bug for Mac OS.
|
||||
"""
|
||||
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != devices.device:
|
||||
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
||||
|
||||
setattr(self, name, attr)
|
||||
|
||||
|
||||
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
||||
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
||||
# import torch
|
||||
# from torch.nn.functional import silu
|
||||
# from types import MethodType
|
||||
#
|
||||
# from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
||||
# from modules.hypernetworks import hypernetwork
|
||||
# from modules.shared import cmd_opts
|
||||
# from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
||||
#
|
||||
# import ldm.modules.attention
|
||||
# import ldm.modules.diffusionmodules.model
|
||||
# import ldm.modules.diffusionmodules.openaimodel
|
||||
# import ldm.models.diffusion.ddpm
|
||||
# import ldm.models.diffusion.ddim
|
||||
# import ldm.models.diffusion.plms
|
||||
# import ldm.modules.encoders.modules
|
||||
#
|
||||
# import sgm.modules.attention
|
||||
# import sgm.modules.diffusionmodules.model
|
||||
# import sgm.modules.diffusionmodules.openaimodel
|
||||
# import sgm.modules.encoders.modules
|
||||
#
|
||||
# attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||
# diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
#
|
||||
# # new memory efficient cross attention blocks do not support hypernets and we already
|
||||
# # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
||||
# ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
||||
# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||
#
|
||||
# # silence new console spam from SD2
|
||||
# ldm.modules.attention.print = shared.ldm_print
|
||||
# ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
# ldm.util.print = shared.ldm_print
|
||||
# ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||
#
|
||||
# optimizers = []
|
||||
# current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||
#
|
||||
# ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||
# ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
|
||||
#
|
||||
# sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||
# sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
|
||||
#
|
||||
#
|
||||
# def list_optimizers():
|
||||
# new_optimizers = script_callbacks.list_optimizers_callback()
|
||||
#
|
||||
# new_optimizers = [x for x in new_optimizers if x.is_available()]
|
||||
#
|
||||
# new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
||||
#
|
||||
# optimizers.clear()
|
||||
# optimizers.extend(new_optimizers)
|
||||
#
|
||||
#
|
||||
# def apply_optimizations(option=None):
|
||||
# return
|
||||
#
|
||||
#
|
||||
# def undo_optimizations():
|
||||
# return
|
||||
#
|
||||
#
|
||||
# def fix_checkpoint():
|
||||
# """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||
# checkpoints to be added when not training (there's a warning)"""
|
||||
#
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def weighted_loss(sd_model, pred, target, mean=True):
|
||||
# #Calculate the weight normally, but ignore the mean
|
||||
# loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||
#
|
||||
# #Check if we have weights available
|
||||
# weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||
# if weight is not None:
|
||||
# loss *= weight
|
||||
#
|
||||
# #Return the loss, as mean if specified
|
||||
# return loss.mean() if mean else loss
|
||||
#
|
||||
# def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
||||
# try:
|
||||
# #Temporarily append weights to a place accessible during loss calc
|
||||
# sd_model._custom_loss_weight = w
|
||||
#
|
||||
# #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||
# #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||
# if not hasattr(sd_model, '_old_get_loss'):
|
||||
# sd_model._old_get_loss = sd_model.get_loss
|
||||
# sd_model.get_loss = MethodType(weighted_loss, sd_model)
|
||||
#
|
||||
# #Run the standard forward function, but with the patched 'get_loss'
|
||||
# return sd_model.forward(x, c, *args, **kwargs)
|
||||
# finally:
|
||||
# try:
|
||||
# #Delete temporary weights if appended
|
||||
# del sd_model._custom_loss_weight
|
||||
# except AttributeError:
|
||||
# pass
|
||||
#
|
||||
# #If we have an old loss function, reset the loss function to the original one
|
||||
# if hasattr(sd_model, '_old_get_loss'):
|
||||
# sd_model.get_loss = sd_model._old_get_loss
|
||||
# del sd_model._old_get_loss
|
||||
#
|
||||
# def apply_weighted_forward(sd_model):
|
||||
# #Add new function 'weighted_forward' that can be called to calc weighted loss
|
||||
# sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
|
||||
#
|
||||
# def undo_weighted_forward(sd_model):
|
||||
# try:
|
||||
# del sd_model.weighted_forward
|
||||
# except AttributeError:
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# class StableDiffusionModelHijack:
|
||||
# fixes = None
|
||||
# layers = None
|
||||
# circular_enabled = False
|
||||
# clip = None
|
||||
# optimization_method = None
|
||||
#
|
||||
# def __init__(self):
|
||||
# self.extra_generation_params = {}
|
||||
# self.comments = []
|
||||
#
|
||||
# def apply_optimizations(self, option=None):
|
||||
# pass
|
||||
#
|
||||
# def convert_sdxl_to_ssd(self, m):
|
||||
# pass
|
||||
#
|
||||
# def hijack(self, m):
|
||||
# pass
|
||||
#
|
||||
# def undo_hijack(self, m):
|
||||
# pass
|
||||
#
|
||||
# def apply_circular(self, enable):
|
||||
# pass
|
||||
#
|
||||
# def clear_comments(self):
|
||||
# self.comments = []
|
||||
# self.extra_generation_params = {}
|
||||
#
|
||||
# def get_prompt_lengths(self, text, cond_stage_model):
|
||||
# pass
|
||||
#
|
||||
# def redo_hijack(self, m):
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# class EmbeddingsWithFixes(torch.nn.Module):
|
||||
# def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||
# super().__init__()
|
||||
# self.wrapped = wrapped
|
||||
# self.embeddings = embeddings
|
||||
# self.textual_inversion_key = textual_inversion_key
|
||||
# self.weight = self.wrapped.weight
|
||||
#
|
||||
# def forward(self, input_ids):
|
||||
# batch_fixes = self.embeddings.fixes
|
||||
# self.embeddings.fixes = None
|
||||
#
|
||||
# inputs_embeds = self.wrapped(input_ids)
|
||||
#
|
||||
# if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||
# return inputs_embeds
|
||||
#
|
||||
# vecs = []
|
||||
# for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
# for offset, embedding in fixes:
|
||||
# vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||
# emb = devices.cond_cast_unet(vec)
|
||||
# emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
# tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
||||
#
|
||||
# vecs.append(tensor)
|
||||
#
|
||||
# return torch.stack(vecs)
|
||||
#
|
||||
#
|
||||
# class TextualInversionEmbeddings(torch.nn.Embedding):
|
||||
# def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
|
||||
# super().__init__(num_embeddings, embedding_dim, **kwargs)
|
||||
#
|
||||
# self.embeddings = model_hijack
|
||||
# self.textual_inversion_key = textual_inversion_key
|
||||
#
|
||||
# @property
|
||||
# def wrapped(self):
|
||||
# return super().forward
|
||||
#
|
||||
# def forward(self, input_ids):
|
||||
# return EmbeddingsWithFixes.forward(self, input_ids)
|
||||
#
|
||||
#
|
||||
# def add_circular_option_to_conv_2d():
|
||||
# conv2d_constructor = torch.nn.Conv2d.__init__
|
||||
#
|
||||
# def conv2d_constructor_circular(self, *args, **kwargs):
|
||||
# return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
||||
#
|
||||
# torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# def register_buffer(self, name, attr):
|
||||
# """
|
||||
# Fix register buffer bug for Mac OS.
|
||||
# """
|
||||
#
|
||||
# if type(attr) == torch.Tensor:
|
||||
# if attr.device != devices.device:
|
||||
# attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
||||
#
|
||||
# setattr(self, name, attr)
|
||||
#
|
||||
#
|
||||
# ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
||||
# ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
||||
|
||||
@@ -1,46 +1,46 @@
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
|
||||
|
||||
def BasicTransformerBlock_forward(self, x, context=None):
|
||||
return checkpoint(self._forward, x, context)
|
||||
|
||||
|
||||
def AttentionBlock_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
|
||||
def ResBlock_forward(self, x, emb):
|
||||
return checkpoint(self._forward, x, emb)
|
||||
|
||||
|
||||
stored = []
|
||||
|
||||
|
||||
def add():
|
||||
if len(stored) != 0:
|
||||
return
|
||||
|
||||
stored.extend([
|
||||
ldm.modules.attention.BasicTransformerBlock.forward,
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
|
||||
])
|
||||
|
||||
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
|
||||
|
||||
|
||||
def remove():
|
||||
if len(stored) == 0:
|
||||
return
|
||||
|
||||
ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
|
||||
|
||||
stored.clear()
|
||||
|
||||
# from torch.utils.checkpoint import checkpoint
|
||||
#
|
||||
# import ldm.modules.attention
|
||||
# import ldm.modules.diffusionmodules.openaimodel
|
||||
#
|
||||
#
|
||||
# def BasicTransformerBlock_forward(self, x, context=None):
|
||||
# return checkpoint(self._forward, x, context)
|
||||
#
|
||||
#
|
||||
# def AttentionBlock_forward(self, x):
|
||||
# return checkpoint(self._forward, x)
|
||||
#
|
||||
#
|
||||
# def ResBlock_forward(self, x, emb):
|
||||
# return checkpoint(self._forward, x, emb)
|
||||
#
|
||||
#
|
||||
# stored = []
|
||||
#
|
||||
#
|
||||
# def add():
|
||||
# if len(stored) != 0:
|
||||
# return
|
||||
#
|
||||
# stored.extend([
|
||||
# ldm.modules.attention.BasicTransformerBlock.forward,
|
||||
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
|
||||
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
|
||||
# ])
|
||||
#
|
||||
# ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
|
||||
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
|
||||
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
|
||||
#
|
||||
#
|
||||
# def remove():
|
||||
# if len(stored) == 0:
|
||||
# return
|
||||
#
|
||||
# ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
|
||||
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
|
||||
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
|
||||
#
|
||||
# stored.clear()
|
||||
#
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,154 +1,154 @@
|
||||
import torch
|
||||
from packaging import version
|
||||
from einops import repeat
|
||||
import math
|
||||
|
||||
from modules import devices
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
|
||||
|
||||
class TorchHijackForUnet:
|
||||
"""
|
||||
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
||||
"""
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'cat':
|
||||
return self.cat
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def cat(self, tensors, *args, **kwargs):
|
||||
if len(tensors) == 2:
|
||||
a, b = tensors
|
||||
if a.shape[-2:] != b.shape[-2:]:
|
||||
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||
|
||||
tensors = (a, b)
|
||||
|
||||
return torch.cat(tensors, *args, **kwargs)
|
||||
|
||||
|
||||
th = TorchHijackForUnet()
|
||||
|
||||
|
||||
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||
"""Always make sure inputs to unet are in correct dtype."""
|
||||
if isinstance(cond, dict):
|
||||
for y in cond.keys():
|
||||
if isinstance(cond[y], list):
|
||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
else:
|
||||
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||
|
||||
with devices.autocast():
|
||||
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
||||
if devices.unet_needs_upcast:
|
||||
return result.float()
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
# Monkey patch to create timestep embed tensor on device, avoiding a block.
|
||||
def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
# Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
|
||||
# Prevents a lot of unnecessary aten::copy_ calls
|
||||
def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
if not isinstance(context, list):
|
||||
context = [context]
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, context=context[i])
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||
def forward(self, x):
|
||||
if devices.unet_needs_upcast:
|
||||
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||
else:
|
||||
return torch.nn.GELU.forward(self, x)
|
||||
|
||||
|
||||
ddpm_edit_hijack = None
|
||||
def hijack_ddpm_edit():
|
||||
global ddpm_edit_hijack
|
||||
if not ddpm_edit_hijack:
|
||||
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
|
||||
|
||||
|
||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
||||
CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
|
||||
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
||||
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
||||
|
||||
|
||||
def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
||||
if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = devices.dtype_unet
|
||||
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
||||
|
||||
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||
# import torch
|
||||
# from packaging import version
|
||||
# from einops import repeat
|
||||
# import math
|
||||
#
|
||||
# from modules import devices
|
||||
# from modules.sd_hijack_utils import CondFunc
|
||||
#
|
||||
#
|
||||
# class TorchHijackForUnet:
|
||||
# """
|
||||
# This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||
# this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
||||
# """
|
||||
#
|
||||
# def __getattr__(self, item):
|
||||
# if item == 'cat':
|
||||
# return self.cat
|
||||
#
|
||||
# if hasattr(torch, item):
|
||||
# return getattr(torch, item)
|
||||
#
|
||||
# raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
#
|
||||
# def cat(self, tensors, *args, **kwargs):
|
||||
# if len(tensors) == 2:
|
||||
# a, b = tensors
|
||||
# if a.shape[-2:] != b.shape[-2:]:
|
||||
# a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||
#
|
||||
# tensors = (a, b)
|
||||
#
|
||||
# return torch.cat(tensors, *args, **kwargs)
|
||||
#
|
||||
#
|
||||
# th = TorchHijackForUnet()
|
||||
#
|
||||
#
|
||||
# # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||
# def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||
# """Always make sure inputs to unet are in correct dtype."""
|
||||
# if isinstance(cond, dict):
|
||||
# for y in cond.keys():
|
||||
# if isinstance(cond[y], list):
|
||||
# cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
# else:
|
||||
# cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||
#
|
||||
# with devices.autocast():
|
||||
# result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
||||
# if devices.unet_needs_upcast:
|
||||
# return result.float()
|
||||
# else:
|
||||
# return result
|
||||
#
|
||||
#
|
||||
# # Monkey patch to create timestep embed tensor on device, avoiding a block.
|
||||
# def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
|
||||
# """
|
||||
# Create sinusoidal timestep embeddings.
|
||||
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
# These may be fractional.
|
||||
# :param dim: the dimension of the output.
|
||||
# :param max_period: controls the minimum frequency of the embeddings.
|
||||
# :return: an [N x dim] Tensor of positional embeddings.
|
||||
# """
|
||||
# if not repeat_only:
|
||||
# half = dim // 2
|
||||
# freqs = torch.exp(
|
||||
# -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
||||
# )
|
||||
# args = timesteps[:, None].float() * freqs[None]
|
||||
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
# if dim % 2:
|
||||
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
# else:
|
||||
# embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
# return embedding
|
||||
#
|
||||
#
|
||||
# # Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
|
||||
# # Prevents a lot of unnecessary aten::copy_ calls
|
||||
# def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
|
||||
# # note: if no context is given, cross-attention defaults to self-attention
|
||||
# if not isinstance(context, list):
|
||||
# context = [context]
|
||||
# b, c, h, w = x.shape
|
||||
# x_in = x
|
||||
# x = self.norm(x)
|
||||
# if not self.use_linear:
|
||||
# x = self.proj_in(x)
|
||||
# x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||
# if self.use_linear:
|
||||
# x = self.proj_in(x)
|
||||
# for i, block in enumerate(self.transformer_blocks):
|
||||
# x = block(x, context=context[i])
|
||||
# if self.use_linear:
|
||||
# x = self.proj_out(x)
|
||||
# x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
||||
# if not self.use_linear:
|
||||
# x = self.proj_out(x)
|
||||
# return x + x_in
|
||||
#
|
||||
#
|
||||
# class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||
# def forward(self, x):
|
||||
# if devices.unet_needs_upcast:
|
||||
# return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||
# else:
|
||||
# return torch.nn.GELU.forward(self, x)
|
||||
#
|
||||
#
|
||||
# ddpm_edit_hijack = None
|
||||
# def hijack_ddpm_edit():
|
||||
# global ddpm_edit_hijack
|
||||
# if not ddpm_edit_hijack:
|
||||
# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
# ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
|
||||
#
|
||||
#
|
||||
# unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
||||
# CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
|
||||
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
#
|
||||
# if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||
# CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
# CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
# CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
#
|
||||
# first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||
# first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||
#
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
||||
# CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
||||
#
|
||||
#
|
||||
# def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
||||
# if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
||||
# dtype = torch.float32
|
||||
# else:
|
||||
# dtype = devices.dtype_unet
|
||||
# return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
||||
#
|
||||
#
|
||||
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||
# CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||
|
||||
@@ -10,7 +10,6 @@ import re
|
||||
import safetensors.torch
|
||||
from omegaconf import OmegaConf, ListConfig
|
||||
from urllib import request
|
||||
import ldm.modules.midas as midas
|
||||
import gc
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
@@ -415,89 +414,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
|
||||
def enable_midas_autodownload():
|
||||
"""
|
||||
Gives the ldm.modules.midas.api.load_model function automatic downloading.
|
||||
|
||||
When the 512-depth-ema model, and other future models like it, is loaded,
|
||||
it calls midas.api.load_model to load the associated midas depth model.
|
||||
This function applies a wrapper to download the model to the correct
|
||||
location automatically.
|
||||
"""
|
||||
|
||||
midas_path = os.path.join(paths.models_path, 'midas')
|
||||
|
||||
# stable-diffusion-stability-ai hard-codes the midas model path to
|
||||
# a location that differs from where other scripts using this model look.
|
||||
# HACK: Overriding the path here.
|
||||
for k, v in midas.api.ISL_PATHS.items():
|
||||
file_name = os.path.basename(v)
|
||||
midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
|
||||
|
||||
midas_urls = {
|
||||
"dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
||||
"midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
|
||||
"midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
|
||||
}
|
||||
|
||||
midas.api.load_model_inner = midas.api.load_model
|
||||
|
||||
def load_model_wrapper(model_type):
|
||||
path = midas.api.ISL_PATHS[model_type]
|
||||
if not os.path.exists(path):
|
||||
if not os.path.exists(midas_path):
|
||||
os.mkdir(midas_path)
|
||||
|
||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||
request.urlretrieve(midas_urls[model_type], path)
|
||||
print(f"{model_type} downloaded")
|
||||
|
||||
return midas.api.load_model_inner(model_type)
|
||||
|
||||
midas.api.load_model = load_model_wrapper
|
||||
pass
|
||||
|
||||
|
||||
def patch_given_betas():
|
||||
import ldm.models.diffusion.ddpm
|
||||
|
||||
def patched_register_schedule(*args, **kwargs):
|
||||
"""a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
|
||||
|
||||
if isinstance(args[1], ListConfig):
|
||||
args = (args[0], np.array(args[1]), *args[2:])
|
||||
|
||||
original_register_schedule(*args, **kwargs)
|
||||
|
||||
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
|
||||
pass
|
||||
|
||||
|
||||
def repair_config(sd_config, state_dict=None):
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
if hasattr(sd_config.model.params, 'unet_config'):
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
if hasattr(sd_config.model.params, 'first_stage_config'):
|
||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
||||
|
||||
# For UnCLIP-L, override the hardcoded karlo directory
|
||||
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
|
||||
karlo_path = os.path.join(paths.models_path, 'karlo')
|
||||
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
||||
|
||||
# Do not use checkpoint for inference.
|
||||
# This helps prevent extra performance overhead on checking parameters.
|
||||
# The perf overhead is about 100ms/it on 4090 for SDXL.
|
||||
if hasattr(sd_config.model.params, "network_config"):
|
||||
sd_config.model.params.network_config.params.use_checkpoint = False
|
||||
if hasattr(sd_config.model.params, "unet_config"):
|
||||
sd_config.model.params.unet_config.params.use_checkpoint = False
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
|
||||
@@ -1,137 +1,137 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared, paths, sd_disable_initialization, devices
|
||||
|
||||
sd_configs_path = shared.sd_configs_path
|
||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
||||
|
||||
|
||||
config_default = shared.sd_default_config
|
||||
# config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
|
||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
|
||||
|
||||
|
||||
def is_using_v_parameterization_for_sd2(state_dict):
|
||||
"""
|
||||
Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
||||
"""
|
||||
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
|
||||
device = devices.device
|
||||
|
||||
with sd_disable_initialization.DisableInitialization():
|
||||
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
image_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
model_channels=320,
|
||||
attention_resolutions=[4, 2, 1],
|
||||
num_res_blocks=2,
|
||||
channel_mult=[1, 2, 4, 4],
|
||||
num_head_channels=64,
|
||||
use_spatial_transformer=True,
|
||||
use_linear_in_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=1024,
|
||||
legacy=False
|
||||
)
|
||||
unet.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
||||
unet.load_state_dict(unet_sd, strict=True)
|
||||
unet.to(device=device, dtype=devices.dtype_unet)
|
||||
|
||||
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
||||
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
||||
|
||||
with devices.autocast():
|
||||
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
|
||||
|
||||
return out < -1
|
||||
|
||||
|
||||
def guess_model_config_from_state_dict(sd, filename):
|
||||
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||
|
||||
if "model.diffusion_model.x_embedder.proj.weight" in sd:
|
||||
return config_sd3
|
||||
|
||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_sdxl_inpainting
|
||||
else:
|
||||
return config_sdxl
|
||||
|
||||
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||
return config_sdxl_refiner
|
||||
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
return config_depth_model
|
||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||
return config_unclip
|
||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
|
||||
return config_unopenclip
|
||||
|
||||
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_sd2_inpainting
|
||||
# elif is_using_v_parameterization_for_sd2(sd):
|
||||
# return config_sd2v
|
||||
else:
|
||||
return config_sd2v
|
||||
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_inpainting
|
||||
if diffusion_model_input.shape[1] == 8:
|
||||
return config_instruct_pix2pix
|
||||
|
||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
||||
return config_alt_diffusion_m18
|
||||
return config_alt_diffusion
|
||||
|
||||
return config_default
|
||||
|
||||
|
||||
def find_checkpoint_config(state_dict, info):
|
||||
if info is None:
|
||||
return guess_model_config_from_state_dict(state_dict, "")
|
||||
|
||||
config = find_checkpoint_config_near_filename(info)
|
||||
if config is not None:
|
||||
return config
|
||||
|
||||
return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||
|
||||
|
||||
def find_checkpoint_config_near_filename(info):
|
||||
if info is None:
|
||||
return None
|
||||
|
||||
config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
# import os
|
||||
#
|
||||
# import torch
|
||||
#
|
||||
# from modules import shared, paths, sd_disable_initialization, devices
|
||||
#
|
||||
# sd_configs_path = shared.sd_configs_path
|
||||
# # sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
# # sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
||||
#
|
||||
#
|
||||
# config_default = shared.sd_default_config
|
||||
# # config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
# config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
# config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
# config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||
# config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||
# config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
|
||||
# config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
# config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||
# config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||
# config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||
# config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
# config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
# config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||
# config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
|
||||
#
|
||||
#
|
||||
# def is_using_v_parameterization_for_sd2(state_dict):
|
||||
# """
|
||||
# Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
||||
# """
|
||||
#
|
||||
# import ldm.modules.diffusionmodules.openaimodel
|
||||
#
|
||||
# device = devices.device
|
||||
#
|
||||
# with sd_disable_initialization.DisableInitialization():
|
||||
# unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
||||
# use_checkpoint=False,
|
||||
# use_fp16=False,
|
||||
# image_size=32,
|
||||
# in_channels=4,
|
||||
# out_channels=4,
|
||||
# model_channels=320,
|
||||
# attention_resolutions=[4, 2, 1],
|
||||
# num_res_blocks=2,
|
||||
# channel_mult=[1, 2, 4, 4],
|
||||
# num_head_channels=64,
|
||||
# use_spatial_transformer=True,
|
||||
# use_linear_in_transformer=True,
|
||||
# transformer_depth=1,
|
||||
# context_dim=1024,
|
||||
# legacy=False
|
||||
# )
|
||||
# unet.eval()
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
||||
# unet.load_state_dict(unet_sd, strict=True)
|
||||
# unet.to(device=device, dtype=devices.dtype_unet)
|
||||
#
|
||||
# test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
||||
# x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
||||
#
|
||||
# with devices.autocast():
|
||||
# out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
|
||||
#
|
||||
# return out < -1
|
||||
#
|
||||
#
|
||||
# def guess_model_config_from_state_dict(sd, filename):
|
||||
# sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||
# diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
# sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||
#
|
||||
# if "model.diffusion_model.x_embedder.proj.weight" in sd:
|
||||
# return config_sd3
|
||||
#
|
||||
# if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||
# if diffusion_model_input.shape[1] == 9:
|
||||
# return config_sdxl_inpainting
|
||||
# else:
|
||||
# return config_sdxl
|
||||
#
|
||||
# if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||
# return config_sdxl_refiner
|
||||
# elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
# return config_depth_model
|
||||
# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||
# return config_unclip
|
||||
# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
|
||||
# return config_unopenclip
|
||||
#
|
||||
# if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||
# if diffusion_model_input.shape[1] == 9:
|
||||
# return config_sd2_inpainting
|
||||
# # elif is_using_v_parameterization_for_sd2(sd):
|
||||
# # return config_sd2v
|
||||
# else:
|
||||
# return config_sd2v
|
||||
#
|
||||
# if diffusion_model_input is not None:
|
||||
# if diffusion_model_input.shape[1] == 9:
|
||||
# return config_inpainting
|
||||
# if diffusion_model_input.shape[1] == 8:
|
||||
# return config_instruct_pix2pix
|
||||
#
|
||||
# if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||
# if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
||||
# return config_alt_diffusion_m18
|
||||
# return config_alt_diffusion
|
||||
#
|
||||
# return config_default
|
||||
#
|
||||
#
|
||||
# def find_checkpoint_config(state_dict, info):
|
||||
# if info is None:
|
||||
# return guess_model_config_from_state_dict(state_dict, "")
|
||||
#
|
||||
# config = find_checkpoint_config_near_filename(info)
|
||||
# if config is not None:
|
||||
# return config
|
||||
#
|
||||
# return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||
#
|
||||
#
|
||||
# def find_checkpoint_config_near_filename(info):
|
||||
# if info is None:
|
||||
# return None
|
||||
#
|
||||
# config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
||||
# if os.path.exists(config):
|
||||
# return config
|
||||
#
|
||||
# return None
|
||||
#
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -6,7 +5,7 @@ if TYPE_CHECKING:
|
||||
from modules.sd_models import CheckpointInfo
|
||||
|
||||
|
||||
class WebuiSdModel(LatentDiffusion):
|
||||
class WebuiSdModel:
|
||||
"""This class is not actually instantinated, but its fields are created and fieeld by webui"""
|
||||
|
||||
lowvram: bool
|
||||
|
||||
@@ -1,115 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import sgm.models.diffusion
|
||||
import sgm.modules.diffusionmodules.denoiser_scaling
|
||||
import sgm.modules.diffusionmodules.discretizer
|
||||
from modules import devices, shared, prompt_parser
|
||||
from modules import torch_utils
|
||||
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||
|
||||
for embedder in self.conditioner.embedders:
|
||||
embedder.ucg_rate = 0.0
|
||||
|
||||
width = getattr(batch, 'width', 1024) or 1024
|
||||
height = getattr(batch, 'height', 1024) or 1024
|
||||
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||
|
||||
devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
|
||||
|
||||
sdxl_conds = {
|
||||
"txt": batch,
|
||||
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
||||
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
||||
}
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
||||
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
|
||||
if self.model.diffusion_model.in_channels == 9:
|
||||
x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||
|
||||
return self.model(x, t, cond, *args, **kwargs)
|
||||
|
||||
|
||||
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||
return x
|
||||
|
||||
|
||||
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||||
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
||||
|
||||
|
||||
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
||||
res = []
|
||||
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
||||
encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
||||
res.append(encoded)
|
||||
|
||||
return torch.cat(res, dim=1)
|
||||
|
||||
|
||||
def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
||||
return embedder.tokenize(texts)
|
||||
|
||||
raise AssertionError('no tokenizer available')
|
||||
|
||||
|
||||
|
||||
def process_texts(self, texts):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||
return embedder.process_texts(texts)
|
||||
|
||||
|
||||
def get_target_prompt_token_count(self, token_count):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
||||
return embedder.get_target_prompt_token_count(token_count)
|
||||
|
||||
|
||||
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||
sgm.modules.GeneralConditioner.tokenize = tokenize
|
||||
sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||
|
||||
|
||||
def extend_sdxl(model):
|
||||
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||
|
||||
dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
||||
model.model.diffusion_model.dtype = dtype
|
||||
model.model.conditioning_key = 'crossattn'
|
||||
model.cond_stage_key = 'txt'
|
||||
# model.cond_stage_model will be set in sd_hijack
|
||||
|
||||
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||
|
||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
||||
|
||||
model.conditioner.wrapped = torch.nn.Module()
|
||||
|
||||
|
||||
sgm.modules.attention.print = shared.ldm_print
|
||||
sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||
sgm.modules.encoders.modules.print = shared.ldm_print
|
||||
|
||||
# this gets the code to load the vanilla attention that we override
|
||||
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
||||
# from __future__ import annotations
|
||||
#
|
||||
# import torch
|
||||
#
|
||||
# import sgm.models.diffusion
|
||||
# import sgm.modules.diffusionmodules.denoiser_scaling
|
||||
# import sgm.modules.diffusionmodules.discretizer
|
||||
# from modules import devices, shared, prompt_parser
|
||||
# from modules import torch_utils
|
||||
#
|
||||
# from backend import memory_management
|
||||
#
|
||||
#
|
||||
# def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||
#
|
||||
# for embedder in self.conditioner.embedders:
|
||||
# embedder.ucg_rate = 0.0
|
||||
#
|
||||
# width = getattr(batch, 'width', 1024) or 1024
|
||||
# height = getattr(batch, 'height', 1024) or 1024
|
||||
# is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||
# aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||
#
|
||||
# devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
|
||||
#
|
||||
# sdxl_conds = {
|
||||
# "txt": batch,
|
||||
# "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
# "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
||||
# "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
# "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
||||
# }
|
||||
#
|
||||
# force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
||||
# c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
||||
#
|
||||
# return c
|
||||
#
|
||||
#
|
||||
# def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
|
||||
# if self.model.diffusion_model.in_channels == 9:
|
||||
# x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||
#
|
||||
# return self.model(x, t, cond, *args, **kwargs)
|
||||
#
|
||||
#
|
||||
# def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||
# return x
|
||||
#
|
||||
#
|
||||
# sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||
# sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||||
# sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
||||
#
|
||||
#
|
||||
# def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
||||
# res = []
|
||||
#
|
||||
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
||||
# encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
||||
# res.append(encoded)
|
||||
#
|
||||
# return torch.cat(res, dim=1)
|
||||
#
|
||||
#
|
||||
# def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
||||
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
||||
# return embedder.tokenize(texts)
|
||||
#
|
||||
# raise AssertionError('no tokenizer available')
|
||||
#
|
||||
#
|
||||
#
|
||||
# def process_texts(self, texts):
|
||||
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||
# return embedder.process_texts(texts)
|
||||
#
|
||||
#
|
||||
# def get_target_prompt_token_count(self, token_count):
|
||||
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
||||
# return embedder.get_target_prompt_token_count(token_count)
|
||||
#
|
||||
#
|
||||
# # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||
# sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||
# sgm.modules.GeneralConditioner.tokenize = tokenize
|
||||
# sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||
# sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||
#
|
||||
#
|
||||
# def extend_sdxl(model):
|
||||
# """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||
#
|
||||
# dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
||||
# model.model.diffusion_model.dtype = dtype
|
||||
# model.model.conditioning_key = 'crossattn'
|
||||
# model.cond_stage_key = 'txt'
|
||||
# # model.cond_stage_model will be set in sd_hijack
|
||||
#
|
||||
# model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||
#
|
||||
# discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||
# model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
||||
#
|
||||
# model.conditioner.wrapped = torch.nn.Module()
|
||||
#
|
||||
#
|
||||
# sgm.modules.attention.print = shared.ldm_print
|
||||
# sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
# sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||
# sgm.modules.encoders.modules.print = shared.ldm_print
|
||||
#
|
||||
# # this gets the code to load the vanilla attention that we override
|
||||
# sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||
# sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
||||
|
||||
@@ -35,9 +35,7 @@ def refresh_vae_list():
|
||||
|
||||
|
||||
def cross_attention_optimizations():
|
||||
import modules.sd_hijack
|
||||
|
||||
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
|
||||
return ["Automatic"]
|
||||
|
||||
|
||||
def sd_unet_items():
|
||||
|
||||
@@ -1,245 +1,243 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||
from torchvision import transforms
|
||||
from collections import defaultdict
|
||||
from random import shuffle, choices
|
||||
|
||||
import random
|
||||
import tqdm
|
||||
from modules import devices, shared, images
|
||||
import re
|
||||
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||
|
||||
|
||||
class DatasetEntry:
|
||||
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
|
||||
self.filename = filename
|
||||
self.filename_text = filename_text
|
||||
self.weight = weight
|
||||
self.latent_dist = latent_dist
|
||||
self.latent_sample = latent_sample
|
||||
self.cond = cond
|
||||
self.cond_text = cond_text
|
||||
self.pixel_values = pixel_values
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
self.dataset = []
|
||||
|
||||
with open(template_file, "r") as file:
|
||||
lines = [x.strip() for x in file.readlines()]
|
||||
|
||||
self.lines = lines
|
||||
|
||||
assert data_root, 'dataset directory not specified'
|
||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
|
||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.tag_drop_out = tag_drop_out
|
||||
groups = defaultdict(list)
|
||||
|
||||
print("Preparing dataset...")
|
||||
for path in tqdm.tqdm(self.image_paths):
|
||||
alpha_channel = None
|
||||
if shared.state.interrupted:
|
||||
raise Exception("interrupted")
|
||||
try:
|
||||
image = images.read(path)
|
||||
#Currently does not work for single color transparency
|
||||
#We would need to read image.info['transparency'] for that
|
||||
if use_weight and 'A' in image.getbands():
|
||||
alpha_channel = image.getchannel('A')
|
||||
image = image.convert('RGB')
|
||||
if not varsize:
|
||||
image = image.resize((width, height), PIL.Image.BICUBIC)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
text_filename = f"{os.path.splitext(path)[0]}.txt"
|
||||
filename = os.path.basename(path)
|
||||
|
||||
if os.path.exists(text_filename):
|
||||
with open(text_filename, "r", encoding="utf8") as file:
|
||||
filename_text = file.read()
|
||||
else:
|
||||
filename_text = os.path.splitext(filename)[0]
|
||||
filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
||||
if re_word:
|
||||
tokens = re_word.findall(filename_text)
|
||||
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||
|
||||
npimage = np.array(image).astype(np.uint8)
|
||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||
latent_sample = None
|
||||
|
||||
with devices.autocast():
|
||||
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||
|
||||
#Perform latent sampling, even for random sampling.
|
||||
#We need the sample dimensions for the weights
|
||||
if latent_sampling_method == "deterministic":
|
||||
if isinstance(latent_dist, DiagonalGaussianDistribution):
|
||||
# Works only for DiagonalGaussianDistribution
|
||||
latent_dist.std = 0
|
||||
else:
|
||||
latent_sampling_method = "once"
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
|
||||
if use_weight and alpha_channel is not None:
|
||||
channels, *latent_size = latent_sample.shape
|
||||
weight_img = alpha_channel.resize(latent_size)
|
||||
npweight = np.array(weight_img).astype(np.float32)
|
||||
#Repeat for every channel in the latent sample
|
||||
weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
|
||||
#Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
|
||||
weight -= weight.min()
|
||||
weight /= weight.mean()
|
||||
elif use_weight:
|
||||
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
|
||||
weight = torch.ones(latent_sample.shape)
|
||||
else:
|
||||
weight = None
|
||||
|
||||
if latent_sampling_method == "random":
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
||||
else:
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
|
||||
|
||||
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
entry.cond_text = self.create_text(filename_text)
|
||||
|
||||
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
with devices.autocast():
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
groups[image.size].append(len(self.dataset))
|
||||
self.dataset.append(entry)
|
||||
del torchdata
|
||||
del latent_dist
|
||||
del latent_sample
|
||||
del weight
|
||||
|
||||
self.length = len(self.dataset)
|
||||
self.groups = list(groups.values())
|
||||
assert self.length > 0, "No images have been found in the dataset."
|
||||
self.batch_size = min(batch_size, self.length)
|
||||
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||
self.latent_sampling_method = latent_sampling_method
|
||||
|
||||
if len(groups) > 1:
|
||||
print("Buckets:")
|
||||
for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
|
||||
print(f" {w}x{h}: {len(ids)}")
|
||||
print()
|
||||
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
tags = filename_text.split(',')
|
||||
if self.tag_drop_out != 0:
|
||||
tags = [t for t in tags if random.random() > self.tag_drop_out]
|
||||
if self.shuffle_tags:
|
||||
random.shuffle(tags)
|
||||
text = text.replace("[filewords]", ','.join(tags))
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
return text
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
entry = self.dataset[i]
|
||||
if self.tag_drop_out != 0 or self.shuffle_tags:
|
||||
entry.cond_text = self.create_text(entry.filename_text)
|
||||
if self.latent_sampling_method == "random":
|
||||
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||
return entry
|
||||
|
||||
|
||||
class GroupedBatchSampler(Sampler):
|
||||
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||
super().__init__(data_source)
|
||||
|
||||
n = len(data_source)
|
||||
self.groups = data_source.groups
|
||||
self.len = n_batch = n // batch_size
|
||||
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
||||
self.base = [int(e) // batch_size for e in expected]
|
||||
self.n_rand_batches = nrb = n_batch - sum(self.base)
|
||||
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
def __iter__(self):
|
||||
b = self.batch_size
|
||||
|
||||
for g in self.groups:
|
||||
shuffle(g)
|
||||
|
||||
batches = []
|
||||
for g in self.groups:
|
||||
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
||||
for _ in range(self.n_rand_batches):
|
||||
rand_group = choices(self.groups, self.probs)[0]
|
||||
batches.append(choices(rand_group, k=b))
|
||||
|
||||
shuffle(batches)
|
||||
|
||||
yield from batches
|
||||
|
||||
|
||||
class PersonalizedDataLoader(DataLoader):
|
||||
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||
if latent_sampling_method == "random":
|
||||
self.collate_fn = collate_wrapper_random
|
||||
else:
|
||||
self.collate_fn = collate_wrapper
|
||||
|
||||
|
||||
class BatchLoader:
|
||||
def __init__(self, data):
|
||||
self.cond_text = [entry.cond_text for entry in data]
|
||||
self.cond = [entry.cond for entry in data]
|
||||
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||
if all(entry.weight is not None for entry in data):
|
||||
self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
|
||||
else:
|
||||
self.weight = None
|
||||
#self.emb_index = [entry.emb_index for entry in data]
|
||||
#print(self.latent_sample.device)
|
||||
|
||||
def pin_memory(self):
|
||||
self.latent_sample = self.latent_sample.pin_memory()
|
||||
return self
|
||||
|
||||
def collate_wrapper(batch):
|
||||
return BatchLoader(batch)
|
||||
|
||||
class BatchLoaderRandom(BatchLoader):
|
||||
def __init__(self, data):
|
||||
super().__init__(data)
|
||||
|
||||
def pin_memory(self):
|
||||
return self
|
||||
|
||||
def collate_wrapper_random(batch):
|
||||
return BatchLoaderRandom(batch)
|
||||
# import os
|
||||
# import numpy as np
|
||||
# import PIL
|
||||
# import torch
|
||||
# from torch.utils.data import Dataset, DataLoader, Sampler
|
||||
# from torchvision import transforms
|
||||
# from collections import defaultdict
|
||||
# from random import shuffle, choices
|
||||
#
|
||||
# import random
|
||||
# import tqdm
|
||||
# from modules import devices, shared, images
|
||||
# import re
|
||||
#
|
||||
# re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||
#
|
||||
#
|
||||
# class DatasetEntry:
|
||||
# def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
|
||||
# self.filename = filename
|
||||
# self.filename_text = filename_text
|
||||
# self.weight = weight
|
||||
# self.latent_dist = latent_dist
|
||||
# self.latent_sample = latent_sample
|
||||
# self.cond = cond
|
||||
# self.cond_text = cond_text
|
||||
# self.pixel_values = pixel_values
|
||||
#
|
||||
#
|
||||
# class PersonalizedBase(Dataset):
|
||||
# def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
|
||||
# re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
|
||||
#
|
||||
# self.placeholder_token = placeholder_token
|
||||
#
|
||||
# self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
#
|
||||
# self.dataset = []
|
||||
#
|
||||
# with open(template_file, "r") as file:
|
||||
# lines = [x.strip() for x in file.readlines()]
|
||||
#
|
||||
# self.lines = lines
|
||||
#
|
||||
# assert data_root, 'dataset directory not specified'
|
||||
# assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
# assert os.listdir(data_root), "Dataset directory is empty"
|
||||
#
|
||||
# self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||
#
|
||||
# self.shuffle_tags = shuffle_tags
|
||||
# self.tag_drop_out = tag_drop_out
|
||||
# groups = defaultdict(list)
|
||||
#
|
||||
# print("Preparing dataset...")
|
||||
# for path in tqdm.tqdm(self.image_paths):
|
||||
# alpha_channel = None
|
||||
# if shared.state.interrupted:
|
||||
# raise Exception("interrupted")
|
||||
# try:
|
||||
# image = images.read(path)
|
||||
# #Currently does not work for single color transparency
|
||||
# #We would need to read image.info['transparency'] for that
|
||||
# if use_weight and 'A' in image.getbands():
|
||||
# alpha_channel = image.getchannel('A')
|
||||
# image = image.convert('RGB')
|
||||
# if not varsize:
|
||||
# image = image.resize((width, height), PIL.Image.BICUBIC)
|
||||
# except Exception:
|
||||
# continue
|
||||
#
|
||||
# text_filename = f"{os.path.splitext(path)[0]}.txt"
|
||||
# filename = os.path.basename(path)
|
||||
#
|
||||
# if os.path.exists(text_filename):
|
||||
# with open(text_filename, "r", encoding="utf8") as file:
|
||||
# filename_text = file.read()
|
||||
# else:
|
||||
# filename_text = os.path.splitext(filename)[0]
|
||||
# filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
||||
# if re_word:
|
||||
# tokens = re_word.findall(filename_text)
|
||||
# filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||
#
|
||||
# npimage = np.array(image).astype(np.uint8)
|
||||
# npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
#
|
||||
# torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||
# latent_sample = None
|
||||
#
|
||||
# with devices.autocast():
|
||||
# latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||
#
|
||||
# #Perform latent sampling, even for random sampling.
|
||||
# #We need the sample dimensions for the weights
|
||||
# if latent_sampling_method == "deterministic":
|
||||
# if isinstance(latent_dist, DiagonalGaussianDistribution):
|
||||
# # Works only for DiagonalGaussianDistribution
|
||||
# latent_dist.std = 0
|
||||
# else:
|
||||
# latent_sampling_method = "once"
|
||||
# latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
#
|
||||
# if use_weight and alpha_channel is not None:
|
||||
# channels, *latent_size = latent_sample.shape
|
||||
# weight_img = alpha_channel.resize(latent_size)
|
||||
# npweight = np.array(weight_img).astype(np.float32)
|
||||
# #Repeat for every channel in the latent sample
|
||||
# weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
|
||||
# #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
|
||||
# weight -= weight.min()
|
||||
# weight /= weight.mean()
|
||||
# elif use_weight:
|
||||
# #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
|
||||
# weight = torch.ones(latent_sample.shape)
|
||||
# else:
|
||||
# weight = None
|
||||
#
|
||||
# if latent_sampling_method == "random":
|
||||
# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
||||
# else:
|
||||
# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
|
||||
#
|
||||
# if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
# entry.cond_text = self.create_text(filename_text)
|
||||
#
|
||||
# if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
# with devices.autocast():
|
||||
# entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
# groups[image.size].append(len(self.dataset))
|
||||
# self.dataset.append(entry)
|
||||
# del torchdata
|
||||
# del latent_dist
|
||||
# del latent_sample
|
||||
# del weight
|
||||
#
|
||||
# self.length = len(self.dataset)
|
||||
# self.groups = list(groups.values())
|
||||
# assert self.length > 0, "No images have been found in the dataset."
|
||||
# self.batch_size = min(batch_size, self.length)
|
||||
# self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||
# self.latent_sampling_method = latent_sampling_method
|
||||
#
|
||||
# if len(groups) > 1:
|
||||
# print("Buckets:")
|
||||
# for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
|
||||
# print(f" {w}x{h}: {len(ids)}")
|
||||
# print()
|
||||
#
|
||||
# def create_text(self, filename_text):
|
||||
# text = random.choice(self.lines)
|
||||
# tags = filename_text.split(',')
|
||||
# if self.tag_drop_out != 0:
|
||||
# tags = [t for t in tags if random.random() > self.tag_drop_out]
|
||||
# if self.shuffle_tags:
|
||||
# random.shuffle(tags)
|
||||
# text = text.replace("[filewords]", ','.join(tags))
|
||||
# text = text.replace("[name]", self.placeholder_token)
|
||||
# return text
|
||||
#
|
||||
# def __len__(self):
|
||||
# return self.length
|
||||
#
|
||||
# def __getitem__(self, i):
|
||||
# entry = self.dataset[i]
|
||||
# if self.tag_drop_out != 0 or self.shuffle_tags:
|
||||
# entry.cond_text = self.create_text(entry.filename_text)
|
||||
# if self.latent_sampling_method == "random":
|
||||
# entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||
# return entry
|
||||
#
|
||||
#
|
||||
# class GroupedBatchSampler(Sampler):
|
||||
# def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||
# super().__init__(data_source)
|
||||
#
|
||||
# n = len(data_source)
|
||||
# self.groups = data_source.groups
|
||||
# self.len = n_batch = n // batch_size
|
||||
# expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
||||
# self.base = [int(e) // batch_size for e in expected]
|
||||
# self.n_rand_batches = nrb = n_batch - sum(self.base)
|
||||
# self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
||||
# self.batch_size = batch_size
|
||||
#
|
||||
# def __len__(self):
|
||||
# return self.len
|
||||
#
|
||||
# def __iter__(self):
|
||||
# b = self.batch_size
|
||||
#
|
||||
# for g in self.groups:
|
||||
# shuffle(g)
|
||||
#
|
||||
# batches = []
|
||||
# for g in self.groups:
|
||||
# batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
||||
# for _ in range(self.n_rand_batches):
|
||||
# rand_group = choices(self.groups, self.probs)[0]
|
||||
# batches.append(choices(rand_group, k=b))
|
||||
#
|
||||
# shuffle(batches)
|
||||
#
|
||||
# yield from batches
|
||||
#
|
||||
#
|
||||
# class PersonalizedDataLoader(DataLoader):
|
||||
# def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||
# super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||
# if latent_sampling_method == "random":
|
||||
# self.collate_fn = collate_wrapper_random
|
||||
# else:
|
||||
# self.collate_fn = collate_wrapper
|
||||
#
|
||||
#
|
||||
# class BatchLoader:
|
||||
# def __init__(self, data):
|
||||
# self.cond_text = [entry.cond_text for entry in data]
|
||||
# self.cond = [entry.cond for entry in data]
|
||||
# self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||
# if all(entry.weight is not None for entry in data):
|
||||
# self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
|
||||
# else:
|
||||
# self.weight = None
|
||||
# #self.emb_index = [entry.emb_index for entry in data]
|
||||
# #print(self.latent_sample.device)
|
||||
#
|
||||
# def pin_memory(self):
|
||||
# self.latent_sample = self.latent_sample.pin_memory()
|
||||
# return self
|
||||
#
|
||||
# def collate_wrapper(batch):
|
||||
# return BatchLoader(batch)
|
||||
#
|
||||
# class BatchLoaderRandom(BatchLoader):
|
||||
# def __init__(self, data):
|
||||
# super().__init__(data)
|
||||
#
|
||||
# def pin_memory(self):
|
||||
# return self
|
||||
#
|
||||
# def collate_wrapper_random(batch):
|
||||
# return BatchLoaderRandom(batch)
|
||||
|
||||
Reference in New Issue
Block a user