From bccf9fb23a2cbc9d42091a856b3d35b2c4414d51 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Mon, 5 Aug 2024 03:58:34 -0700 Subject: [PATCH] Free WebUI from its Prison Congratulations WebUI. Say Hello to freedom. --- extensions-builtin/LDSR/ldsr_model_arch.py | 250 --- extensions-builtin/LDSR/preload.py | 6 - extensions-builtin/LDSR/scripts/ldsr_model.py | 70 - .../LDSR/sd_hijack_autoencoder.py | 293 ---- extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 1443 ----------------- extensions-builtin/LDSR/vqvae_quantize.py | 147 -- .../legacy_preprocessors/preprocessor.py | 5 +- .../lib_controlnet/utils.py | 3 +- modules/api/api.py | 3 +- modules/hypernetworks/hypernetwork.py | 2 +- modules/initialize.py | 34 +- modules/launch_utils.py | 12 +- modules/paths.py | 45 +- modules/processing.py | 25 +- modules/safe.py | 390 ++--- modules/sd_disable_initialization.py | 464 +++--- modules/sd_hijack.py | 419 ++--- modules/sd_hijack_checkpoint.py | 92 +- modules/sd_hijack_optimizations.py | 1354 ++++++++-------- modules/sd_hijack_unet.py | 308 ++-- modules/sd_models.py | 81 +- modules/sd_models_config.py | 274 ++-- modules/sd_models_types.py | 3 +- modules/sd_models_xl.py | 230 +-- modules/shared_items.py | 4 +- modules/textual_inversion/dataset.py | 488 +++--- 26 files changed, 2053 insertions(+), 4392 deletions(-) delete mode 100644 extensions-builtin/LDSR/ldsr_model_arch.py delete mode 100644 extensions-builtin/LDSR/preload.py delete mode 100644 extensions-builtin/LDSR/scripts/ldsr_model.py delete mode 100644 extensions-builtin/LDSR/sd_hijack_autoencoder.py delete mode 100644 extensions-builtin/LDSR/sd_hijack_ddpm_v1.py delete mode 100644 extensions-builtin/LDSR/vqvae_quantize.py diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py deleted file mode 100644 index 7cac36ce..00000000 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ /dev/null @@ -1,250 +0,0 @@ -import os -import gc -import time - -import numpy as np -import torch -import torchvision -from PIL import Image -from einops import rearrange, repeat -from omegaconf import OmegaConf -import safetensors.torch - -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.util import instantiate_from_config, ismap -from modules import shared, sd_hijack, devices - -cached_ldsr_model: torch.nn.Module = None - - -# Create LDSR Class -class LDSR: - def load_model_from_config(self, half_attention): - global cached_ldsr_model - - if shared.opts.ldsr_cached and cached_ldsr_model is not None: - print("Loading model from cache") - model: torch.nn.Module = cached_ldsr_model - else: - print(f"Loading model from {self.modelPath}") - _, extension = os.path.splitext(self.modelPath) - if extension.lower() == ".safetensors": - pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu") - else: - pl_sd = torch.load(self.modelPath, map_location="cpu") - sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd - config = OmegaConf.load(self.yamlPath) - config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" - model: torch.nn.Module = instantiate_from_config(config.model) - model.load_state_dict(sd, strict=False) - model = model.to(shared.device) - if half_attention: - model = model.half() - if shared.cmd_opts.opt_channelslast: - model = model.to(memory_format=torch.channels_last) - - sd_hijack.model_hijack.hijack(model) # apply optimization - model.eval() - - if shared.opts.ldsr_cached: - cached_ldsr_model = model - - return {"model": model} - - def __init__(self, model_path, yaml_path): - self.modelPath = model_path - self.yamlPath = yaml_path - - @staticmethod - def run(model, selected_path, custom_steps, eta): - example = get_cond(selected_path) - - n_runs = 1 - guider = None - ckwargs = None - ddim_use_x0_pred = False - temperature = 1. - eta = eta - custom_shape = None - - height, width = example["image"].shape[1:3] - split_input = height >= 128 and width >= 128 - - if split_input: - ks = 128 - stride = 64 - vqf = 4 # - model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), - "vqf": vqf, - "patch_distributed_vq": True, - "tie_braker": False, - "clip_max_weight": 0.5, - "clip_min_weight": 0.01, - "clip_max_tie_weight": 0.5, - "clip_min_tie_weight": 0.01} - else: - if hasattr(model, "split_input_params"): - delattr(model, "split_input_params") - - x_t = None - logs = None - for _ in range(n_runs): - if custom_shape is not None: - x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) - x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) - - logs = make_convolutional_sample(example, model, - custom_steps=custom_steps, - eta=eta, quantize_x0=False, - custom_shape=custom_shape, - temperature=temperature, noise_dropout=0., - corrector=guider, corrector_kwargs=ckwargs, x_T=x_t, - ddim_use_x0_pred=ddim_use_x0_pred - ) - return logs - - def super_resolution(self, image, steps=100, target_scale=2, half_attention=False): - model = self.load_model_from_config(half_attention) - - # Run settings - diffusion_steps = int(steps) - eta = 1.0 - - - gc.collect() - devices.torch_gc() - - im_og = image - width_og, height_og = im_og.size - # If we can adjust the max upscale size, then the 4 below should be our variable - down_sample_rate = target_scale / 4 - wd = width_og * down_sample_rate - hd = height_og * down_sample_rate - width_downsampled_pre = int(np.ceil(wd)) - height_downsampled_pre = int(np.ceil(hd)) - - if down_sample_rate != 1: - print( - f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]') - im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) - else: - print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") - - # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts - pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size - im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) - - logs = self.run(model["model"], im_padded, diffusion_steps, eta) - - sample = logs["sample"] - sample = sample.detach().cpu() - sample = torch.clamp(sample, -1., 1.) - sample = (sample + 1.) / 2. * 255 - sample = sample.numpy().astype(np.uint8) - sample = np.transpose(sample, (0, 2, 3, 1)) - a = Image.fromarray(sample[0]) - - # remove padding - a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4)) - - del model - gc.collect() - devices.torch_gc() - - return a - - -def get_cond(selected_path): - example = {} - up_f = 4 - c = selected_path.convert('RGB') - c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) - c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], - antialias=True) - c_up = rearrange(c_up, '1 c h w -> 1 h w c') - c = rearrange(c, '1 c h w -> 1 h w c') - c = 2. * c - 1. - - c = c.to(shared.device) - example["LR_image"] = c - example["image"] = c_up - - return example - - -@torch.no_grad() -def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, - mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None, - corrector_kwargs=None, x_t=None - ): - ddim = DDIMSampler(model) - bs = shape[0] - shape = shape[1:] - print(f"Sampling with eta = {eta}; steps: {steps}") - samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, - normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, - mask=mask, x0=x0, temperature=temperature, verbose=False, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, x_t=x_t) - - return samples, intermediates - - -@torch.no_grad() -def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, - corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): - log = {} - - z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=not (hasattr(model, 'split_input_params') - and model.cond_stage_key == 'coordinates_bbox'), - return_original_cond=True) - - if custom_shape is not None: - z = torch.randn(custom_shape) - print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") - - z0 = None - - log["input"] = x - log["reconstruction"] = xrec - - if ismap(xc): - log["original_conditioning"] = model.to_rgb(xc) - if hasattr(model, 'cond_stage_key'): - log[model.cond_stage_key] = model.to_rgb(xc) - - else: - log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) - if model.cond_stage_model: - log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) - if model.cond_stage_key == 'class_label': - log[model.cond_stage_key] = xc[model.cond_stage_key] - - with model.ema_scope("Plotting"): - t0 = time.time() - - sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, - eta=eta, - quantize_x0=quantize_x0, mask=None, x0=z0, - temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs, - x_t=x_T) - t1 = time.time() - - if ddim_use_x0_pred: - sample = intermediates['pred_x0'][-1] - - x_sample = model.decode_first_stage(sample) - - try: - x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) - log["sample_noquant"] = x_sample_noquant - log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) - except Exception: - pass - - log["sample"] = x_sample - log["time"] = t1 - t0 - - return log diff --git a/extensions-builtin/LDSR/preload.py b/extensions-builtin/LDSR/preload.py deleted file mode 100644 index d746007c..00000000 --- a/extensions-builtin/LDSR/preload.py +++ /dev/null @@ -1,6 +0,0 @@ -import os -from modules import paths - - -def preload(parser): - parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR')) diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py deleted file mode 100644 index e9bb9271..00000000 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ /dev/null @@ -1,70 +0,0 @@ -import os - -from modules.modelloader import load_file_from_url -from modules.upscaler import Upscaler, UpscalerData -from modules_forge.utils import prepare_free_memory -from ldsr_model_arch import LDSR -from modules import shared, script_callbacks, errors -import sd_hijack_autoencoder # noqa: F401 -import sd_hijack_ddpm_v1 # noqa: F401 - - -class UpscalerLDSR(Upscaler): - def __init__(self, user_path): - self.name = "LDSR" - self.user_path = user_path - self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" - self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" - super().__init__() - scaler_data = UpscalerData("LDSR", None, self) - self.scalers = [scaler_data] - - def load_model(self, path: str): - # Remove incorrect project.yaml file if too big - yaml_path = os.path.join(self.model_path, "project.yaml") - old_model_path = os.path.join(self.model_path, "model.pth") - new_model_path = os.path.join(self.model_path, "model.ckpt") - - local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"]) - local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None) - local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None) - local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None) - - if os.path.exists(yaml_path): - statinfo = os.stat(yaml_path) - if statinfo.st_size >= 10485760: - print("Removing invalid LDSR YAML file.") - os.remove(yaml_path) - - if os.path.exists(old_model_path): - print("Renaming model from model.pth to model.ckpt") - os.rename(old_model_path, new_model_path) - - if local_safetensors_path is not None and os.path.exists(local_safetensors_path): - model = local_safetensors_path - else: - model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt") - - yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml") - - return LDSR(model, yaml) - - def do_upscale(self, img, path): - prepare_free_memory(aggressive=True) - try: - ldsr = self.load_model(path) - except Exception: - errors.report(f"Failed loading LDSR model {path}", exc_info=True) - return img - ddim_steps = shared.opts.ldsr_steps - return ldsr.super_resolution(img, ddim_steps, self.scale) - - -def on_ui_settings(): - import gradio as gr - - shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling"))) - shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling"))) - - -script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py deleted file mode 100644 index c29d274d..00000000 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ /dev/null @@ -1,293 +0,0 @@ -# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo -# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo -# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder -import numpy as np -import torch -import pytorch_lightning as pl -import torch.nn.functional as F -from contextlib import contextmanager - -from torch.optim.lr_scheduler import LambdaLR - -from ldm.modules.ema import LitEma -from vqvae_quantize import VectorQuantizer2 as VectorQuantizer -from ldm.modules.diffusionmodules.model import Encoder, Decoder -from ldm.util import instantiate_from_config - -import ldm.models.autoencoder -from packaging import version - -class VQModel(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=None, - image_key="image", - colorize_nlabels=None, - monitor=None, - batch_resize_range=None, - scheduler_config=None, - lr_g_factor=1.0, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - use_ema=False - ): - super().__init__() - self.embed_dim = embed_dim - self.n_embed = n_embed - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, - remap=remap, - sane_index_shape=sane_index_shape) - self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - self.batch_resize_range = batch_resize_range - if self.batch_resize_range is not None: - print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") - - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or []) - self.scheduler_config = scheduler_config - self.lr_g_factor = lr_g_factor - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.parameters()) - self.model_ema.copy_to(self) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def init_from_ckpt(self, path, ignore_keys=None): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys or []: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if missing: - print(f"Missing Keys: {missing}") - if unexpected: - print(f"Unexpected Keys: {unexpected}") - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self) - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - quant, emb_loss, info = self.quantize(h) - return quant, emb_loss, info - - def encode_to_prequant(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, quant): - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - def decode_code(self, code_b): - quant_b = self.quantize.embed_code(code_b) - dec = self.decode(quant_b) - return dec - - def forward(self, input, return_pred_indices=False): - quant, diff, (_,_,ind) = self.encode(input) - dec = self.decode(quant) - if return_pred_indices: - return dec, diff, ind - return dec, diff - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - if self.batch_resize_range is not None: - lower_size = self.batch_resize_range[0] - upper_size = self.batch_resize_range[1] - if self.global_step <= 4: - # do the first few batches with max size to avoid later oom - new_resize = upper_size - else: - new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) - if new_resize != x.shape[2]: - x = F.interpolate(x, size=new_resize, mode="bicubic") - x = x.detach() - return x - - def training_step(self, batch, batch_idx, optimizer_idx): - # https://github.com/pytorch/pytorch/issues/37142 - # try not to fool the heuristics - x = self.get_input(batch, self.image_key) - xrec, qloss, ind = self(x, return_pred_indices=True) - - if optimizer_idx == 0: - # autoencode - aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train", - predicted_indices=ind) - - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) - return aeloss - - if optimizer_idx == 1: - # discriminator - discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) - return discloss - - def validation_step(self, batch, batch_idx): - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - self._validation_step(batch, batch_idx, suffix="_ema") - return log_dict - - def _validation_step(self, batch, batch_idx, suffix=""): - x = self.get_input(batch, self.image_key) - xrec, qloss, ind = self(x, return_pred_indices=True) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, - self.global_step, - last_layer=self.get_last_layer(), - split="val"+suffix, - predicted_indices=ind - ) - - discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, - self.global_step, - last_layer=self.get_last_layer(), - split="val"+suffix, - predicted_indices=ind - ) - rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] - self.log(f"val{suffix}/rec_loss", rec_loss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) - self.log(f"val{suffix}/aeloss", aeloss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) - if version.parse(pl.__version__) >= version.parse('1.4.0'): - del log_dict_ae[f"val{suffix}/rec_loss"] - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr_d = self.learning_rate - lr_g = self.lr_g_factor*self.learning_rate - print("lr_d", lr_d) - print("lr_g", lr_g) - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quantize.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr_g, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr_d, betas=(0.5, 0.9)) - - if self.scheduler_config is not None: - scheduler = instantiate_from_config(self.scheduler_config) - - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }, - { - 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }, - ] - return [opt_ae, opt_disc], scheduler - return [opt_ae, opt_disc], [] - - def get_last_layer(self): - return self.decoder.conv_out.weight - - def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = {} - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if only_inputs: - log["inputs"] = x - return log - xrec, _ = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["inputs"] = x - log["reconstructions"] = xrec - if plot_ema: - with self.ema_scope(): - xrec_ema, _ = self(x) - if x.shape[1] > 3: - xrec_ema = self.to_rgb(xrec_ema) - log["reconstructions_ema"] = xrec_ema - return log - - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. - return x - - -class VQModelInterface(VQModel): - def __init__(self, embed_dim, *args, **kwargs): - super().__init__(*args, embed_dim=embed_dim, **kwargs) - self.embed_dim = embed_dim - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, h, force_not_quantize=False): - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - -ldm.models.autoencoder.VQModel = VQModel -ldm.models.autoencoder.VQModelInterface = VQModelInterface diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py deleted file mode 100644 index 51ab1821..00000000 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ /dev/null @@ -1,1443 +0,0 @@ -# This script is copied from the compvis/stable-diffusion repo (aka the SD V1 repo) -# Original filename: ldm/models/diffusion/ddpm.py -# The purpose to reinstate the old DDPM logic which works with VQ, whereas the V2 one doesn't -# Some models such as LDSR require VQ to work correctly -# The classes are suffixed with "V1" and added back to the "ldm.models.diffusion.ddpm" module - -import torch -import torch.nn as nn -import numpy as np -import pytorch_lightning as pl -from torch.optim.lr_scheduler import LambdaLR -from einops import rearrange, repeat -from contextlib import contextmanager -from functools import partial -from tqdm import tqdm -from torchvision.utils import make_grid -from pytorch_lightning.utilities.distributed import rank_zero_only - -from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config -from ldm.modules.ema import LitEma -from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL -from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like -from ldm.models.diffusion.ddim import DDIMSampler - -import ldm.models.diffusion.ddpm - -__conditioning_keys__ = {'concat': 'c_concat', - 'crossattn': 'c_crossattn', - 'adm': 'y'} - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -def uniform_on_device(r1, r2, shape, device): - return (r1 - r2) * torch.rand(*shape, device=device) + r2 - - -class DDPMV1(pl.LightningModule): - # classic DDPM with Gaussian diffusion, in image space - def __init__(self, - unet_config, - timesteps=1000, - beta_schedule="linear", - loss_type="l2", - ckpt_path=None, - ignore_keys=None, - load_only_unet=False, - monitor="val/loss", - use_ema=True, - first_stage_key="image", - image_size=256, - channels=3, - log_every_t=100, - clip_denoised=True, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - given_betas=None, - original_elbo_weight=0., - v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1., - conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules - scheduler_config=None, - use_positional_encodings=False, - learn_logvar=False, - logvar_init=0., - ): - super().__init__() - assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' - self.parameterization = parameterization - print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") - self.cond_stage_model = None - self.clip_denoised = clip_denoised - self.log_every_t = log_every_t - self.first_stage_key = first_stage_key - self.image_size = image_size # try conv? - self.channels = channels - self.use_positional_encodings = use_positional_encodings - self.model = DiffusionWrapperV1(unet_config, conditioning_key) - count_params(self.model, verbose=True) - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self.model) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - self.use_scheduler = scheduler_config is not None - if self.use_scheduler: - self.scheduler_config = scheduler_config - - self.v_posterior = v_posterior - self.original_elbo_weight = original_elbo_weight - self.l_simple_weight = l_simple_weight - - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet) - - self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, - linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) - - self.loss_type = loss_type - - self.learn_logvar = learn_logvar - self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) - if self.learn_logvar: - self.logvar = nn.Parameter(self.logvar, requires_grad=True) - - - def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - if exists(given_betas): - betas = given_betas - else: - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( - 1. - alphas_cumprod) + self.v_posterior * betas - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer('posterior_variance', to_torch(posterior_variance)) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) - self.register_buffer('posterior_mean_coef1', to_torch( - betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) - self.register_buffer('posterior_mean_coef2', to_torch( - (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) - - if self.parameterization == "eps": - lvlb_weights = self.betas ** 2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) - elif self.parameterization == "x0": - lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) - else: - raise NotImplementedError("mu not supported") - # TODO how to choose this term - lvlb_weights[0] = lvlb_weights[1] - self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) - assert not torch.isnan(self.lvlb_weights).all() - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.model.parameters()) - self.model_ema.copy_to(self.model) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.model.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def init_from_ckpt(self, path, ignore_keys=None, only_model=False): - sd = torch.load(path, map_location="cpu") - if "state_dict" in list(sd.keys()): - sd = sd["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys or []: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if missing: - print(f"Missing Keys: {missing}") - if unexpected: - print(f"Unexpected Keys: {unexpected}") - - def q_mean_variance(self, x_start, t): - """ - Get the distribution q(x_t | x_0). - :param x_start: the [N x C x ...] tensor of noiseless inputs. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :return: A tuple (mean, variance, log_variance), all of x_start's shape. - """ - mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) - variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) - return mean, variance, log_variance - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, x, t, clip_denoised: bool): - model_out = self.model(x, t) - if self.parameterization == "eps": - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": - x_recon = model_out - if clip_denoised: - x_recon.clamp_(-1., 1.) - - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): - b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) - noise = noise_like(x.shape, device, repeat_noise) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def p_sample_loop(self, shape, return_intermediates=False): - device = self.betas.device - b = shape[0] - img = torch.randn(shape, device=device) - intermediates = [img] - for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), - clip_denoised=self.clip_denoised) - if i % self.log_every_t == 0 or i == self.num_timesteps - 1: - intermediates.append(img) - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample(self, batch_size=16, return_intermediates=False): - image_size = self.image_size - channels = self.channels - return self.p_sample_loop((batch_size, channels, image_size, image_size), - return_intermediates=return_intermediates) - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) - - def get_loss(self, pred, target, mean=True): - if self.loss_type == 'l1': - loss = (target - pred).abs() - if mean: - loss = loss.mean() - elif self.loss_type == 'l2': - if mean: - loss = torch.nn.functional.mse_loss(target, pred) - else: - loss = torch.nn.functional.mse_loss(target, pred, reduction='none') - else: - raise NotImplementedError("unknown loss type '{loss_type}'") - - return loss - - def p_losses(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - model_out = self.model(x_noisy, t) - - loss_dict = {} - if self.parameterization == "eps": - target = noise - elif self.parameterization == "x0": - target = x_start - else: - raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") - - loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) - - log_prefix = 'train' if self.training else 'val' - - loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) - loss_simple = loss.mean() * self.l_simple_weight - - loss_vlb = (self.lvlb_weights[t] * loss).mean() - loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) - - loss = loss_simple + self.original_elbo_weight * loss_vlb - - loss_dict.update({f'{log_prefix}/loss': loss}) - - return loss, loss_dict - - def forward(self, x, *args, **kwargs): - # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size - # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' - t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() - return self.p_losses(x, t, *args, **kwargs) - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') - x = x.to(memory_format=torch.contiguous_format).float() - return x - - def shared_step(self, batch): - x = self.get_input(batch, self.first_stage_key) - loss, loss_dict = self(x) - return loss, loss_dict - - def training_step(self, batch, batch_idx): - loss, loss_dict = self.shared_step(batch) - - self.log_dict(loss_dict, prog_bar=True, - logger=True, on_step=True, on_epoch=True) - - self.log("global_step", self.global_step, - prog_bar=True, logger=True, on_step=True, on_epoch=False) - - if self.use_scheduler: - lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) - - return loss - - @torch.no_grad() - def validation_step(self, batch, batch_idx): - _, loss_dict_no_ema = self.shared_step(batch) - with self.ema_scope(): - _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} - self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) - self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self.model) - - def _get_rows_from_list(self, samples): - n_imgs_per_row = len(samples) - denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') - denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) - return denoise_grid - - @torch.no_grad() - def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = {} - x = self.get_input(batch, self.first_stage_key) - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - x = x.to(self.device)[:N] - log["inputs"] = x - - # get diffusion row - diffusion_row = [] - x_start = x[:n_row] - - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(x_start) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - diffusion_row.append(x_noisy) - - log["diffusion_row"] = self._get_rows_from_list(diffusion_row) - - if sample: - # get denoise row - with self.ema_scope("Plotting"): - samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) - - log["samples"] = samples - log["denoise_row"] = self._get_rows_from_list(denoise_row) - - if return_keys: - if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: - return log - else: - return {key: log[key] for key in return_keys} - return log - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - if self.learn_logvar: - params = params + [self.logvar] - opt = torch.optim.AdamW(params, lr=lr) - return opt - - -class LatentDiffusionV1(DDPMV1): - """main class""" - def __init__(self, - first_stage_config, - cond_stage_config, - num_timesteps_cond=None, - cond_stage_key="image", - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - scale_by_std=False, - *args, **kwargs): - self.num_timesteps_cond = default(num_timesteps_cond, 1) - self.scale_by_std = scale_by_std - assert self.num_timesteps_cond <= kwargs['timesteps'] - # for backwards compatibility after implementation of DiffusionWrapper - if conditioning_key is None: - conditioning_key = 'concat' if concat_mode else 'crossattn' - if cond_stage_config == '__is_unconditional__': - conditioning_key = None - ckpt_path = kwargs.pop("ckpt_path", None) - ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(*args, conditioning_key=conditioning_key, **kwargs) - self.concat_mode = concat_mode - self.cond_stage_trainable = cond_stage_trainable - self.cond_stage_key = cond_stage_key - try: - self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except Exception: - self.num_downs = 0 - if not scale_by_std: - self.scale_factor = scale_factor - else: - self.register_buffer('scale_factor', torch.tensor(scale_factor)) - self.instantiate_first_stage(first_stage_config) - self.instantiate_cond_stage(cond_stage_config) - self.cond_stage_forward = cond_stage_forward - self.clip_denoised = False - self.bbox_tokenizer = None - - self.restarted_from_ckpt = False - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys) - self.restarted_from_ckpt = True - - def make_cond_schedule(self, ): - self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) - ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() - self.cond_ids[:self.num_timesteps_cond] = ids - - @rank_zero_only - @torch.no_grad() - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - # only for very first batch - if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: - assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' - # set rescale weight to 1./std of encodings - print("### USING STD-RESCALING ###") - x = super().get_input(batch, self.first_stage_key) - x = x.to(self.device) - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - del self.scale_factor - self.register_buffer('scale_factor', 1. / z.flatten().std()) - print(f"setting self.scale_factor to {self.scale_factor}") - print("### USING STD-RESCALING ###") - - def register_schedule(self, - given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) - - self.shorten_cond_schedule = self.num_timesteps_cond > 1 - if self.shorten_cond_schedule: - self.make_cond_schedule() - - def instantiate_first_stage(self, config): - model = instantiate_from_config(config) - self.first_stage_model = model.eval() - self.first_stage_model.train = disabled_train - for param in self.first_stage_model.parameters(): - param.requires_grad = False - - def instantiate_cond_stage(self, config): - if not self.cond_stage_trainable: - if config == "__is_first_stage__": - print("Using first stage also as cond stage.") - self.cond_stage_model = self.first_stage_model - elif config == "__is_unconditional__": - print(f"Training {self.__class__.__name__} as an unconditional model.") - self.cond_stage_model = None - # self.be_unconditional = True - else: - model = instantiate_from_config(config) - self.cond_stage_model = model.eval() - self.cond_stage_model.train = disabled_train - for param in self.cond_stage_model.parameters(): - param.requires_grad = False - else: - assert config != '__is_first_stage__' - assert config != '__is_unconditional__' - model = instantiate_from_config(config) - self.cond_stage_model = model - - def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): - denoise_row = [] - for zd in tqdm(samples, desc=desc): - denoise_row.append(self.decode_first_stage(zd.to(self.device), - force_not_quantize=force_no_decoder_quantization)) - n_imgs_per_row = len(denoise_row) - denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W - denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') - denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) - return denoise_grid - - def get_first_stage_encoding(self, encoder_posterior): - if isinstance(encoder_posterior, DiagonalGaussianDistribution): - z = encoder_posterior.sample() - elif isinstance(encoder_posterior, torch.Tensor): - z = encoder_posterior - else: - raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") - return self.scale_factor * z - - def get_learned_conditioning(self, c): - if self.cond_stage_forward is None: - if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): - c = self.cond_stage_model.encode(c) - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - else: - c = self.cond_stage_model(c) - else: - assert hasattr(self.cond_stage_model, self.cond_stage_forward) - c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) - return c - - def meshgrid(self, h, w): - y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) - x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) - - arr = torch.cat([y, x], dim=-1) - return arr - - def delta_border(self, h, w): - """ - :param h: height - :param w: width - :return: normalized distance to image border, - with min distance = 0 at border and max dist = 0.5 at image center - """ - lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) - arr = self.meshgrid(h, w) / lower_right_corner - dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] - dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] - edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] - return edge_dist - - def get_weighting(self, h, w, Ly, Lx, device): - weighting = self.delta_border(h, w) - weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], - self.split_input_params["clip_max_weight"], ) - weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) - - if self.split_input_params["tie_braker"]: - L_weighting = self.delta_border(Ly, Lx) - L_weighting = torch.clip(L_weighting, - self.split_input_params["clip_min_tie_weight"], - self.split_input_params["clip_max_tie_weight"]) - - L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) - weighting = weighting * L_weighting - return weighting - - def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code - """ - :param x: img of size (bs, c, h, w) - :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) - """ - bs, nc, h, w = x.shape - - # number of crops in image - Ly = (h - kernel_size[0]) // stride[0] + 1 - Lx = (w - kernel_size[1]) // stride[1] + 1 - - if uf == 1 and df == 1: - fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) - unfold = torch.nn.Unfold(**fold_params) - - fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) - - weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) - - elif uf > 1 and df == 1: - fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) - unfold = torch.nn.Unfold(**fold_params) - - fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), - dilation=1, padding=0, - stride=(stride[0] * uf, stride[1] * uf)) - fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) - - weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) - - elif df > 1 and uf == 1: - fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) - unfold = torch.nn.Unfold(**fold_params) - - fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), - dilation=1, padding=0, - stride=(stride[0] // df, stride[1] // df)) - fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) - - weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) - - else: - raise NotImplementedError - - return fold, unfold, normalization, weighting - - @torch.no_grad() - def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, - cond_key=None, return_original_cond=False, bs=None): - x = super().get_input(batch, k) - if bs is not None: - x = x[:bs] - x = x.to(self.device) - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - - if self.model.conditioning_key is not None: - if cond_key is None: - cond_key = self.cond_stage_key - if cond_key != self.first_stage_key: - if cond_key in ['caption', 'coordinates_bbox']: - xc = batch[cond_key] - elif cond_key == 'class_label': - xc = batch - else: - xc = super().get_input(batch, cond_key).to(self.device) - else: - xc = x - if not self.cond_stage_trainable or force_c_encode: - if isinstance(xc, dict) or isinstance(xc, list): - # import pudb; pudb.set_trace() - c = self.get_learned_conditioning(xc) - else: - c = self.get_learned_conditioning(xc.to(self.device)) - else: - c = xc - if bs is not None: - c = c[:bs] - - if self.use_positional_encodings: - pos_x, pos_y = self.compute_latent_shifts(batch) - ckey = __conditioning_keys__[self.model.conditioning_key] - c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} - - else: - c = None - xc = None - if self.use_positional_encodings: - pos_x, pos_y = self.compute_latent_shifts(batch) - c = {'pos_x': pos_x, 'pos_y': pos_y} - out = [z, c] - if return_first_stage_outputs: - xrec = self.decode_first_stage(z) - out.extend([x, xrec]) - if return_original_cond: - out.append(xc) - return out - - @torch.no_grad() - def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, 'b h w c -> b c h w').contiguous() - - z = 1. / self.scale_factor * z - - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - uf = self.split_input_params["vqf"] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) - - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [self.first_stage_model.decode(z[:, :, :, :, i], - force_not_quantize=predict_cids or force_not_quantize) - for i in range(z.shape[-1])] - else: - - output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] - - o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - - # same as above but without decorator - def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, 'b h w c -> b c h w').contiguous() - - z = 1. / self.scale_factor * z - - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - uf = self.split_input_params["vqf"] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) - - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [self.first_stage_model.decode(z[:, :, :, :, i], - force_not_quantize=predict_cids or force_not_quantize) - for i in range(z.shape[-1])] - else: - - output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] - - o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - - @torch.no_grad() - def encode_first_stage(self, x): - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - df = self.split_input_params["vqf"] - self.split_input_params['original_image_size'] = x.shape[-2:] - bs, nc, h, w = x.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) - z = unfold(x) # (bn, nc * prod(**ks), L) - # Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] - - o = torch.stack(output_list, axis=-1) - o = o * weighting - - # Reverse reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization - return decoded - - else: - return self.first_stage_model.encode(x) - else: - return self.first_stage_model.encode(x) - - def shared_step(self, batch, **kwargs): - x, c = self.get_input(batch, self.first_stage_key) - loss = self(x, c) - return loss - - def forward(self, x, c, *args, **kwargs): - t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() - if self.model.conditioning_key is not None: - assert c is not None - if self.cond_stage_trainable: - c = self.get_learned_conditioning(c) - if self.shorten_cond_schedule: # TODO: drop this option - tc = self.cond_ids[t].to(self.device) - c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) - return self.p_losses(x, c, t, *args, **kwargs) - - def apply_model(self, x_noisy, t, cond, return_ids=False): - - if isinstance(cond, dict): - # hybrid case, cond is expected to be a dict - pass - else: - if not isinstance(cond, list): - cond = [cond] - key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' - cond = {key: cond} - - if hasattr(self, "split_input_params"): - assert len(cond) == 1 # todo can only deal with one conditioning atm - assert not return_ids - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - - h, w = x_noisy.shape[-2:] - - fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) - - z = unfold(x_noisy) # (bn, nc * prod(**ks), L) - # Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] - - if self.cond_stage_key in ["image", "LR_image", "segmentation", - 'bbox_img'] and self.model.conditioning_key: # todo check for completeness - c_key = next(iter(cond.keys())) # get key - c = next(iter(cond.values())) # get value - assert (len(c) == 1) # todo extend to list with more than one elem - c = c[0] # get element - - c = unfold(c) - c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] - - elif self.cond_stage_key == 'coordinates_bbox': - assert 'original_image_size' in self.split_input_params, 'BoundingBoxRescaling is missing original_image_size' - - # assuming padding of unfold is always 0 and its dilation is always 1 - n_patches_per_row = int((w - ks[0]) / stride[0] + 1) - full_img_h, full_img_w = self.split_input_params['original_image_size'] - # as we are operating on latents, we need the factor from the original image size to the - # spatial latent size to properly rescale the crops for regenerating the bbox annotations - num_downs = self.first_stage_model.encoder.num_resolutions - 1 - rescale_latent = 2 ** (num_downs) - - # get top left positions of patches as conforming for the bbbox tokenizer, therefore we - # need to rescale the tl patch coordinates to be in between (0,1) - tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, - rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) - for patch_nr in range(z.shape[-1])] - - # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) - patch_limits = [(x_tl, y_tl, - rescale_latent * ks[0] / full_img_w, - rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] - # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] - - # tokenize crop coordinates for the bounding boxes of the respective patches - patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) - for bbox in patch_limits] # list of length l with tensors of shape (1, 2) - print(patch_limits_tknzd[0].shape) - # cut tknzd crop position from conditioning - assert isinstance(cond, dict), 'cond must be dict to be fed into model' - cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) - print(cut_cond.shape) - - adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) - adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') - print(adapted_cond.shape) - adapted_cond = self.get_learned_conditioning(adapted_cond) - print(adapted_cond.shape) - adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) - print(adapted_cond.shape) - - cond_list = [{'c_crossattn': [e]} for e in adapted_cond] - - else: - cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient - - # apply model by loop over crops - output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] - assert not isinstance(output_list[0], - tuple) # todo cant deal with multiple model outputs check this never happens - - o = torch.stack(output_list, axis=-1) - o = o * weighting - # Reverse reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - x_recon = fold(o) / normalization - - else: - x_recon = self.model(x_noisy, t, **cond) - - if isinstance(x_recon, tuple) and not return_ids: - return x_recon[0] - else: - return x_recon - - def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - - def _prior_bpd(self, x_start): - """ - Get the prior KL term for the variational lower-bound, measured in - bits-per-dim. - This term can't be optimized, as it only depends on the encoder. - :param x_start: the [N x C x ...] tensor of inputs. - :return: a batch of [N] KL values (in bits), one per batch element. - """ - batch_size = x_start.shape[0] - t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) - return mean_flat(kl_prior) / np.log(2.0) - - def p_losses(self, x_start, cond, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - model_output = self.apply_model(x_noisy, t, cond) - - loss_dict = {} - prefix = 'train' if self.training else 'val' - - if self.parameterization == "x0": - target = x_start - elif self.parameterization == "eps": - target = noise - else: - raise NotImplementedError() - - loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) - loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) - - logvar_t = self.logvar[t].to(self.device) - loss = loss_simple / torch.exp(logvar_t) + logvar_t - # loss = loss_simple / torch.exp(self.logvar) + self.logvar - if self.learn_logvar: - loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) - loss_dict.update({'logvar': self.logvar.data.mean()}) - - loss = self.l_simple_weight * loss.mean() - - loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) - loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() - loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) - loss += (self.original_elbo_weight * loss_vlb) - loss_dict.update({f'{prefix}/loss': loss}) - - return loss, loss_dict - - def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, - return_x0=False, score_corrector=None, corrector_kwargs=None): - t_in = t - model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) - - if score_corrector is not None: - assert self.parameterization == "eps" - model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) - - if return_codebook_ids: - model_out, logits = model_out - - if self.parameterization == "eps": - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": - x_recon = model_out - else: - raise NotImplementedError() - - if clip_denoised: - x_recon.clamp_(-1., 1.) - if quantize_denoised: - x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) - if return_codebook_ids: - return model_mean, posterior_variance, posterior_log_variance, logits - elif return_x0: - return model_mean, posterior_variance, posterior_log_variance, x_recon - else: - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, - return_codebook_ids=False, quantize_denoised=False, return_x0=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): - b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, - return_codebook_ids=return_codebook_ids, - quantize_denoised=quantize_denoised, - return_x0=return_x0, - score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) - if return_codebook_ids: - raise DeprecationWarning("Support dropped.") - model_mean, _, model_log_variance, logits = outputs - elif return_x0: - model_mean, _, model_log_variance, x0 = outputs - else: - model_mean, _, model_log_variance = outputs - - noise = noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - - if return_codebook_ids: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) - if return_x0: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 - else: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, - img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., - score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, - log_every_t=None): - if not log_every_t: - log_every_t = self.log_every_t - timesteps = self.num_timesteps - if batch_size is not None: - b = batch_size if batch_size is not None else shape[0] - shape = [batch_size] + list(shape) - else: - b = batch_size = shape[0] - if x_T is None: - img = torch.randn(shape, device=self.device) - else: - img = x_T - intermediates = [] - if cond is not None: - if isinstance(cond, dict): - cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - [x[:batch_size] for x in cond[key]] for key in cond} - else: - cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', - total=timesteps) if verbose else reversed( - range(0, timesteps)) - if type(temperature) == float: - temperature = [temperature] * timesteps - - for i in iterator: - ts = torch.full((b,), i, device=self.device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - - img, x0_partial = self.p_sample(img, cond, ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, return_x0=True, - temperature=temperature[i], noise_dropout=noise_dropout, - score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) - if mask is not None: - assert x0 is not None - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(x0_partial) - if callback: - callback(i) - if img_callback: - img_callback(img, i) - return img, intermediates - - @torch.no_grad() - def p_sample_loop(self, cond, shape, return_intermediates=False, - x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, start_T=None, - log_every_t=None): - - if not log_every_t: - log_every_t = self.log_every_t - device = self.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - intermediates = [img] - if timesteps is None: - timesteps = self.num_timesteps - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( - range(0, timesteps)) - - if mask is not None: - assert x0 is not None - assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match - - for i in iterator: - ts = torch.full((b,), i, device=device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - - img = self.p_sample(img, cond, ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised) - if mask is not None: - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(img) - if callback: - callback(i) - if img_callback: - img_callback(img, i) - - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, - verbose=True, timesteps=None, quantize_denoised=False, - mask=None, x0=None, shape=None,**kwargs): - if shape is None: - shape = (batch_size, self.channels, self.image_size, self.image_size) - if cond is not None: - if isinstance(cond, dict): - cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - [x[:batch_size] for x in cond[key]] for key in cond} - else: - cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] - return self.p_sample_loop(cond, - shape, - return_intermediates=return_intermediates, x_T=x_T, - verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, - mask=mask, x0=x0) - - @torch.no_grad() - def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): - - if ddim: - ddim_sampler = DDIMSampler(self) - shape = (self.channels, self.image_size, self.image_size) - samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, - shape,cond,verbose=False,**kwargs) - - else: - samples, intermediates = self.sample(cond=cond, batch_size=batch_size, - return_intermediates=True,**kwargs) - - return samples, intermediates - - - @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, - quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, **kwargs): - - use_ddim = ddim_steps is not None - - log = {} - z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=N) - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - log["inputs"] = x - log["reconstruction"] = xrec - if self.model.conditioning_key is not None: - if hasattr(self.cond_stage_model, "decode"): - xc = self.cond_stage_model.decode(c) - log["conditioning"] = xc - elif self.cond_stage_key in ["caption"]: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) - log["conditioning"] = xc - elif self.cond_stage_key == 'class_label': - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) - log['conditioning'] = xc - elif isimage(xc): - log["conditioning"] = xc - if ismap(xc): - log["original_conditioning"] = self.to_rgb(xc) - - if plot_diffusion_rows: - # get diffusion row - diffusion_row = [] - z_start = z[:n_row] - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(z_start) - z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) - diffusion_row.append(self.decode_first_stage(z_noisy)) - - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') - diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) - log["diffusion_row"] = diffusion_grid - - if sample: - # get denoise row - with self.ema_scope("Plotting"): - samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, - ddim_steps=ddim_steps,eta=ddim_eta) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) - x_samples = self.decode_first_stage(samples) - log["samples"] = x_samples - if plot_denoise_rows: - denoise_grid = self._get_denoise_row_from_list(z_denoise_row) - log["denoise_row"] = denoise_grid - - if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( - self.first_stage_model, IdentityFirstStage): - # also display when quantizing x0 while sampling - with self.ema_scope("Plotting Quantized Denoised"): - samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, - ddim_steps=ddim_steps,eta=ddim_eta, - quantize_denoised=True) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, - # quantize_denoised=True) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_x0_quantized"] = x_samples - - if inpaint: - # make a simple center square - h, w = z.shape[2], z.shape[3] - mask = torch.ones(N, h, w).to(self.device) - # zeros will be filled in - mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. - mask = mask[:, None, ...] - with self.ema_scope("Plotting Inpaint"): - - samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_inpainting"] = x_samples - log["mask"] = mask - - # outpaint - with self.ema_scope("Plotting Outpaint"): - samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_outpainting"] = x_samples - - if plot_progressive_rows: - with self.ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) - prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") - log["progressive_row"] = prog_row - - if return_keys: - if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: - return log - else: - return {key: log[key] for key in return_keys} - return log - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - if self.cond_stage_trainable: - print(f"{self.__class__.__name__}: Also optimizing conditioner params!") - params = params + list(self.cond_stage_model.parameters()) - if self.learn_logvar: - print('Diffusion model optimizing logvar') - params.append(self.logvar) - opt = torch.optim.AdamW(params, lr=lr) - if self.use_scheduler: - assert 'target' in self.scheduler_config - scheduler = instantiate_from_config(self.scheduler_config) - - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }] - return [opt], scheduler - return opt - - @torch.no_grad() - def to_rgb(self, x): - x = x.float() - if not hasattr(self, "colorize"): - self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) - x = nn.functional.conv2d(x, weight=self.colorize) - x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. - return x - - -class DiffusionWrapperV1(pl.LightningModule): - def __init__(self, diff_model_config, conditioning_key): - super().__init__() - self.diffusion_model = instantiate_from_config(diff_model_config) - self.conditioning_key = conditioning_key - assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] - - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): - if self.conditioning_key is None: - out = self.diffusion_model(x, t) - elif self.conditioning_key == 'concat': - xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t) - elif self.conditioning_key == 'crossattn': - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc) - elif self.conditioning_key == 'hybrid': - xc = torch.cat([x] + c_concat, dim=1) - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc) - elif self.conditioning_key == 'adm': - cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc) - else: - raise NotImplementedError() - - return out - - -class Layout2ImgDiffusionV1(LatentDiffusionV1): - # TODO: move all layout-specific hacks to this class - def __init__(self, cond_stage_key, *args, **kwargs): - assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs) - - def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(*args, batch=batch, N=N, **kwargs) - - key = 'train' if self.training else 'validation' - dset = self.trainer.datamodule.datasets[key] - mapper = dset.conditional_builders[self.cond_stage_key] - - bbox_imgs = [] - map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) - for tknzd_bbox in batch[self.cond_stage_key][:N]: - bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) - bbox_imgs.append(bboximg) - - cond_img = torch.stack(bbox_imgs, dim=0) - logs['bbox_image'] = cond_img - return logs - -ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1 -ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1 -ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1 -ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1 diff --git a/extensions-builtin/LDSR/vqvae_quantize.py b/extensions-builtin/LDSR/vqvae_quantize.py deleted file mode 100644 index dd14b8fd..00000000 --- a/extensions-builtin/LDSR/vqvae_quantize.py +++ /dev/null @@ -1,147 +0,0 @@ -# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py, -# where the license is as follows: -# -# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR -# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE -# OR OTHER DEALINGS IN THE SOFTWARE./ - -import torch -import torch.nn as nn -import numpy as np -from einops import rearrange - - -class VectorQuantizer2(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly - avoids costly matrix multiplications and allows for post-hoc remapping of indices. - """ - - # NOTE: due to a bug the beta term was applied to the wrong term. for - # backwards compatibility we use the buggy version by default, but you can - # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", - sane_index_shape=False, legacy=True): - super().__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - self.legacy = legacy - - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices.") - else: - self.re_embed = n_e - - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - def forward(self, z, temp=None, rescale_logits=False, return_logits=False): - assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" - assert rescale_logits is False, "Only for interface compatible with Gumbel" - assert return_logits is False, "Only for interface compatible with Gumbel" - # reshape z -> (batch, height, width, channel) and flatten - z = rearrange(z, 'b c h w -> b h w c').contiguous() - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ - torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ - torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) - - min_encoding_indices = torch.argmin(d, dim=1) - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - - # compute loss for embedding - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \ - torch.mean((z_q - z.detach()) ** 2) - else: - loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \ - torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape( - z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q diff --git a/extensions-builtin/forge_legacy_preprocessors/legacy_preprocessors/preprocessor.py b/extensions-builtin/forge_legacy_preprocessors/legacy_preprocessors/preprocessor.py index af6738b3..833d65e8 100644 --- a/extensions-builtin/forge_legacy_preprocessors/legacy_preprocessors/preprocessor.py +++ b/extensions-builtin/forge_legacy_preprocessors/legacy_preprocessors/preprocessor.py @@ -11,9 +11,12 @@ from transformers.models.clip.modeling_clip import CLIPVisionModelOutput from annotator.util import HWC3 from typing import Callable, Tuple, Union -from modules.safe import Extra from modules import devices +import contextlib + +Extra = lambda x: contextlib.nullcontext() + def torch_handler(module: str, name: str): """ Allow all torch access. Bypass A1111 safety whitelist. """ diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py index e5bf5277..9a81f4e6 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py @@ -19,7 +19,6 @@ import cv2 import logging from typing import Any, Callable, Dict, List -from modules.safe import unsafe_torch_load from lib_controlnet.logging import logger @@ -28,7 +27,7 @@ def load_state_dict(ckpt_path, location="cpu"): if extension.lower() == ".safetensors": state_dict = safetensors.torch.load_file(ckpt_path, device=location) else: - state_dict = unsafe_torch_load(ckpt_path, map_location=torch.device(location)) + state_dict = torch.load(ckpt_path, map_location=torch.device(location)) state_dict = get_state_dict(state_dict) logger.info(f"Loaded state_dict from [{ckpt_path}]") return state_dict diff --git a/modules/api/api.py b/modules/api/api.py index 78d10969..25ce7ca0 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -24,7 +24,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin -from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import Any @@ -725,7 +724,7 @@ class Api: def get_sd_models(self): import modules.sd_models as sd_models - return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename} for x in sd_models.checkpoints_list.values()] def get_sd_vaes(self): import modules.sd_vae as sd_vae diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 17454665..2b5205a1 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -9,7 +9,7 @@ import modules.textual_inversion.dataset import torch import tqdm from einops import rearrange, repeat -from ldm.util import default +from backend.nn.unet import default from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules.textual_inversion import textual_inversion, saving_settings from modules.textual_inversion.learn_schedule import LearnRateScheduler diff --git a/modules/initialize.py b/modules/initialize.py index ec4d58a4..3cf71c66 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -1,25 +1,12 @@ import importlib import logging -import os import sys import warnings import os -from threading import Thread - from modules.timer import startup_timer -class HiddenPrints: - def __enter__(self): - self._original_stdout = sys.stdout - sys.stdout = open(os.devnull, 'w') - - def __exit__(self, exc_type, exc_val, exc_tb): - sys.stdout.close() - sys.stdout = self._original_stdout - - def imports(): logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) @@ -35,16 +22,8 @@ def imports(): import gradio # noqa: F401 startup_timer.record("import gradio") - with HiddenPrints(): - from modules import paths, timer, import_hook, errors # noqa: F401 - startup_timer.record("setup paths") - - import ldm.modules.encoders.modules # noqa: F401 - import ldm.modules.diffusionmodules.model - startup_timer.record("import ldm") - - import sgm.modules.encoders.modules # noqa: F401 - startup_timer.record("import sgm") + from modules import paths, timer, import_hook, errors # noqa: F401 + startup_timer.record("setup paths") from modules import shared_init shared_init.initialize() @@ -137,15 +116,6 @@ def initialize_rest(*, reload_script_modules=False): sd_vae.refresh_vae_list() startup_timer.record("refresh VAE") - from modules import textual_inversion - textual_inversion.textual_inversion.list_textual_inversion_templates() - startup_timer.record("refresh textual inversion templates") - - from modules import script_callbacks, sd_hijack_optimizations, sd_hijack - script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers) - sd_hijack.list_optimizers() - startup_timer.record("scripts list_optimizers") - from modules import sd_unet sd_unet.list_unets() startup_timer.record("scripts list_unets") diff --git a/modules/launch_utils.py b/modules/launch_utils.py index f933de64..8c1823a1 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -391,15 +391,15 @@ def prepare_environment(): openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git") - stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") - stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") + # stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") + # stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917") - stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") - stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") + # stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") + # stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") @@ -456,8 +456,8 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash) - git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) - git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) + # git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) + # git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) diff --git a/modules/paths.py b/modules/paths.py index 501ff658..83bbbc79 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -2,45 +2,15 @@ import os import sys from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401 -import modules.safe # noqa: F401 - -def mute_sdxl_imports(): - """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" - - class Dummy: - pass - - module = Dummy() - module.LPIPS = None - sys.modules['taming.modules.losses.lpips'] = module - - module = Dummy() - module.StableDataModuleFromConfig = None - sys.modules['sgm.data'] = module - - -# data_path = cmd_opts_pre.data sys.path.insert(0, script_path) -# search for directory of stable diffusion in following places -sd_path = None -possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)] -for possible_sd_path in possible_sd_paths: - if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): - sd_path = os.path.abspath(possible_sd_path) - break - -assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}" - -mute_sdxl_imports() +sd_path = os.path.dirname(__file__) path_dirs = [ - (sd_path, 'ldm', 'Stable Diffusion', []), - (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), - (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), - (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), - (os.path.join(sd_path, '../huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []), + (os.path.join(sd_path, '../repositories/BLIP'), 'models/blip.py', 'BLIP', []), + (os.path.join(sd_path, '../repositories/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), + (os.path.join(sd_path, '../repositories/huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []), ] paths = {} @@ -53,13 +23,6 @@ for d, must_exist, what, options in path_dirs: d = os.path.abspath(d) if "atstart" in options: sys.path.insert(0, d) - elif "sgm" in options: - # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we - # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. - - sys.path.insert(0, d) - import sgm # noqa: F401 - sys.path.pop(0) else: sys.path.append(d) paths[what] = d diff --git a/modules/processing.py b/modules/processing.py index 2d0f13fa..09cd8d9a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -28,8 +28,6 @@ import modules.images as images import modules.styles import modules.sd_models as sd_models import modules.sd_vae as sd_vae -from ldm.data.util import AddMiDaS -from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType @@ -295,23 +293,7 @@ class StableDiffusionProcessing: return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height) def depth2img_image_conditioning(self, source_image): - # Use the AddMiDaS helper to Format our source image to suit the MiDaS model - transformer = AddMiDaS(model_type="dpt_hybrid") - transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")}) - midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) - midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - - conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) - conditioning = torch.nn.functional.interpolate( - self.sd_model.depth_model(midas_in), - size=conditioning_image.shape[2:], - mode="bicubic", - align_corners=False, - ) - - (depth_min, depth_max) = torch.aminmax(conditioning) - conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1. - return conditioning + raise NotImplementedError('NotImplementedError: depth2img_image_conditioning') def edit_image_conditioning(self, source_image): conditioning_image = shared.sd_model.encode_first_stage(source_image).mode() @@ -368,11 +350,6 @@ class StableDiffusionProcessing: def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True): source_image = devices.cond_cast_float(source_image) - # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely - # identify itself with a field common to all models. The conditioning_key is also hybrid. - if isinstance(self.sd_model, LatentDepth2ImageDiffusion): - return self.depth2img_image_conditioning(source_image) - if self.sd_model.cond_stage_key == "edit": return self.edit_image_conditioning(source_image) diff --git a/modules/safe.py b/modules/safe.py index d1e242e8..c483e2a8 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -1,195 +1,195 @@ -# this code is adapted from the script contributed by anon from /h/ - -import pickle -import collections - -import torch -import numpy -import _codecs -import zipfile -import re - - -# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage -from modules import errors - -TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage - -def encode(*args): - out = _codecs.encode(*args) - return out - - -class RestrictedUnpickler(pickle.Unpickler): - extra_handler = None - - def persistent_load(self, saved_id): - assert saved_id[0] == 'storage' - - try: - return TypedStorage(_internal=True) - except TypeError: - return TypedStorage() # PyTorch before 2.0 does not have the _internal argument - - def find_class(self, module, name): - if self.extra_handler is not None: - res = self.extra_handler(module, name) - if res is not None: - return res - - if module == 'collections' and name == 'OrderedDict': - return getattr(collections, name) - if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: - return getattr(torch._utils, name) - if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']: - return getattr(torch, name) - if module == 'torch.nn.modules.container' and name in ['ParameterDict']: - return getattr(torch.nn.modules.container, name) - if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']: - return getattr(numpy.core.multiarray, name) - if module == 'numpy' and name in ['dtype', 'ndarray']: - return getattr(numpy, name) - if module == '_codecs' and name == 'encode': - return encode - if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': - import pytorch_lightning.callbacks - return pytorch_lightning.callbacks.model_checkpoint - if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': - import pytorch_lightning.callbacks.model_checkpoint - return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint - if module == "__builtin__" and name == 'set': - return set - - # Forbid everything else. - raise Exception(f"global '{module}/{name}' is forbidden") - - -# Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/' -allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$") -data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") - -def check_zip_filenames(filename, names): - for name in names: - if allowed_zip_names_re.match(name): - continue - - raise Exception(f"bad file inside {filename}: {name}") - - -def check_pt(filename, extra_handler): - try: - - # new pytorch format is a zip file - with zipfile.ZipFile(filename) as z: - check_zip_filenames(filename, z.namelist()) - - # find filename of data.pkl in zip file: '/data.pkl' - data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] - if len(data_pkl_filenames) == 0: - raise Exception(f"data.pkl not found in {filename}") - if len(data_pkl_filenames) > 1: - raise Exception(f"Multiple data.pkl found in {filename}") - with z.open(data_pkl_filenames[0]) as file: - unpickler = RestrictedUnpickler(file) - unpickler.extra_handler = extra_handler - unpickler.load() - - except zipfile.BadZipfile: - - # if it's not a zip file, it's an old pytorch format, with five objects written to pickle - with open(filename, "rb") as file: - unpickler = RestrictedUnpickler(file) - unpickler.extra_handler = extra_handler - for _ in range(5): - unpickler.load() - - -def load(filename, *args, **kwargs): - return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs) - - -def load_with_extra(filename, extra_handler=None, *args, **kwargs): - """ - this function is intended to be used by extensions that want to load models with - some extra classes in them that the usual unpickler would find suspicious. - - Use the extra_handler argument to specify a function that takes module and field name as text, - and returns that field's value: - - ```python - def extra(module, name): - if module == 'collections' and name == 'OrderedDict': - return collections.OrderedDict - - return None - - safe.load_with_extra('model.pt', extra_handler=extra) - ``` - - The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is - definitely unsafe. - """ - - from modules import shared - - try: - if not shared.cmd_opts.disable_safe_unpickle: - check_pt(filename, extra_handler) - - except pickle.UnpicklingError: - errors.report( - f"Error verifying pickled file from {filename}\n" - "-----> !!!! The file is most likely corrupted !!!! <-----\n" - "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", - exc_info=True, - ) - return None - except Exception: - errors.report( - f"Error verifying pickled file from {filename}\n" - f"The file may be malicious, so the program is not going to read it.\n" - f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", - exc_info=True, - ) - return None - - return unsafe_torch_load(filename, *args, **kwargs) - - -class Extra: - """ - A class for temporarily setting the global handler for when you can't explicitly call load_with_extra - (because it's not your code making the torch.load call). The intended use is like this: - -``` -import torch -from modules import safe - -def handler(module, name): - if module == 'torch' and name in ['float64', 'float16']: - return getattr(torch, name) - - return None - -with safe.Extra(handler): - x = torch.load('model.pt') -``` - """ - - def __init__(self, handler): - self.handler = handler - - def __enter__(self): - global global_extra_handler - - assert global_extra_handler is None, 'already inside an Extra() block' - global_extra_handler = self.handler - - def __exit__(self, exc_type, exc_val, exc_tb): - global global_extra_handler - - global_extra_handler = None - - -unsafe_torch_load = torch.load -global_extra_handler = None +# # this code is adapted from the script contributed by anon from /h/ +# +# import pickle +# import collections +# +# import torch +# import numpy +# import _codecs +# import zipfile +# import re +# +# +# # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage +# from modules import errors +# +# TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage +# +# def encode(*args): +# out = _codecs.encode(*args) +# return out +# +# +# class RestrictedUnpickler(pickle.Unpickler): +# extra_handler = None +# +# def persistent_load(self, saved_id): +# assert saved_id[0] == 'storage' +# +# try: +# return TypedStorage(_internal=True) +# except TypeError: +# return TypedStorage() # PyTorch before 2.0 does not have the _internal argument +# +# def find_class(self, module, name): +# if self.extra_handler is not None: +# res = self.extra_handler(module, name) +# if res is not None: +# return res +# +# if module == 'collections' and name == 'OrderedDict': +# return getattr(collections, name) +# if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: +# return getattr(torch._utils, name) +# if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']: +# return getattr(torch, name) +# if module == 'torch.nn.modules.container' and name in ['ParameterDict']: +# return getattr(torch.nn.modules.container, name) +# if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']: +# return getattr(numpy.core.multiarray, name) +# if module == 'numpy' and name in ['dtype', 'ndarray']: +# return getattr(numpy, name) +# if module == '_codecs' and name == 'encode': +# return encode +# if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': +# import pytorch_lightning.callbacks +# return pytorch_lightning.callbacks.model_checkpoint +# if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': +# import pytorch_lightning.callbacks.model_checkpoint +# return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint +# if module == "__builtin__" and name == 'set': +# return set +# +# # Forbid everything else. +# raise Exception(f"global '{module}/{name}' is forbidden") +# +# +# # Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/' +# allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$") +# data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") +# +# def check_zip_filenames(filename, names): +# for name in names: +# if allowed_zip_names_re.match(name): +# continue +# +# raise Exception(f"bad file inside {filename}: {name}") +# +# +# def check_pt(filename, extra_handler): +# try: +# +# # new pytorch format is a zip file +# with zipfile.ZipFile(filename) as z: +# check_zip_filenames(filename, z.namelist()) +# +# # find filename of data.pkl in zip file: '/data.pkl' +# data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] +# if len(data_pkl_filenames) == 0: +# raise Exception(f"data.pkl not found in {filename}") +# if len(data_pkl_filenames) > 1: +# raise Exception(f"Multiple data.pkl found in {filename}") +# with z.open(data_pkl_filenames[0]) as file: +# unpickler = RestrictedUnpickler(file) +# unpickler.extra_handler = extra_handler +# unpickler.load() +# +# except zipfile.BadZipfile: +# +# # if it's not a zip file, it's an old pytorch format, with five objects written to pickle +# with open(filename, "rb") as file: +# unpickler = RestrictedUnpickler(file) +# unpickler.extra_handler = extra_handler +# for _ in range(5): +# unpickler.load() +# +# +# def load(filename, *args, **kwargs): +# return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs) +# +# +# def load_with_extra(filename, extra_handler=None, *args, **kwargs): +# """ +# this function is intended to be used by extensions that want to load models with +# some extra classes in them that the usual unpickler would find suspicious. +# +# Use the extra_handler argument to specify a function that takes module and field name as text, +# and returns that field's value: +# +# ```python +# def extra(module, name): +# if module == 'collections' and name == 'OrderedDict': +# return collections.OrderedDict +# +# return None +# +# safe.load_with_extra('model.pt', extra_handler=extra) +# ``` +# +# The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is +# definitely unsafe. +# """ +# +# from modules import shared +# +# try: +# if not shared.cmd_opts.disable_safe_unpickle: +# check_pt(filename, extra_handler) +# +# except pickle.UnpicklingError: +# errors.report( +# f"Error verifying pickled file from {filename}\n" +# "-----> !!!! The file is most likely corrupted !!!! <-----\n" +# "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", +# exc_info=True, +# ) +# return None +# except Exception: +# errors.report( +# f"Error verifying pickled file from {filename}\n" +# f"The file may be malicious, so the program is not going to read it.\n" +# f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", +# exc_info=True, +# ) +# return None +# +# return unsafe_torch_load(filename, *args, **kwargs) +# +# +# class Extra: +# """ +# A class for temporarily setting the global handler for when you can't explicitly call load_with_extra +# (because it's not your code making the torch.load call). The intended use is like this: +# +# ``` +# import torch +# from modules import safe +# +# def handler(module, name): +# if module == 'torch' and name in ['float64', 'float16']: +# return getattr(torch, name) +# +# return None +# +# with safe.Extra(handler): +# x = torch.load('model.pt') +# ``` +# """ +# +# def __init__(self, handler): +# self.handler = handler +# +# def __enter__(self): +# global global_extra_handler +# +# assert global_extra_handler is None, 'already inside an Extra() block' +# global_extra_handler = self.handler +# +# def __exit__(self, exc_type, exc_val, exc_tb): +# global global_extra_handler +# +# global_extra_handler = None +# +# +# unsafe_torch_load = torch.load +# global_extra_handler = None diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 273a7edd..3bff9255 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -1,232 +1,232 @@ -import ldm.modules.encoders.modules -import open_clip -import torch -import transformers.utils.hub - -from modules import shared - - -class ReplaceHelper: - def __init__(self): - self.replaced = [] - - def replace(self, obj, field, func): - original = getattr(obj, field, None) - if original is None: - return None - - self.replaced.append((obj, field, original)) - setattr(obj, field, func) - - return original - - def restore(self): - for obj, field, original in self.replaced: - setattr(obj, field, original) - - self.replaced.clear() - - -class DisableInitialization(ReplaceHelper): - """ - When an object of this class enters a `with` block, it starts: - - preventing torch's layer initialization functions from working - - changes CLIP and OpenCLIP to not download model weights - - changes CLIP to not make requests to check if there is a new version of a file you already have - - When it leaves the block, it reverts everything to how it was before. - - Use it like this: - ``` - with DisableInitialization(): - do_things() - ``` - """ - - def __init__(self, disable_clip=True): - super().__init__() - self.disable_clip = disable_clip - - def replace(self, obj, field, func): - original = getattr(obj, field, None) - if original is None: - return None - - self.replaced.append((obj, field, original)) - setattr(obj, field, func) - - return original - - def __enter__(self): - def do_nothing(*args, **kwargs): - pass - - def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs): - return self.create_model_and_transforms(*args, pretrained=None, **kwargs) - - def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): - res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) - res.name_or_path = pretrained_model_name_or_path - return res - - def transformers_modeling_utils_load_pretrained_model(*args, **kwargs): - args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug - return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs) - - def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs): - - # this file is always 404, prevent making request - if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json': - return None - - try: - res = original(url, *args, local_files_only=True, **kwargs) - if res is None: - res = original(url, *args, local_files_only=False, **kwargs) - return res - except Exception: - return original(url, *args, local_files_only=False, **kwargs) - - def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): - return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs) - - def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs): - return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs) - - def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs): - return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs) - - self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing) - self.replace(torch.nn.init, '_no_grad_normal_', do_nothing) - self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing) - - if self.disable_clip: - self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) - self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) - self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) - self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) - self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) - self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) - - def __exit__(self, exc_type, exc_val, exc_tb): - self.restore() - - -class InitializeOnMeta(ReplaceHelper): - """ - Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device, - which results in those parameters having no values and taking no memory. model.to() will be broken and - will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict. - - Usage: - ``` - with sd_disable_initialization.InitializeOnMeta(): - sd_model = instantiate_from_config(sd_config.model) - ``` - """ - - def __enter__(self): - if shared.cmd_opts.disable_model_loading_ram_optimization: - return - - def set_device(x): - x["device"] = "meta" - return x - - linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs))) - conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs))) - mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs))) - self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None) - - def __exit__(self, exc_type, exc_val, exc_tb): - self.restore() - - -class LoadStateDictOnMeta(ReplaceHelper): - """ - Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device. - As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory. - Meant to be used together with InitializeOnMeta above. - - Usage: - ``` - with sd_disable_initialization.LoadStateDictOnMeta(state_dict): - model.load_state_dict(state_dict, strict=False) - ``` - """ - - def __init__(self, state_dict, device, weight_dtype_conversion=None): - super().__init__() - self.state_dict = state_dict - self.device = device - self.weight_dtype_conversion = weight_dtype_conversion or {} - self.default_dtype = self.weight_dtype_conversion.get('') - - def get_weight_dtype(self, key): - key_first_term, _ = key.split('.', 1) - return self.weight_dtype_conversion.get(key_first_term, self.default_dtype) - - def __enter__(self): - if shared.cmd_opts.disable_model_loading_ram_optimization: - return - - sd = self.state_dict - device = self.device - - def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): - used_param_keys = [] - - for name, param in module._parameters.items(): - if param is None: - continue - - key = prefix + name - sd_param = sd.pop(key, None) - if sd_param is not None: - state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) - used_param_keys.append(key) - - if param.is_meta: - dtype = sd_param.dtype if sd_param is not None else param.dtype - module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) - - for name in module._buffers: - key = prefix + name - - sd_param = sd.pop(key, None) - if sd_param is not None: - state_dict[key] = sd_param - used_param_keys.append(key) - - original(module, state_dict, prefix, *args, **kwargs) - - for key in used_param_keys: - state_dict.pop(key, None) - - def load_state_dict(original, module, state_dict, strict=True): - """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help - because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with - all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. - - In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd). - - The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads - the function and does not call the original) the state dict will just fail to load because weights - would be on the meta device. - """ - - if state_dict is sd: - state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} - - original(module, state_dict, strict=strict) - - module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) - module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) - linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) - conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) - mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) - layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs)) - group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs)) - - def __exit__(self, exc_type, exc_val, exc_tb): - self.restore() +# import ldm.modules.encoders.modules +# import open_clip +# import torch +# import transformers.utils.hub +# +# from modules import shared +# +# +# class ReplaceHelper: +# def __init__(self): +# self.replaced = [] +# +# def replace(self, obj, field, func): +# original = getattr(obj, field, None) +# if original is None: +# return None +# +# self.replaced.append((obj, field, original)) +# setattr(obj, field, func) +# +# return original +# +# def restore(self): +# for obj, field, original in self.replaced: +# setattr(obj, field, original) +# +# self.replaced.clear() +# +# +# class DisableInitialization(ReplaceHelper): +# """ +# When an object of this class enters a `with` block, it starts: +# - preventing torch's layer initialization functions from working +# - changes CLIP and OpenCLIP to not download model weights +# - changes CLIP to not make requests to check if there is a new version of a file you already have +# +# When it leaves the block, it reverts everything to how it was before. +# +# Use it like this: +# ``` +# with DisableInitialization(): +# do_things() +# ``` +# """ +# +# def __init__(self, disable_clip=True): +# super().__init__() +# self.disable_clip = disable_clip +# +# def replace(self, obj, field, func): +# original = getattr(obj, field, None) +# if original is None: +# return None +# +# self.replaced.append((obj, field, original)) +# setattr(obj, field, func) +# +# return original +# +# def __enter__(self): +# def do_nothing(*args, **kwargs): +# pass +# +# def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs): +# return self.create_model_and_transforms(*args, pretrained=None, **kwargs) +# +# def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): +# res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) +# res.name_or_path = pretrained_model_name_or_path +# return res +# +# def transformers_modeling_utils_load_pretrained_model(*args, **kwargs): +# args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug +# return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs) +# +# def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs): +# +# # this file is always 404, prevent making request +# if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json': +# return None +# +# try: +# res = original(url, *args, local_files_only=True, **kwargs) +# if res is None: +# res = original(url, *args, local_files_only=False, **kwargs) +# return res +# except Exception: +# return original(url, *args, local_files_only=False, **kwargs) +# +# def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): +# return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs) +# +# def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs): +# return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs) +# +# def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs): +# return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs) +# +# self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing) +# self.replace(torch.nn.init, '_no_grad_normal_', do_nothing) +# self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing) +# +# if self.disable_clip: +# self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) +# self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) +# self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) +# self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) +# self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) +# self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) +# +# def __exit__(self, exc_type, exc_val, exc_tb): +# self.restore() +# +# +# class InitializeOnMeta(ReplaceHelper): +# """ +# Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device, +# which results in those parameters having no values and taking no memory. model.to() will be broken and +# will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict. +# +# Usage: +# ``` +# with sd_disable_initialization.InitializeOnMeta(): +# sd_model = instantiate_from_config(sd_config.model) +# ``` +# """ +# +# def __enter__(self): +# if shared.cmd_opts.disable_model_loading_ram_optimization: +# return +# +# def set_device(x): +# x["device"] = "meta" +# return x +# +# linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs))) +# conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs))) +# mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs))) +# self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None) +# +# def __exit__(self, exc_type, exc_val, exc_tb): +# self.restore() +# +# +# class LoadStateDictOnMeta(ReplaceHelper): +# """ +# Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device. +# As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory. +# Meant to be used together with InitializeOnMeta above. +# +# Usage: +# ``` +# with sd_disable_initialization.LoadStateDictOnMeta(state_dict): +# model.load_state_dict(state_dict, strict=False) +# ``` +# """ +# +# def __init__(self, state_dict, device, weight_dtype_conversion=None): +# super().__init__() +# self.state_dict = state_dict +# self.device = device +# self.weight_dtype_conversion = weight_dtype_conversion or {} +# self.default_dtype = self.weight_dtype_conversion.get('') +# +# def get_weight_dtype(self, key): +# key_first_term, _ = key.split('.', 1) +# return self.weight_dtype_conversion.get(key_first_term, self.default_dtype) +# +# def __enter__(self): +# if shared.cmd_opts.disable_model_loading_ram_optimization: +# return +# +# sd = self.state_dict +# device = self.device +# +# def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): +# used_param_keys = [] +# +# for name, param in module._parameters.items(): +# if param is None: +# continue +# +# key = prefix + name +# sd_param = sd.pop(key, None) +# if sd_param is not None: +# state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) +# used_param_keys.append(key) +# +# if param.is_meta: +# dtype = sd_param.dtype if sd_param is not None else param.dtype +# module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) +# +# for name in module._buffers: +# key = prefix + name +# +# sd_param = sd.pop(key, None) +# if sd_param is not None: +# state_dict[key] = sd_param +# used_param_keys.append(key) +# +# original(module, state_dict, prefix, *args, **kwargs) +# +# for key in used_param_keys: +# state_dict.pop(key, None) +# +# def load_state_dict(original, module, state_dict, strict=True): +# """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help +# because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with +# all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. +# +# In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd). +# +# The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads +# the function and does not call the original) the state dict will just fail to load because weights +# would be on the meta device. +# """ +# +# if state_dict is sd: +# state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} +# +# original(module, state_dict, strict=strict) +# +# module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) +# module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) +# linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) +# conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) +# mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) +# layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs)) +# group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs)) +# +# def __exit__(self, exc_type, exc_val, exc_tb): +# self.restore() diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f292073b..23d5ed6a 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -1,124 +1,3 @@ -import torch -from torch.nn.functional import silu -from types import MethodType - -from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches -from modules.hypernetworks import hypernetwork -from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 - -import ldm.modules.attention -import ldm.modules.diffusionmodules.model -import ldm.modules.diffusionmodules.openaimodel -import ldm.models.diffusion.ddpm -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms -import ldm.modules.encoders.modules - -import sgm.modules.attention -import sgm.modules.diffusionmodules.model -import sgm.modules.diffusionmodules.openaimodel -import sgm.modules.encoders.modules - -attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward -diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity -diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward - -# new memory efficient cross attention blocks do not support hypernets and we already -# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention -ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention -ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention - -# silence new console spam from SD2 -ldm.modules.attention.print = shared.ldm_print -ldm.modules.diffusionmodules.model.print = shared.ldm_print -ldm.util.print = shared.ldm_print -ldm.models.diffusion.ddpm.print = shared.ldm_print - -optimizers = [] -current_optimizer: sd_hijack_optimizations.SdOptimization = None - -ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward) -ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward) - -sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward) -sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward) - - -def list_optimizers(): - new_optimizers = script_callbacks.list_optimizers_callback() - - new_optimizers = [x for x in new_optimizers if x.is_available()] - - new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True) - - optimizers.clear() - optimizers.extend(new_optimizers) - - -def apply_optimizations(option=None): - return - - -def undo_optimizations(): - return - - -def fix_checkpoint(): - """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want - checkpoints to be added when not training (there's a warning)""" - - pass - - -def weighted_loss(sd_model, pred, target, mean=True): - #Calculate the weight normally, but ignore the mean - loss = sd_model._old_get_loss(pred, target, mean=False) - - #Check if we have weights available - weight = getattr(sd_model, '_custom_loss_weight', None) - if weight is not None: - loss *= weight - - #Return the loss, as mean if specified - return loss.mean() if mean else loss - -def weighted_forward(sd_model, x, c, w, *args, **kwargs): - try: - #Temporarily append weights to a place accessible during loss calc - sd_model._custom_loss_weight = w - - #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely - #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set - if not hasattr(sd_model, '_old_get_loss'): - sd_model._old_get_loss = sd_model.get_loss - sd_model.get_loss = MethodType(weighted_loss, sd_model) - - #Run the standard forward function, but with the patched 'get_loss' - return sd_model.forward(x, c, *args, **kwargs) - finally: - try: - #Delete temporary weights if appended - del sd_model._custom_loss_weight - except AttributeError: - pass - - #If we have an old loss function, reset the loss function to the original one - if hasattr(sd_model, '_old_get_loss'): - sd_model.get_loss = sd_model._old_get_loss - del sd_model._old_get_loss - -def apply_weighted_forward(sd_model): - #Add new function 'weighted_forward' that can be called to calc weighted loss - sd_model.weighted_forward = MethodType(weighted_forward, sd_model) - -def undo_weighted_forward(sd_model): - try: - del sd_model.weighted_forward - except AttributeError: - pass - - class StableDiffusionModelHijack: fixes = None layers = None @@ -156,74 +35,234 @@ class StableDiffusionModelHijack: pass -class EmbeddingsWithFixes(torch.nn.Module): - def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'): - super().__init__() - self.wrapped = wrapped - self.embeddings = embeddings - self.textual_inversion_key = textual_inversion_key - self.weight = self.wrapped.weight - - def forward(self, input_ids): - batch_fixes = self.embeddings.fixes - self.embeddings.fixes = None - - inputs_embeds = self.wrapped(input_ids) - - if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: - return inputs_embeds - - vecs = [] - for fixes, tensor in zip(batch_fixes, inputs_embeds): - for offset, embedding in fixes: - vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec - emb = devices.cond_cast_unet(vec) - emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) - tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype) - - vecs.append(tensor) - - return torch.stack(vecs) - - -class TextualInversionEmbeddings(torch.nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs): - super().__init__(num_embeddings, embedding_dim, **kwargs) - - self.embeddings = model_hijack - self.textual_inversion_key = textual_inversion_key - - @property - def wrapped(self): - return super().forward - - def forward(self, input_ids): - return EmbeddingsWithFixes.forward(self, input_ids) - - -def add_circular_option_to_conv_2d(): - conv2d_constructor = torch.nn.Conv2d.__init__ - - def conv2d_constructor_circular(self, *args, **kwargs): - return conv2d_constructor(self, *args, padding_mode='circular', **kwargs) - - torch.nn.Conv2d.__init__ = conv2d_constructor_circular - - model_hijack = StableDiffusionModelHijack() - -def register_buffer(self, name, attr): - """ - Fix register buffer bug for Mac OS. - """ - - if type(attr) == torch.Tensor: - if attr.device != devices.device: - attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None)) - - setattr(self, name, attr) - - -ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer -ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer +# import torch +# from torch.nn.functional import silu +# from types import MethodType +# +# from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches +# from modules.hypernetworks import hypernetwork +# from modules.shared import cmd_opts +# from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 +# +# import ldm.modules.attention +# import ldm.modules.diffusionmodules.model +# import ldm.modules.diffusionmodules.openaimodel +# import ldm.models.diffusion.ddpm +# import ldm.models.diffusion.ddim +# import ldm.models.diffusion.plms +# import ldm.modules.encoders.modules +# +# import sgm.modules.attention +# import sgm.modules.diffusionmodules.model +# import sgm.modules.diffusionmodules.openaimodel +# import sgm.modules.encoders.modules +# +# attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward +# diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity +# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +# +# # new memory efficient cross attention blocks do not support hypernets and we already +# # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention +# ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention +# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention +# +# # silence new console spam from SD2 +# ldm.modules.attention.print = shared.ldm_print +# ldm.modules.diffusionmodules.model.print = shared.ldm_print +# ldm.util.print = shared.ldm_print +# ldm.models.diffusion.ddpm.print = shared.ldm_print +# +# optimizers = [] +# current_optimizer: sd_hijack_optimizations.SdOptimization = None +# +# ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward) +# ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward) +# +# sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward) +# sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward) +# +# +# def list_optimizers(): +# new_optimizers = script_callbacks.list_optimizers_callback() +# +# new_optimizers = [x for x in new_optimizers if x.is_available()] +# +# new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True) +# +# optimizers.clear() +# optimizers.extend(new_optimizers) +# +# +# def apply_optimizations(option=None): +# return +# +# +# def undo_optimizations(): +# return +# +# +# def fix_checkpoint(): +# """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want +# checkpoints to be added when not training (there's a warning)""" +# +# pass +# +# +# def weighted_loss(sd_model, pred, target, mean=True): +# #Calculate the weight normally, but ignore the mean +# loss = sd_model._old_get_loss(pred, target, mean=False) +# +# #Check if we have weights available +# weight = getattr(sd_model, '_custom_loss_weight', None) +# if weight is not None: +# loss *= weight +# +# #Return the loss, as mean if specified +# return loss.mean() if mean else loss +# +# def weighted_forward(sd_model, x, c, w, *args, **kwargs): +# try: +# #Temporarily append weights to a place accessible during loss calc +# sd_model._custom_loss_weight = w +# +# #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely +# #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set +# if not hasattr(sd_model, '_old_get_loss'): +# sd_model._old_get_loss = sd_model.get_loss +# sd_model.get_loss = MethodType(weighted_loss, sd_model) +# +# #Run the standard forward function, but with the patched 'get_loss' +# return sd_model.forward(x, c, *args, **kwargs) +# finally: +# try: +# #Delete temporary weights if appended +# del sd_model._custom_loss_weight +# except AttributeError: +# pass +# +# #If we have an old loss function, reset the loss function to the original one +# if hasattr(sd_model, '_old_get_loss'): +# sd_model.get_loss = sd_model._old_get_loss +# del sd_model._old_get_loss +# +# def apply_weighted_forward(sd_model): +# #Add new function 'weighted_forward' that can be called to calc weighted loss +# sd_model.weighted_forward = MethodType(weighted_forward, sd_model) +# +# def undo_weighted_forward(sd_model): +# try: +# del sd_model.weighted_forward +# except AttributeError: +# pass +# +# +# class StableDiffusionModelHijack: +# fixes = None +# layers = None +# circular_enabled = False +# clip = None +# optimization_method = None +# +# def __init__(self): +# self.extra_generation_params = {} +# self.comments = [] +# +# def apply_optimizations(self, option=None): +# pass +# +# def convert_sdxl_to_ssd(self, m): +# pass +# +# def hijack(self, m): +# pass +# +# def undo_hijack(self, m): +# pass +# +# def apply_circular(self, enable): +# pass +# +# def clear_comments(self): +# self.comments = [] +# self.extra_generation_params = {} +# +# def get_prompt_lengths(self, text, cond_stage_model): +# pass +# +# def redo_hijack(self, m): +# pass +# +# +# class EmbeddingsWithFixes(torch.nn.Module): +# def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'): +# super().__init__() +# self.wrapped = wrapped +# self.embeddings = embeddings +# self.textual_inversion_key = textual_inversion_key +# self.weight = self.wrapped.weight +# +# def forward(self, input_ids): +# batch_fixes = self.embeddings.fixes +# self.embeddings.fixes = None +# +# inputs_embeds = self.wrapped(input_ids) +# +# if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: +# return inputs_embeds +# +# vecs = [] +# for fixes, tensor in zip(batch_fixes, inputs_embeds): +# for offset, embedding in fixes: +# vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec +# emb = devices.cond_cast_unet(vec) +# emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) +# tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype) +# +# vecs.append(tensor) +# +# return torch.stack(vecs) +# +# +# class TextualInversionEmbeddings(torch.nn.Embedding): +# def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs): +# super().__init__(num_embeddings, embedding_dim, **kwargs) +# +# self.embeddings = model_hijack +# self.textual_inversion_key = textual_inversion_key +# +# @property +# def wrapped(self): +# return super().forward +# +# def forward(self, input_ids): +# return EmbeddingsWithFixes.forward(self, input_ids) +# +# +# def add_circular_option_to_conv_2d(): +# conv2d_constructor = torch.nn.Conv2d.__init__ +# +# def conv2d_constructor_circular(self, *args, **kwargs): +# return conv2d_constructor(self, *args, padding_mode='circular', **kwargs) +# +# torch.nn.Conv2d.__init__ = conv2d_constructor_circular +# +# +# +# +# +# def register_buffer(self, name, attr): +# """ +# Fix register buffer bug for Mac OS. +# """ +# +# if type(attr) == torch.Tensor: +# if attr.device != devices.device: +# attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None)) +# +# setattr(self, name, attr) +# +# +# ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer +# ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py index 2604d969..8b11b443 100644 --- a/modules/sd_hijack_checkpoint.py +++ b/modules/sd_hijack_checkpoint.py @@ -1,46 +1,46 @@ -from torch.utils.checkpoint import checkpoint - -import ldm.modules.attention -import ldm.modules.diffusionmodules.openaimodel - - -def BasicTransformerBlock_forward(self, x, context=None): - return checkpoint(self._forward, x, context) - - -def AttentionBlock_forward(self, x): - return checkpoint(self._forward, x) - - -def ResBlock_forward(self, x, emb): - return checkpoint(self._forward, x, emb) - - -stored = [] - - -def add(): - if len(stored) != 0: - return - - stored.extend([ - ldm.modules.attention.BasicTransformerBlock.forward, - ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, - ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward - ]) - - ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward - ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward - ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward - - -def remove(): - if len(stored) == 0: - return - - ldm.modules.attention.BasicTransformerBlock.forward = stored[0] - ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] - ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] - - stored.clear() - +# from torch.utils.checkpoint import checkpoint +# +# import ldm.modules.attention +# import ldm.modules.diffusionmodules.openaimodel +# +# +# def BasicTransformerBlock_forward(self, x, context=None): +# return checkpoint(self._forward, x, context) +# +# +# def AttentionBlock_forward(self, x): +# return checkpoint(self._forward, x) +# +# +# def ResBlock_forward(self, x, emb): +# return checkpoint(self._forward, x, emb) +# +# +# stored = [] +# +# +# def add(): +# if len(stored) != 0: +# return +# +# stored.extend([ +# ldm.modules.attention.BasicTransformerBlock.forward, +# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, +# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward +# ]) +# +# ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward +# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward +# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward +# +# +# def remove(): +# if len(stored) == 0: +# return +# +# ldm.modules.attention.BasicTransformerBlock.forward = stored[0] +# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] +# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] +# +# stored.clear() +# diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 0269f1f5..696835ad 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,677 +1,677 @@ -from __future__ import annotations -import math -import psutil -import platform - -import torch -from torch import einsum - -from ldm.util import default -from einops import rearrange - -from modules import shared, errors, devices, sub_quadratic_attention -from modules.hypernetworks import hypernetwork - -import ldm.modules.attention -import ldm.modules.diffusionmodules.model - -import sgm.modules.attention -import sgm.modules.diffusionmodules.model - -diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward -sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward - - -class SdOptimization: - name: str = None - label: str | None = None - cmd_opt: str | None = None - priority: int = 0 - - def title(self): - if self.label is None: - return self.name - - return f"{self.name} - {self.label}" - - def is_available(self): - return True - - def apply(self): - pass - - def undo(self): - ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward - - sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward - sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward - - -class SdOptimizationXformers(SdOptimization): - name = "xformers" - cmd_opt = "xformers" - priority = 100 - - def is_available(self): - return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)) - - def apply(self): - ldm.modules.attention.CrossAttention.forward = xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward - sgm.modules.attention.CrossAttention.forward = xformers_attention_forward - sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward - - -class SdOptimizationSdpNoMem(SdOptimization): - name = "sdp-no-mem" - label = "scaled dot product without memory efficient attention" - cmd_opt = "opt_sdp_no_mem_attention" - priority = 80 - - def is_available(self): - return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) - - def apply(self): - ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward - sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward - sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward - - -class SdOptimizationSdp(SdOptimizationSdpNoMem): - name = "sdp" - label = "scaled dot product" - cmd_opt = "opt_sdp_attention" - priority = 70 - - def apply(self): - ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward - sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward - sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward - - -class SdOptimizationSubQuad(SdOptimization): - name = "sub-quadratic" - cmd_opt = "opt_sub_quad_attention" - - @property - def priority(self): - return 1000 if shared.device.type == 'mps' else 10 - - def apply(self): - ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward - sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward - sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward - - -class SdOptimizationV1(SdOptimization): - name = "V1" - label = "original v1" - cmd_opt = "opt_split_attention_v1" - priority = 10 - - def apply(self): - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 - sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 - - -class SdOptimizationInvokeAI(SdOptimization): - name = "InvokeAI" - cmd_opt = "opt_split_attention_invokeai" - - @property - def priority(self): - return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10 - - def apply(self): - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI - sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI - - -class SdOptimizationDoggettx(SdOptimization): - name = "Doggettx" - cmd_opt = "opt_split_attention" - priority = 90 - - def apply(self): - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward - sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward - sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward - - -def list_optimizers(res): - res.extend([ - SdOptimizationXformers(), - SdOptimizationSdpNoMem(), - SdOptimizationSdp(), - SdOptimizationSubQuad(), - SdOptimizationV1(), - SdOptimizationInvokeAI(), - SdOptimizationDoggettx(), - ]) - - -if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: - try: - import xformers.ops - shared.xformers_available = True - except Exception: - errors.report("Cannot import xformers", exc_info=True) - - -def get_available_vram(): - if shared.device.type == 'cuda': - stats = torch.cuda.memory_stats(shared.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - return mem_free_total - else: - return psutil.virtual_memory().available - - -# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - - context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) - del context, context_k, context_v, x - - q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) - del q_in, k_in, v_in - - dtype = q.dtype - if shared.opts.upcast_attn: - q, k, v = q.float(), k.float(), v.float() - - with devices.without_autocast(disable=not shared.opts.upcast_attn): - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale - - s2 = s1.softmax(dim=-1) - del s1 - - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - del q, k, v - - r1 = r1.to(dtype) - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - - -# taken from https://github.com/Doggettx/stable-diffusion and modified -def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - - context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) - - dtype = q_in.dtype - if shared.opts.upcast_attn: - q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() - - with devices.without_autocast(disable=not shared.opts.upcast_attn): - k_in = k_in * self.scale - - del context, x - - q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - mem_free_total = get_available_vram() - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps - for i in range(0, q.shape[1], slice_size): - end = min(i + slice_size, q.shape[1]) - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - - del q, k, v - - r1 = r1.to(dtype) - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - - -# -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- -mem_total_gb = psutil.virtual_memory().total // (1 << 30) - - -def einsum_op_compvis(q, k, v): - s = einsum('b i d, b j d -> b i j', q, k) - s = s.softmax(dim=-1, dtype=s.dtype) - return einsum('b i j, b j d -> b i d', s, v) - - -def einsum_op_slice_0(q, k, v, slice_size): - r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[0], slice_size): - end = i + slice_size - r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) - return r - - -def einsum_op_slice_1(q, k, v, slice_size): - r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) - return r - - -def einsum_op_mps_v1(q, k, v): - if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 - return einsum_op_compvis(q, k, v) - else: - slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - if slice_size % 4096 == 0: - slice_size -= 1 - return einsum_op_slice_1(q, k, v, slice_size) - - -def einsum_op_mps_v2(q, k, v): - if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: - return einsum_op_compvis(q, k, v) - else: - return einsum_op_slice_0(q, k, v, 1) - - -def einsum_op_tensor_mem(q, k, v, max_tensor_mb): - size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) - if size_mb <= max_tensor_mb: - return einsum_op_compvis(q, k, v) - div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() - if div <= q.shape[0]: - return einsum_op_slice_0(q, k, v, q.shape[0] // div) - return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) - - -def einsum_op_cuda(q, k, v): - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - # Divide factor of safety as there's copying and fragmentation - return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) - - -def einsum_op(q, k, v): - if q.device.type == 'cuda': - return einsum_op_cuda(q, k, v) - - if q.device.type == 'mps': - if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18: - return einsum_op_mps_v1(q, k, v) - return einsum_op_mps_v2(q, k, v) - - # Smaller slices are faster due to L2/L3/SLC caches. - # Tested on i7 with 8MB L3 cache. - return einsum_op_tensor_mem(q, k, v, 32) - - -def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, context_k, context_v, x - - dtype = q.dtype - if shared.opts.upcast_attn: - q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() - - with devices.without_autocast(disable=not shared.opts.upcast_attn): - k = k * self.scale - - q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v)) - r = einsum_op(q, k, v) - r = r.to(dtype) - return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) - -# -- End of code from https://github.com/invoke-ai/InvokeAI -- - - -# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 -# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface -def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs): - assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." - - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, context_k, context_v, x - - q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) - k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) - v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) - - if q.device.type == 'mps': - q, k, v = q.contiguous(), k.contiguous(), v.contiguous() - - dtype = q.dtype - if shared.opts.upcast_attn: - q, k = q.float(), k.float() - - x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) - - x = x.to(dtype) - - x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) - - out_proj, dropout = self.to_out - x = out_proj(x) - x = dropout(x) - - return x - - -def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): - bytes_per_token = torch.finfo(q.dtype).bits//8 - batch_x_heads, q_tokens, _ = q.shape - _, k_tokens, _ = k.shape - qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - - if chunk_threshold is None: - if q.device.type == 'mps': - chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token) - else: - chunk_threshold_bytes = int(get_available_vram() * 0.7) - elif chunk_threshold == 0: - chunk_threshold_bytes = None - else: - chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) - - if kv_chunk_size_min is None and chunk_threshold_bytes is not None: - kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) - elif kv_chunk_size_min == 0: - kv_chunk_size_min = None - - if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: - # the big matmul fits into our memory limit; do everything in 1 chunk, - # i.e. send it down the unchunked fast-path - kv_chunk_size = k_tokens - - with devices.without_autocast(disable=q.dtype == v.dtype): - return sub_quadratic_attention.efficient_dot_product_attention( - q, - k, - v, - query_chunk_size=q_chunk_size, - kv_chunk_size=kv_chunk_size, - kv_chunk_size_min = kv_chunk_size_min, - use_checkpoint=use_checkpoint, - ) - - -def get_xformers_flash_attention_op(q, k, v): - if not shared.cmd_opts.xformers_flash_attention: - return None - - try: - flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp - fw, bw = flash_attention_op - if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): - return flash_attention_op - except Exception as e: - errors.display_once(e, "enabling flash attention") - - return None - - -def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): - h = self.heads - q_in = self.to_q(x) - context = default(context, x) - - context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) - - q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in)) - - del q_in, k_in, v_in - - dtype = q.dtype - if shared.opts.upcast_attn: - q, k, v = q.float(), k.float(), v.float() - - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) - - out = out.to(dtype) - - b, n, h, d = out.shape - out = out.reshape(b, n, h * d) - return self.to_out(out) - - -# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py -# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface -def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs): - batch_size, sequence_length, inner_dim = x.shape - - if mask is not None: - mask = self.prepare_attention_mask(mask, sequence_length, batch_size) - mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) - - h = self.heads - q_in = self.to_q(x) - context = default(context, x) - - context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) - - head_dim = inner_dim // h - q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) - k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) - v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) - - del q_in, k_in, v_in - - dtype = q.dtype - if shared.opts.upcast_attn: - q, k, v = q.float(), k.float(), v.float() - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - hidden_states = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim) - hidden_states = hidden_states.to(dtype) - - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - return hidden_states - - -def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs): - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): - return scaled_dot_product_attention_forward(self, x, context, mask) - - -def cross_attention_attnblock_forward(self, x): - h_ = x - h_ = self.norm(h_) - q1 = self.q(h_) - k1 = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q1.shape - - q2 = q1.reshape(b, c, h*w) - del q1 - - q = q2.permute(0, 2, 1) # b,hw,c - del q2 - - k = k1.reshape(b, c, h*w) # b,c,hw - del k1 - - h_ = torch.zeros_like(k, device=q.device) - - mem_free_total = get_available_vram() - - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - mem_required = tensor_size * 2.5 - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - - w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w2 = w1 * (int(c)**(-0.5)) - del w1 - w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) - del w2 - - # attend to values - v1 = v.reshape(b, c, h*w) - w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w3 - - h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w4 - - h2 = h_.reshape(b, c, h, w) - del h_ - - h3 = self.proj_out(h2) - del h2 - - h3 += x - - return h3 - - -def xformers_attnblock_forward(self, x): - try: - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - b, c, h, w = q.shape - q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) - dtype = q.dtype - if shared.opts.upcast_attn: - q, k = q.float(), k.float() - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) - out = out.to(dtype) - out = rearrange(out, 'b (h w) c -> b c h w', h=h) - out = self.proj_out(out) - return x + out - except NotImplementedError: - return cross_attention_attnblock_forward(self, x) - - -def sdp_attnblock_forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - b, c, h, w = q.shape - q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) - dtype = q.dtype - if shared.opts.upcast_attn: - q, k, v = q.float(), k.float(), v.float() - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) - out = out.to(dtype) - out = rearrange(out, 'b (h w) c -> b c h w', h=h) - out = self.proj_out(out) - return x + out - - -def sdp_no_mem_attnblock_forward(self, x): - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): - return sdp_attnblock_forward(self, x) - - -def sub_quad_attnblock_forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - b, c, h, w = q.shape - q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) - out = rearrange(out, 'b (h w) c -> b c h w', h=h) - out = self.proj_out(out) - return x + out +# from __future__ import annotations +# import math +# import psutil +# import platform +# +# import torch +# from torch import einsum +# +# from ldm.util import default +# from einops import rearrange +# +# from modules import shared, errors, devices, sub_quadratic_attention +# from modules.hypernetworks import hypernetwork +# +# import ldm.modules.attention +# import ldm.modules.diffusionmodules.model +# +# import sgm.modules.attention +# import sgm.modules.diffusionmodules.model +# +# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +# sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward +# +# +# class SdOptimization: +# name: str = None +# label: str | None = None +# cmd_opt: str | None = None +# priority: int = 0 +# +# def title(self): +# if self.label is None: +# return self.name +# +# return f"{self.name} - {self.label}" +# +# def is_available(self): +# return True +# +# def apply(self): +# pass +# +# def undo(self): +# ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward +# ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward +# +# sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward +# sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward +# +# +# class SdOptimizationXformers(SdOptimization): +# name = "xformers" +# cmd_opt = "xformers" +# priority = 100 +# +# def is_available(self): +# return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)) +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = xformers_attention_forward +# ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward +# sgm.modules.attention.CrossAttention.forward = xformers_attention_forward +# sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward +# +# +# class SdOptimizationSdpNoMem(SdOptimization): +# name = "sdp-no-mem" +# label = "scaled dot product without memory efficient attention" +# cmd_opt = "opt_sdp_no_mem_attention" +# priority = 80 +# +# def is_available(self): +# return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward +# ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward +# sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward +# sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward +# +# +# class SdOptimizationSdp(SdOptimizationSdpNoMem): +# name = "sdp" +# label = "scaled dot product" +# cmd_opt = "opt_sdp_attention" +# priority = 70 +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward +# ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward +# sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward +# sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward +# +# +# class SdOptimizationSubQuad(SdOptimization): +# name = "sub-quadratic" +# cmd_opt = "opt_sub_quad_attention" +# +# @property +# def priority(self): +# return 1000 if shared.device.type == 'mps' else 10 +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward +# ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward +# sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward +# sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward +# +# +# class SdOptimizationV1(SdOptimization): +# name = "V1" +# label = "original v1" +# cmd_opt = "opt_split_attention_v1" +# priority = 10 +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 +# sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 +# +# +# class SdOptimizationInvokeAI(SdOptimization): +# name = "InvokeAI" +# cmd_opt = "opt_split_attention_invokeai" +# +# @property +# def priority(self): +# return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10 +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI +# sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI +# +# +# class SdOptimizationDoggettx(SdOptimization): +# name = "Doggettx" +# cmd_opt = "opt_split_attention" +# priority = 90 +# +# def apply(self): +# ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward +# ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward +# sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward +# sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward +# +# +# def list_optimizers(res): +# res.extend([ +# SdOptimizationXformers(), +# SdOptimizationSdpNoMem(), +# SdOptimizationSdp(), +# SdOptimizationSubQuad(), +# SdOptimizationV1(), +# SdOptimizationInvokeAI(), +# SdOptimizationDoggettx(), +# ]) +# +# +# if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: +# try: +# import xformers.ops +# shared.xformers_available = True +# except Exception: +# errors.report("Cannot import xformers", exc_info=True) +# +# +# def get_available_vram(): +# if shared.device.type == 'cuda': +# stats = torch.cuda.memory_stats(shared.device) +# mem_active = stats['active_bytes.all.current'] +# mem_reserved = stats['reserved_bytes.all.current'] +# mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) +# mem_free_torch = mem_reserved - mem_active +# mem_free_total = mem_free_cuda + mem_free_torch +# return mem_free_total +# else: +# return psutil.virtual_memory().available +# +# +# # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion +# def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs): +# h = self.heads +# +# q_in = self.to_q(x) +# context = default(context, x) +# +# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) +# k_in = self.to_k(context_k) +# v_in = self.to_v(context_v) +# del context, context_k, context_v, x +# +# q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) +# del q_in, k_in, v_in +# +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k, v = q.float(), k.float(), v.float() +# +# with devices.without_autocast(disable=not shared.opts.upcast_attn): +# r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) +# for i in range(0, q.shape[0], 2): +# end = i + 2 +# s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) +# s1 *= self.scale +# +# s2 = s1.softmax(dim=-1) +# del s1 +# +# r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) +# del s2 +# del q, k, v +# +# r1 = r1.to(dtype) +# +# r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) +# del r1 +# +# return self.to_out(r2) +# +# +# # taken from https://github.com/Doggettx/stable-diffusion and modified +# def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs): +# h = self.heads +# +# q_in = self.to_q(x) +# context = default(context, x) +# +# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) +# k_in = self.to_k(context_k) +# v_in = self.to_v(context_v) +# +# dtype = q_in.dtype +# if shared.opts.upcast_attn: +# q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() +# +# with devices.without_autocast(disable=not shared.opts.upcast_attn): +# k_in = k_in * self.scale +# +# del context, x +# +# q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) +# del q_in, k_in, v_in +# +# r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) +# +# mem_free_total = get_available_vram() +# +# gb = 1024 ** 3 +# tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() +# modifier = 3 if q.element_size() == 2 else 2.5 +# mem_required = tensor_size * modifier +# steps = 1 +# +# if mem_required > mem_free_total: +# steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) +# # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " +# # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") +# +# if steps > 64: +# max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 +# raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' +# f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') +# +# slice_size = q.shape[1] // steps +# for i in range(0, q.shape[1], slice_size): +# end = min(i + slice_size, q.shape[1]) +# s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) +# +# s2 = s1.softmax(dim=-1, dtype=q.dtype) +# del s1 +# +# r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) +# del s2 +# +# del q, k, v +# +# r1 = r1.to(dtype) +# +# r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) +# del r1 +# +# return self.to_out(r2) +# +# +# # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- +# mem_total_gb = psutil.virtual_memory().total // (1 << 30) +# +# +# def einsum_op_compvis(q, k, v): +# s = einsum('b i d, b j d -> b i j', q, k) +# s = s.softmax(dim=-1, dtype=s.dtype) +# return einsum('b i j, b j d -> b i d', s, v) +# +# +# def einsum_op_slice_0(q, k, v, slice_size): +# r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) +# for i in range(0, q.shape[0], slice_size): +# end = i + slice_size +# r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) +# return r +# +# +# def einsum_op_slice_1(q, k, v, slice_size): +# r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) +# for i in range(0, q.shape[1], slice_size): +# end = i + slice_size +# r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) +# return r +# +# +# def einsum_op_mps_v1(q, k, v): +# if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 +# return einsum_op_compvis(q, k, v) +# else: +# slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) +# if slice_size % 4096 == 0: +# slice_size -= 1 +# return einsum_op_slice_1(q, k, v, slice_size) +# +# +# def einsum_op_mps_v2(q, k, v): +# if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: +# return einsum_op_compvis(q, k, v) +# else: +# return einsum_op_slice_0(q, k, v, 1) +# +# +# def einsum_op_tensor_mem(q, k, v, max_tensor_mb): +# size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) +# if size_mb <= max_tensor_mb: +# return einsum_op_compvis(q, k, v) +# div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() +# if div <= q.shape[0]: +# return einsum_op_slice_0(q, k, v, q.shape[0] // div) +# return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) +# +# +# def einsum_op_cuda(q, k, v): +# stats = torch.cuda.memory_stats(q.device) +# mem_active = stats['active_bytes.all.current'] +# mem_reserved = stats['reserved_bytes.all.current'] +# mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) +# mem_free_torch = mem_reserved - mem_active +# mem_free_total = mem_free_cuda + mem_free_torch +# # Divide factor of safety as there's copying and fragmentation +# return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) +# +# +# def einsum_op(q, k, v): +# if q.device.type == 'cuda': +# return einsum_op_cuda(q, k, v) +# +# if q.device.type == 'mps': +# if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18: +# return einsum_op_mps_v1(q, k, v) +# return einsum_op_mps_v2(q, k, v) +# +# # Smaller slices are faster due to L2/L3/SLC caches. +# # Tested on i7 with 8MB L3 cache. +# return einsum_op_tensor_mem(q, k, v, 32) +# +# +# def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs): +# h = self.heads +# +# q = self.to_q(x) +# context = default(context, x) +# +# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) +# k = self.to_k(context_k) +# v = self.to_v(context_v) +# del context, context_k, context_v, x +# +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() +# +# with devices.without_autocast(disable=not shared.opts.upcast_attn): +# k = k * self.scale +# +# q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v)) +# r = einsum_op(q, k, v) +# r = r.to(dtype) +# return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) +# +# # -- End of code from https://github.com/invoke-ai/InvokeAI -- +# +# +# # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +# # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface +# def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs): +# assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." +# +# h = self.heads +# +# q = self.to_q(x) +# context = default(context, x) +# +# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) +# k = self.to_k(context_k) +# v = self.to_v(context_v) +# del context, context_k, context_v, x +# +# q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) +# k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) +# v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) +# +# if q.device.type == 'mps': +# q, k, v = q.contiguous(), k.contiguous(), v.contiguous() +# +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k = q.float(), k.float() +# +# x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) +# +# x = x.to(dtype) +# +# x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) +# +# out_proj, dropout = self.to_out +# x = out_proj(x) +# x = dropout(x) +# +# return x +# +# +# def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): +# bytes_per_token = torch.finfo(q.dtype).bits//8 +# batch_x_heads, q_tokens, _ = q.shape +# _, k_tokens, _ = k.shape +# qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens +# +# if chunk_threshold is None: +# if q.device.type == 'mps': +# chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token) +# else: +# chunk_threshold_bytes = int(get_available_vram() * 0.7) +# elif chunk_threshold == 0: +# chunk_threshold_bytes = None +# else: +# chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) +# +# if kv_chunk_size_min is None and chunk_threshold_bytes is not None: +# kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) +# elif kv_chunk_size_min == 0: +# kv_chunk_size_min = None +# +# if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: +# # the big matmul fits into our memory limit; do everything in 1 chunk, +# # i.e. send it down the unchunked fast-path +# kv_chunk_size = k_tokens +# +# with devices.without_autocast(disable=q.dtype == v.dtype): +# return sub_quadratic_attention.efficient_dot_product_attention( +# q, +# k, +# v, +# query_chunk_size=q_chunk_size, +# kv_chunk_size=kv_chunk_size, +# kv_chunk_size_min = kv_chunk_size_min, +# use_checkpoint=use_checkpoint, +# ) +# +# +# def get_xformers_flash_attention_op(q, k, v): +# if not shared.cmd_opts.xformers_flash_attention: +# return None +# +# try: +# flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp +# fw, bw = flash_attention_op +# if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): +# return flash_attention_op +# except Exception as e: +# errors.display_once(e, "enabling flash attention") +# +# return None +# +# +# def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): +# h = self.heads +# q_in = self.to_q(x) +# context = default(context, x) +# +# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) +# k_in = self.to_k(context_k) +# v_in = self.to_v(context_v) +# +# q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in)) +# +# del q_in, k_in, v_in +# +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k, v = q.float(), k.float(), v.float() +# +# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) +# +# out = out.to(dtype) +# +# b, n, h, d = out.shape +# out = out.reshape(b, n, h * d) +# return self.to_out(out) +# +# +# # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py +# # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface +# def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs): +# batch_size, sequence_length, inner_dim = x.shape +# +# if mask is not None: +# mask = self.prepare_attention_mask(mask, sequence_length, batch_size) +# mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) +# +# h = self.heads +# q_in = self.to_q(x) +# context = default(context, x) +# +# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) +# k_in = self.to_k(context_k) +# v_in = self.to_v(context_v) +# +# head_dim = inner_dim // h +# q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) +# k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) +# v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) +# +# del q_in, k_in, v_in +# +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k, v = q.float(), k.float(), v.float() +# +# # the output of sdp = (batch, num_heads, seq_len, head_dim) +# hidden_states = torch.nn.functional.scaled_dot_product_attention( +# q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False +# ) +# +# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim) +# hidden_states = hidden_states.to(dtype) +# +# # linear proj +# hidden_states = self.to_out[0](hidden_states) +# # dropout +# hidden_states = self.to_out[1](hidden_states) +# return hidden_states +# +# +# def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs): +# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): +# return scaled_dot_product_attention_forward(self, x, context, mask) +# +# +# def cross_attention_attnblock_forward(self, x): +# h_ = x +# h_ = self.norm(h_) +# q1 = self.q(h_) +# k1 = self.k(h_) +# v = self.v(h_) +# +# # compute attention +# b, c, h, w = q1.shape +# +# q2 = q1.reshape(b, c, h*w) +# del q1 +# +# q = q2.permute(0, 2, 1) # b,hw,c +# del q2 +# +# k = k1.reshape(b, c, h*w) # b,c,hw +# del k1 +# +# h_ = torch.zeros_like(k, device=q.device) +# +# mem_free_total = get_available_vram() +# +# tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() +# mem_required = tensor_size * 2.5 +# steps = 1 +# +# if mem_required > mem_free_total: +# steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) +# +# slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] +# for i in range(0, q.shape[1], slice_size): +# end = i + slice_size +# +# w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] +# w2 = w1 * (int(c)**(-0.5)) +# del w1 +# w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) +# del w2 +# +# # attend to values +# v1 = v.reshape(b, c, h*w) +# w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) +# del w3 +# +# h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] +# del v1, w4 +# +# h2 = h_.reshape(b, c, h, w) +# del h_ +# +# h3 = self.proj_out(h2) +# del h2 +# +# h3 += x +# +# return h3 +# +# +# def xformers_attnblock_forward(self, x): +# try: +# h_ = x +# h_ = self.norm(h_) +# q = self.q(h_) +# k = self.k(h_) +# v = self.v(h_) +# b, c, h, w = q.shape +# q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k = q.float(), k.float() +# q = q.contiguous() +# k = k.contiguous() +# v = v.contiguous() +# out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) +# out = out.to(dtype) +# out = rearrange(out, 'b (h w) c -> b c h w', h=h) +# out = self.proj_out(out) +# return x + out +# except NotImplementedError: +# return cross_attention_attnblock_forward(self, x) +# +# +# def sdp_attnblock_forward(self, x): +# h_ = x +# h_ = self.norm(h_) +# q = self.q(h_) +# k = self.k(h_) +# v = self.v(h_) +# b, c, h, w = q.shape +# q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) +# dtype = q.dtype +# if shared.opts.upcast_attn: +# q, k, v = q.float(), k.float(), v.float() +# q = q.contiguous() +# k = k.contiguous() +# v = v.contiguous() +# out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) +# out = out.to(dtype) +# out = rearrange(out, 'b (h w) c -> b c h w', h=h) +# out = self.proj_out(out) +# return x + out +# +# +# def sdp_no_mem_attnblock_forward(self, x): +# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): +# return sdp_attnblock_forward(self, x) +# +# +# def sub_quad_attnblock_forward(self, x): +# h_ = x +# h_ = self.norm(h_) +# q = self.q(h_) +# k = self.k(h_) +# v = self.v(h_) +# b, c, h, w = q.shape +# q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) +# q = q.contiguous() +# k = k.contiguous() +# v = v.contiguous() +# out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) +# out = rearrange(out, 'b (h w) c -> b c h w', h=h) +# out = self.proj_out(out) +# return x + out diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index b4f03b13..eb4a0af4 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -1,154 +1,154 @@ -import torch -from packaging import version -from einops import repeat -import math - -from modules import devices -from modules.sd_hijack_utils import CondFunc - - -class TorchHijackForUnet: - """ - This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; - this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 - """ - - def __getattr__(self, item): - if item == 'cat': - return self.cat - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") - - def cat(self, tensors, *args, **kwargs): - if len(tensors) == 2: - a, b = tensors - if a.shape[-2:] != b.shape[-2:]: - a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") - - tensors = (a, b) - - return torch.cat(tensors, *args, **kwargs) - - -th = TorchHijackForUnet() - - -# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling -def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): - """Always make sure inputs to unet are in correct dtype.""" - if isinstance(cond, dict): - for y in cond.keys(): - if isinstance(cond[y], list): - cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] - else: - cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] - - with devices.autocast(): - result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) - if devices.unet_needs_upcast: - return result.float() - else: - return result - - -# Monkey patch to create timestep embed tensor on device, avoiding a block. -def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - else: - embedding = repeat(timesteps, 'b -> b d', d=dim) - return embedding - - -# Monkey patch to SpatialTransformer removing unnecessary contiguous calls. -# Prevents a lot of unnecessary aten::copy_ calls -def spatial_transformer_forward(_, self, x: torch.Tensor, context=None): - # note: if no context is given, cross-attention defaults to self-attention - if not isinstance(context, list): - context = [context] - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - if not self.use_linear: - x = self.proj_in(x) - x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) - if self.use_linear: - x = self.proj_in(x) - for i, block in enumerate(self.transformer_blocks): - x = block(x, context=context[i]) - if self.use_linear: - x = self.proj_out(x) - x = x.view(b, h, w, c).permute(0, 3, 1, 2) - if not self.use_linear: - x = self.proj_out(x) - return x + x_in - - -class GELUHijack(torch.nn.GELU, torch.nn.Module): - def __init__(self, *args, **kwargs): - torch.nn.GELU.__init__(self, *args, **kwargs) - def forward(self, x): - if devices.unet_needs_upcast: - return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) - else: - return torch.nn.GELU.forward(self, x) - - -ddpm_edit_hijack = None -def hijack_ddpm_edit(): - global ddpm_edit_hijack - if not ddpm_edit_hijack: - CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) - CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) - ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model) - - -unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) -CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) - -if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): - CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) - CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) - CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) - -first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 -first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) - -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) -CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) - - -def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): - if devices.unet_needs_upcast and timesteps.dtype == torch.int64: - dtype = torch.float32 - else: - dtype = devices.dtype_unet - return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) - - -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) -CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) +# import torch +# from packaging import version +# from einops import repeat +# import math +# +# from modules import devices +# from modules.sd_hijack_utils import CondFunc +# +# +# class TorchHijackForUnet: +# """ +# This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; +# this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 +# """ +# +# def __getattr__(self, item): +# if item == 'cat': +# return self.cat +# +# if hasattr(torch, item): +# return getattr(torch, item) +# +# raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") +# +# def cat(self, tensors, *args, **kwargs): +# if len(tensors) == 2: +# a, b = tensors +# if a.shape[-2:] != b.shape[-2:]: +# a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") +# +# tensors = (a, b) +# +# return torch.cat(tensors, *args, **kwargs) +# +# +# th = TorchHijackForUnet() +# +# +# # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling +# def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): +# """Always make sure inputs to unet are in correct dtype.""" +# if isinstance(cond, dict): +# for y in cond.keys(): +# if isinstance(cond[y], list): +# cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] +# else: +# cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] +# +# with devices.autocast(): +# result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) +# if devices.unet_needs_upcast: +# return result.float() +# else: +# return result +# +# +# # Monkey patch to create timestep embed tensor on device, avoiding a block. +# def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False): +# """ +# Create sinusoidal timestep embeddings. +# :param timesteps: a 1-D Tensor of N indices, one per batch element. +# These may be fractional. +# :param dim: the dimension of the output. +# :param max_period: controls the minimum frequency of the embeddings. +# :return: an [N x dim] Tensor of positional embeddings. +# """ +# if not repeat_only: +# half = dim // 2 +# freqs = torch.exp( +# -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half +# ) +# args = timesteps[:, None].float() * freqs[None] +# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) +# if dim % 2: +# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) +# else: +# embedding = repeat(timesteps, 'b -> b d', d=dim) +# return embedding +# +# +# # Monkey patch to SpatialTransformer removing unnecessary contiguous calls. +# # Prevents a lot of unnecessary aten::copy_ calls +# def spatial_transformer_forward(_, self, x: torch.Tensor, context=None): +# # note: if no context is given, cross-attention defaults to self-attention +# if not isinstance(context, list): +# context = [context] +# b, c, h, w = x.shape +# x_in = x +# x = self.norm(x) +# if not self.use_linear: +# x = self.proj_in(x) +# x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) +# if self.use_linear: +# x = self.proj_in(x) +# for i, block in enumerate(self.transformer_blocks): +# x = block(x, context=context[i]) +# if self.use_linear: +# x = self.proj_out(x) +# x = x.view(b, h, w, c).permute(0, 3, 1, 2) +# if not self.use_linear: +# x = self.proj_out(x) +# return x + x_in +# +# +# class GELUHijack(torch.nn.GELU, torch.nn.Module): +# def __init__(self, *args, **kwargs): +# torch.nn.GELU.__init__(self, *args, **kwargs) +# def forward(self, x): +# if devices.unet_needs_upcast: +# return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) +# else: +# return torch.nn.GELU.forward(self, x) +# +# +# ddpm_edit_hijack = None +# def hijack_ddpm_edit(): +# global ddpm_edit_hijack +# if not ddpm_edit_hijack: +# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) +# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) +# ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model) +# +# +# unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast +# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) +# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) +# CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) +# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) +# +# if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): +# CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) +# CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) +# CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) +# +# first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 +# first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) +# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) +# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) +# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) +# +# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) +# CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) +# +# +# def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): +# if devices.unet_needs_upcast and timesteps.dtype == torch.int64: +# dtype = torch.float32 +# else: +# dtype = devices.dtype_unet +# return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) +# +# +# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) +# CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) diff --git a/modules/sd_models.py b/modules/sd_models.py index d89a8326..f0aaab00 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -10,7 +10,6 @@ import re import safetensors.torch from omegaconf import OmegaConf, ListConfig from urllib import request -import ldm.modules.midas as midas import gc from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches @@ -415,89 +414,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer def enable_midas_autodownload(): - """ - Gives the ldm.modules.midas.api.load_model function automatic downloading. - - When the 512-depth-ema model, and other future models like it, is loaded, - it calls midas.api.load_model to load the associated midas depth model. - This function applies a wrapper to download the model to the correct - location automatically. - """ - - midas_path = os.path.join(paths.models_path, 'midas') - - # stable-diffusion-stability-ai hard-codes the midas model path to - # a location that differs from where other scripts using this model look. - # HACK: Overriding the path here. - for k, v in midas.api.ISL_PATHS.items(): - file_name = os.path.basename(v) - midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name) - - midas_urls = { - "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", - "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt", - "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt", - "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt", - } - - midas.api.load_model_inner = midas.api.load_model - - def load_model_wrapper(model_type): - path = midas.api.ISL_PATHS[model_type] - if not os.path.exists(path): - if not os.path.exists(midas_path): - os.mkdir(midas_path) - - print(f"Downloading midas model weights for {model_type} to {path}") - request.urlretrieve(midas_urls[model_type], path) - print(f"{model_type} downloaded") - - return midas.api.load_model_inner(model_type) - - midas.api.load_model = load_model_wrapper + pass def patch_given_betas(): - import ldm.models.diffusion.ddpm - - def patched_register_schedule(*args, **kwargs): - """a modified version of register_schedule function that converts plain list from Omegaconf into numpy""" - - if isinstance(args[1], ListConfig): - args = (args[0], np.array(args[1]), *args[2:]) - - original_register_schedule(*args, **kwargs) - - original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule) + pass def repair_config(sd_config, state_dict=None): - if not hasattr(sd_config.model.params, "use_ema"): - sd_config.model.params.use_ema = False - - if hasattr(sd_config.model.params, 'unet_config'): - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half": - sd_config.model.params.unet_config.params.use_fp16 = True - - if hasattr(sd_config.model.params, 'first_stage_config'): - if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: - sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" - - # For UnCLIP-L, override the hardcoded karlo directory - if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"): - karlo_path = os.path.join(paths.models_path, 'karlo') - sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path) - - # Do not use checkpoint for inference. - # This helps prevent extra performance overhead on checking parameters. - # The perf overhead is about 100ms/it on 4090 for SDXL. - if hasattr(sd_config.model.params, "network_config"): - sd_config.model.params.network_config.params.use_checkpoint = False - if hasattr(sd_config.model.params, "unet_config"): - sd_config.model.params.unet_config.params.use_checkpoint = False - + pass def rescale_zero_terminal_snr_abar(alphas_cumprod): diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 41e5087d..c8e1f9f5 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -1,137 +1,137 @@ -import os - -import torch - -from modules import shared, paths, sd_disable_initialization, devices - -sd_configs_path = shared.sd_configs_path -sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") -sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference") - - -config_default = shared.sd_default_config -# config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") -config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") -config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") -config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") -config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") -config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") -config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") -config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") -config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") -config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") -config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") -config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") -config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") -config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") - - -def is_using_v_parameterization_for_sd2(state_dict): - """ - Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. - """ - - import ldm.modules.diffusionmodules.openaimodel - - device = devices.device - - with sd_disable_initialization.DisableInitialization(): - unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( - use_checkpoint=False, - use_fp16=False, - image_size=32, - in_channels=4, - out_channels=4, - model_channels=320, - attention_resolutions=[4, 2, 1], - num_res_blocks=2, - channel_mult=[1, 2, 4, 4], - num_head_channels=64, - use_spatial_transformer=True, - use_linear_in_transformer=True, - transformer_depth=1, - context_dim=1024, - legacy=False - ) - unet.eval() - - with torch.no_grad(): - unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} - unet.load_state_dict(unet_sd, strict=True) - unet.to(device=device, dtype=devices.dtype_unet) - - test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 - x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 - - with devices.autocast(): - out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item() - - return out < -1 - - -def guess_model_config_from_state_dict(sd, filename): - sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) - diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) - sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) - - if "model.diffusion_model.x_embedder.proj.weight" in sd: - return config_sd3 - - if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: - if diffusion_model_input.shape[1] == 9: - return config_sdxl_inpainting - else: - return config_sdxl - - if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: - return config_sdxl_refiner - elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: - return config_depth_model - elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: - return config_unclip - elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024: - return config_unopenclip - - if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: - if diffusion_model_input.shape[1] == 9: - return config_sd2_inpainting - # elif is_using_v_parameterization_for_sd2(sd): - # return config_sd2v - else: - return config_sd2v - - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - return config_inpainting - if diffusion_model_input.shape[1] == 8: - return config_instruct_pix2pix - - if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: - if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: - return config_alt_diffusion_m18 - return config_alt_diffusion - - return config_default - - -def find_checkpoint_config(state_dict, info): - if info is None: - return guess_model_config_from_state_dict(state_dict, "") - - config = find_checkpoint_config_near_filename(info) - if config is not None: - return config - - return guess_model_config_from_state_dict(state_dict, info.filename) - - -def find_checkpoint_config_near_filename(info): - if info is None: - return None - - config = f"{os.path.splitext(info.filename)[0]}.yaml" - if os.path.exists(config): - return config - - return None - +# import os +# +# import torch +# +# from modules import shared, paths, sd_disable_initialization, devices +# +# sd_configs_path = shared.sd_configs_path +# # sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") +# # sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference") +# +# +# config_default = shared.sd_default_config +# # config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") +# config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +# config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") +# config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") +# config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") +# config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") +# config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") +# config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") +# config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") +# config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") +# config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") +# config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") +# config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") +# config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") +# +# +# def is_using_v_parameterization_for_sd2(state_dict): +# """ +# Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. +# """ +# +# import ldm.modules.diffusionmodules.openaimodel +# +# device = devices.device +# +# with sd_disable_initialization.DisableInitialization(): +# unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( +# use_checkpoint=False, +# use_fp16=False, +# image_size=32, +# in_channels=4, +# out_channels=4, +# model_channels=320, +# attention_resolutions=[4, 2, 1], +# num_res_blocks=2, +# channel_mult=[1, 2, 4, 4], +# num_head_channels=64, +# use_spatial_transformer=True, +# use_linear_in_transformer=True, +# transformer_depth=1, +# context_dim=1024, +# legacy=False +# ) +# unet.eval() +# +# with torch.no_grad(): +# unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} +# unet.load_state_dict(unet_sd, strict=True) +# unet.to(device=device, dtype=devices.dtype_unet) +# +# test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 +# x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 +# +# with devices.autocast(): +# out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item() +# +# return out < -1 +# +# +# def guess_model_config_from_state_dict(sd, filename): +# sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) +# diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) +# sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) +# +# if "model.diffusion_model.x_embedder.proj.weight" in sd: +# return config_sd3 +# +# if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: +# if diffusion_model_input.shape[1] == 9: +# return config_sdxl_inpainting +# else: +# return config_sdxl +# +# if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: +# return config_sdxl_refiner +# elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: +# return config_depth_model +# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: +# return config_unclip +# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024: +# return config_unopenclip +# +# if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: +# if diffusion_model_input.shape[1] == 9: +# return config_sd2_inpainting +# # elif is_using_v_parameterization_for_sd2(sd): +# # return config_sd2v +# else: +# return config_sd2v +# +# if diffusion_model_input is not None: +# if diffusion_model_input.shape[1] == 9: +# return config_inpainting +# if diffusion_model_input.shape[1] == 8: +# return config_instruct_pix2pix +# +# if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: +# if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: +# return config_alt_diffusion_m18 +# return config_alt_diffusion +# +# return config_default +# +# +# def find_checkpoint_config(state_dict, info): +# if info is None: +# return guess_model_config_from_state_dict(state_dict, "") +# +# config = find_checkpoint_config_near_filename(info) +# if config is not None: +# return config +# +# return guess_model_config_from_state_dict(state_dict, info.filename) +# +# +# def find_checkpoint_config_near_filename(info): +# if info is None: +# return None +# +# config = f"{os.path.splitext(info.filename)[0]}.yaml" +# if os.path.exists(config): +# return config +# +# return None +# diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index 2fce2777..afc075da 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -1,4 +1,3 @@ -from ldm.models.diffusion.ddpm import LatentDiffusion from typing import TYPE_CHECKING @@ -6,7 +5,7 @@ if TYPE_CHECKING: from modules.sd_models import CheckpointInfo -class WebuiSdModel(LatentDiffusion): +class WebuiSdModel: """This class is not actually instantinated, but its fields are created and fieeld by webui""" lowvram: bool diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 3f1bab96..0b84f2fc 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -1,115 +1,115 @@ -from __future__ import annotations - -import torch - -import sgm.models.diffusion -import sgm.modules.diffusionmodules.denoiser_scaling -import sgm.modules.diffusionmodules.discretizer -from modules import devices, shared, prompt_parser -from modules import torch_utils - -from backend import memory_management - - -def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): - - for embedder in self.conditioner.embedders: - embedder.ucg_rate = 0.0 - - width = getattr(batch, 'width', 1024) or 1024 - height = getattr(batch, 'height', 1024) or 1024 - is_negative_prompt = getattr(batch, 'is_negative_prompt', False) - aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score - - devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype()) - - sdxl_conds = { - "txt": batch, - "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), - "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), - "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), - "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), - } - - force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) - c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) - - return c - - -def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs): - if self.model.diffusion_model.in_channels == 9: - x = torch.cat([x] + cond['c_concat'], dim=1) - - return self.model(x, t, cond, *args, **kwargs) - - -def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility - return x - - -sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning -sgm.models.diffusion.DiffusionEngine.apply_model = apply_model -sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding - - -def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): - res = [] - - for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: - encoded = embedder.encode_embedding_init_text(init_text, nvpt) - res.append(encoded) - - return torch.cat(res, dim=1) - - -def tokenize(self: sgm.modules.GeneralConditioner, texts): - for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]: - return embedder.tokenize(texts) - - raise AssertionError('no tokenizer available') - - - -def process_texts(self, texts): - for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: - return embedder.process_texts(texts) - - -def get_target_prompt_token_count(self, token_count): - for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: - return embedder.get_target_prompt_token_count(token_count) - - -# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist -sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text -sgm.modules.GeneralConditioner.tokenize = tokenize -sgm.modules.GeneralConditioner.process_texts = process_texts -sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count - - -def extend_sdxl(model): - """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - - dtype = torch_utils.get_param(model.model.diffusion_model).dtype - model.model.diffusion_model.dtype = dtype - model.model.conditioning_key = 'crossattn' - model.cond_stage_key = 'txt' - # model.cond_stage_model will be set in sd_hijack - - model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" - - discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() - model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) - - model.conditioner.wrapped = torch.nn.Module() - - -sgm.modules.attention.print = shared.ldm_print -sgm.modules.diffusionmodules.model.print = shared.ldm_print -sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print -sgm.modules.encoders.modules.print = shared.ldm_print - -# this gets the code to load the vanilla attention that we override -sgm.modules.attention.SDP_IS_AVAILABLE = True -sgm.modules.attention.XFORMERS_IS_AVAILABLE = False +# from __future__ import annotations +# +# import torch +# +# import sgm.models.diffusion +# import sgm.modules.diffusionmodules.denoiser_scaling +# import sgm.modules.diffusionmodules.discretizer +# from modules import devices, shared, prompt_parser +# from modules import torch_utils +# +# from backend import memory_management +# +# +# def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): +# +# for embedder in self.conditioner.embedders: +# embedder.ucg_rate = 0.0 +# +# width = getattr(batch, 'width', 1024) or 1024 +# height = getattr(batch, 'height', 1024) or 1024 +# is_negative_prompt = getattr(batch, 'is_negative_prompt', False) +# aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score +# +# devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype()) +# +# sdxl_conds = { +# "txt": batch, +# "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), +# "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), +# "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), +# "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), +# } +# +# force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) +# c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) +# +# return c +# +# +# def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs): +# if self.model.diffusion_model.in_channels == 9: +# x = torch.cat([x] + cond['c_concat'], dim=1) +# +# return self.model(x, t, cond, *args, **kwargs) +# +# +# def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility +# return x +# +# +# sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +# sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +# sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding +# +# +# def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): +# res = [] +# +# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: +# encoded = embedder.encode_embedding_init_text(init_text, nvpt) +# res.append(encoded) +# +# return torch.cat(res, dim=1) +# +# +# def tokenize(self: sgm.modules.GeneralConditioner, texts): +# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]: +# return embedder.tokenize(texts) +# +# raise AssertionError('no tokenizer available') +# +# +# +# def process_texts(self, texts): +# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: +# return embedder.process_texts(texts) +# +# +# def get_target_prompt_token_count(self, token_count): +# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: +# return embedder.get_target_prompt_token_count(token_count) +# +# +# # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist +# sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text +# sgm.modules.GeneralConditioner.tokenize = tokenize +# sgm.modules.GeneralConditioner.process_texts = process_texts +# sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count +# +# +# def extend_sdxl(model): +# """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" +# +# dtype = torch_utils.get_param(model.model.diffusion_model).dtype +# model.model.diffusion_model.dtype = dtype +# model.model.conditioning_key = 'crossattn' +# model.cond_stage_key = 'txt' +# # model.cond_stage_model will be set in sd_hijack +# +# model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" +# +# discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() +# model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) +# +# model.conditioner.wrapped = torch.nn.Module() +# +# +# sgm.modules.attention.print = shared.ldm_print +# sgm.modules.diffusionmodules.model.print = shared.ldm_print +# sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print +# sgm.modules.encoders.modules.print = shared.ldm_print +# +# # this gets the code to load the vanilla attention that we override +# sgm.modules.attention.SDP_IS_AVAILABLE = True +# sgm.modules.attention.XFORMERS_IS_AVAILABLE = False diff --git a/modules/shared_items.py b/modules/shared_items.py index 11f10b3f..1568ba36 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -35,9 +35,7 @@ def refresh_vae_list(): def cross_attention_optimizations(): - import modules.sd_hijack - - return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"] + return ["Automatic"] def sd_unet_items(): diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 71c032df..512bf724 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -1,245 +1,243 @@ -import os -import numpy as np -import PIL -import torch -from torch.utils.data import Dataset, DataLoader, Sampler -from torchvision import transforms -from collections import defaultdict -from random import shuffle, choices - -import random -import tqdm -from modules import devices, shared, images -import re - -from ldm.modules.distributions.distributions import DiagonalGaussianDistribution - -re_numbers_at_start = re.compile(r"^[-\d]+\s*") - - -class DatasetEntry: - def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None): - self.filename = filename - self.filename_text = filename_text - self.weight = weight - self.latent_dist = latent_dist - self.latent_sample = latent_sample - self.cond = cond - self.cond_text = cond_text - self.pixel_values = pixel_values - - -class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): - re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None - - self.placeholder_token = placeholder_token - - self.flip = transforms.RandomHorizontalFlip(p=flip_p) - - self.dataset = [] - - with open(template_file, "r") as file: - lines = [x.strip() for x in file.readlines()] - - self.lines = lines - - assert data_root, 'dataset directory not specified' - assert os.path.isdir(data_root), "Dataset directory doesn't exist" - assert os.listdir(data_root), "Dataset directory is empty" - - self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] - - self.shuffle_tags = shuffle_tags - self.tag_drop_out = tag_drop_out - groups = defaultdict(list) - - print("Preparing dataset...") - for path in tqdm.tqdm(self.image_paths): - alpha_channel = None - if shared.state.interrupted: - raise Exception("interrupted") - try: - image = images.read(path) - #Currently does not work for single color transparency - #We would need to read image.info['transparency'] for that - if use_weight and 'A' in image.getbands(): - alpha_channel = image.getchannel('A') - image = image.convert('RGB') - if not varsize: - image = image.resize((width, height), PIL.Image.BICUBIC) - except Exception: - continue - - text_filename = f"{os.path.splitext(path)[0]}.txt" - filename = os.path.basename(path) - - if os.path.exists(text_filename): - with open(text_filename, "r", encoding="utf8") as file: - filename_text = file.read() - else: - filename_text = os.path.splitext(filename)[0] - filename_text = re.sub(re_numbers_at_start, '', filename_text) - if re_word: - tokens = re_word.findall(filename_text) - filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) - - npimage = np.array(image).astype(np.uint8) - npimage = (npimage / 127.5 - 1.0).astype(np.float32) - - torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) - latent_sample = None - - with devices.autocast(): - latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) - - #Perform latent sampling, even for random sampling. - #We need the sample dimensions for the weights - if latent_sampling_method == "deterministic": - if isinstance(latent_dist, DiagonalGaussianDistribution): - # Works only for DiagonalGaussianDistribution - latent_dist.std = 0 - else: - latent_sampling_method = "once" - latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - - if use_weight and alpha_channel is not None: - channels, *latent_size = latent_sample.shape - weight_img = alpha_channel.resize(latent_size) - npweight = np.array(weight_img).astype(np.float32) - #Repeat for every channel in the latent sample - weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size) - #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default. - weight -= weight.min() - weight /= weight.mean() - elif use_weight: - #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later - weight = torch.ones(latent_sample.shape) - else: - weight = None - - if latent_sampling_method == "random": - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) - else: - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight) - - if not (self.tag_drop_out != 0 or self.shuffle_tags): - entry.cond_text = self.create_text(filename_text) - - if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): - with devices.autocast(): - entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) - groups[image.size].append(len(self.dataset)) - self.dataset.append(entry) - del torchdata - del latent_dist - del latent_sample - del weight - - self.length = len(self.dataset) - self.groups = list(groups.values()) - assert self.length > 0, "No images have been found in the dataset." - self.batch_size = min(batch_size, self.length) - self.gradient_step = min(gradient_step, self.length // self.batch_size) - self.latent_sampling_method = latent_sampling_method - - if len(groups) > 1: - print("Buckets:") - for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]): - print(f" {w}x{h}: {len(ids)}") - print() - - def create_text(self, filename_text): - text = random.choice(self.lines) - tags = filename_text.split(',') - if self.tag_drop_out != 0: - tags = [t for t in tags if random.random() > self.tag_drop_out] - if self.shuffle_tags: - random.shuffle(tags) - text = text.replace("[filewords]", ','.join(tags)) - text = text.replace("[name]", self.placeholder_token) - return text - - def __len__(self): - return self.length - - def __getitem__(self, i): - entry = self.dataset[i] - if self.tag_drop_out != 0 or self.shuffle_tags: - entry.cond_text = self.create_text(entry.filename_text) - if self.latent_sampling_method == "random": - entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) - return entry - - -class GroupedBatchSampler(Sampler): - def __init__(self, data_source: PersonalizedBase, batch_size: int): - super().__init__(data_source) - - n = len(data_source) - self.groups = data_source.groups - self.len = n_batch = n // batch_size - expected = [len(g) / n * n_batch * batch_size for g in data_source.groups] - self.base = [int(e) // batch_size for e in expected] - self.n_rand_batches = nrb = n_batch - sum(self.base) - self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] - self.batch_size = batch_size - - def __len__(self): - return self.len - - def __iter__(self): - b = self.batch_size - - for g in self.groups: - shuffle(g) - - batches = [] - for g in self.groups: - batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) - for _ in range(self.n_rand_batches): - rand_group = choices(self.groups, self.probs)[0] - batches.append(choices(rand_group, k=b)) - - shuffle(batches) - - yield from batches - - -class PersonalizedDataLoader(DataLoader): - def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): - super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) - if latent_sampling_method == "random": - self.collate_fn = collate_wrapper_random - else: - self.collate_fn = collate_wrapper - - -class BatchLoader: - def __init__(self, data): - self.cond_text = [entry.cond_text for entry in data] - self.cond = [entry.cond for entry in data] - self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) - if all(entry.weight is not None for entry in data): - self.weight = torch.stack([entry.weight for entry in data]).squeeze(1) - else: - self.weight = None - #self.emb_index = [entry.emb_index for entry in data] - #print(self.latent_sample.device) - - def pin_memory(self): - self.latent_sample = self.latent_sample.pin_memory() - return self - -def collate_wrapper(batch): - return BatchLoader(batch) - -class BatchLoaderRandom(BatchLoader): - def __init__(self, data): - super().__init__(data) - - def pin_memory(self): - return self - -def collate_wrapper_random(batch): - return BatchLoaderRandom(batch) +# import os +# import numpy as np +# import PIL +# import torch +# from torch.utils.data import Dataset, DataLoader, Sampler +# from torchvision import transforms +# from collections import defaultdict +# from random import shuffle, choices +# +# import random +# import tqdm +# from modules import devices, shared, images +# import re +# +# re_numbers_at_start = re.compile(r"^[-\d]+\s*") +# +# +# class DatasetEntry: +# def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None): +# self.filename = filename +# self.filename_text = filename_text +# self.weight = weight +# self.latent_dist = latent_dist +# self.latent_sample = latent_sample +# self.cond = cond +# self.cond_text = cond_text +# self.pixel_values = pixel_values +# +# +# class PersonalizedBase(Dataset): +# def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): +# re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None +# +# self.placeholder_token = placeholder_token +# +# self.flip = transforms.RandomHorizontalFlip(p=flip_p) +# +# self.dataset = [] +# +# with open(template_file, "r") as file: +# lines = [x.strip() for x in file.readlines()] +# +# self.lines = lines +# +# assert data_root, 'dataset directory not specified' +# assert os.path.isdir(data_root), "Dataset directory doesn't exist" +# assert os.listdir(data_root), "Dataset directory is empty" +# +# self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] +# +# self.shuffle_tags = shuffle_tags +# self.tag_drop_out = tag_drop_out +# groups = defaultdict(list) +# +# print("Preparing dataset...") +# for path in tqdm.tqdm(self.image_paths): +# alpha_channel = None +# if shared.state.interrupted: +# raise Exception("interrupted") +# try: +# image = images.read(path) +# #Currently does not work for single color transparency +# #We would need to read image.info['transparency'] for that +# if use_weight and 'A' in image.getbands(): +# alpha_channel = image.getchannel('A') +# image = image.convert('RGB') +# if not varsize: +# image = image.resize((width, height), PIL.Image.BICUBIC) +# except Exception: +# continue +# +# text_filename = f"{os.path.splitext(path)[0]}.txt" +# filename = os.path.basename(path) +# +# if os.path.exists(text_filename): +# with open(text_filename, "r", encoding="utf8") as file: +# filename_text = file.read() +# else: +# filename_text = os.path.splitext(filename)[0] +# filename_text = re.sub(re_numbers_at_start, '', filename_text) +# if re_word: +# tokens = re_word.findall(filename_text) +# filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) +# +# npimage = np.array(image).astype(np.uint8) +# npimage = (npimage / 127.5 - 1.0).astype(np.float32) +# +# torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) +# latent_sample = None +# +# with devices.autocast(): +# latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) +# +# #Perform latent sampling, even for random sampling. +# #We need the sample dimensions for the weights +# if latent_sampling_method == "deterministic": +# if isinstance(latent_dist, DiagonalGaussianDistribution): +# # Works only for DiagonalGaussianDistribution +# latent_dist.std = 0 +# else: +# latent_sampling_method = "once" +# latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) +# +# if use_weight and alpha_channel is not None: +# channels, *latent_size = latent_sample.shape +# weight_img = alpha_channel.resize(latent_size) +# npweight = np.array(weight_img).astype(np.float32) +# #Repeat for every channel in the latent sample +# weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size) +# #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default. +# weight -= weight.min() +# weight /= weight.mean() +# elif use_weight: +# #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later +# weight = torch.ones(latent_sample.shape) +# else: +# weight = None +# +# if latent_sampling_method == "random": +# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) +# else: +# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight) +# +# if not (self.tag_drop_out != 0 or self.shuffle_tags): +# entry.cond_text = self.create_text(filename_text) +# +# if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): +# with devices.autocast(): +# entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) +# groups[image.size].append(len(self.dataset)) +# self.dataset.append(entry) +# del torchdata +# del latent_dist +# del latent_sample +# del weight +# +# self.length = len(self.dataset) +# self.groups = list(groups.values()) +# assert self.length > 0, "No images have been found in the dataset." +# self.batch_size = min(batch_size, self.length) +# self.gradient_step = min(gradient_step, self.length // self.batch_size) +# self.latent_sampling_method = latent_sampling_method +# +# if len(groups) > 1: +# print("Buckets:") +# for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]): +# print(f" {w}x{h}: {len(ids)}") +# print() +# +# def create_text(self, filename_text): +# text = random.choice(self.lines) +# tags = filename_text.split(',') +# if self.tag_drop_out != 0: +# tags = [t for t in tags if random.random() > self.tag_drop_out] +# if self.shuffle_tags: +# random.shuffle(tags) +# text = text.replace("[filewords]", ','.join(tags)) +# text = text.replace("[name]", self.placeholder_token) +# return text +# +# def __len__(self): +# return self.length +# +# def __getitem__(self, i): +# entry = self.dataset[i] +# if self.tag_drop_out != 0 or self.shuffle_tags: +# entry.cond_text = self.create_text(entry.filename_text) +# if self.latent_sampling_method == "random": +# entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) +# return entry +# +# +# class GroupedBatchSampler(Sampler): +# def __init__(self, data_source: PersonalizedBase, batch_size: int): +# super().__init__(data_source) +# +# n = len(data_source) +# self.groups = data_source.groups +# self.len = n_batch = n // batch_size +# expected = [len(g) / n * n_batch * batch_size for g in data_source.groups] +# self.base = [int(e) // batch_size for e in expected] +# self.n_rand_batches = nrb = n_batch - sum(self.base) +# self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] +# self.batch_size = batch_size +# +# def __len__(self): +# return self.len +# +# def __iter__(self): +# b = self.batch_size +# +# for g in self.groups: +# shuffle(g) +# +# batches = [] +# for g in self.groups: +# batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) +# for _ in range(self.n_rand_batches): +# rand_group = choices(self.groups, self.probs)[0] +# batches.append(choices(rand_group, k=b)) +# +# shuffle(batches) +# +# yield from batches +# +# +# class PersonalizedDataLoader(DataLoader): +# def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): +# super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) +# if latent_sampling_method == "random": +# self.collate_fn = collate_wrapper_random +# else: +# self.collate_fn = collate_wrapper +# +# +# class BatchLoader: +# def __init__(self, data): +# self.cond_text = [entry.cond_text for entry in data] +# self.cond = [entry.cond for entry in data] +# self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) +# if all(entry.weight is not None for entry in data): +# self.weight = torch.stack([entry.weight for entry in data]).squeeze(1) +# else: +# self.weight = None +# #self.emb_index = [entry.emb_index for entry in data] +# #print(self.latent_sample.device) +# +# def pin_memory(self): +# self.latent_sample = self.latent_sample.pin_memory() +# return self +# +# def collate_wrapper(batch): +# return BatchLoader(batch) +# +# class BatchLoaderRandom(BatchLoader): +# def __init__(self, data): +# super().__init__(data) +# +# def pin_memory(self): +# return self +# +# def collate_wrapper_random(batch): +# return BatchLoaderRandom(batch)