diff --git a/backend/loader.py b/backend/loader.py index 2894b097..32d1f8b0 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -6,8 +6,10 @@ from diffusers import DiffusionPipeline from transformers import modeling_utils from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace from backend.operations import using_forge_operations -from backend.nn.autoencoder_kl import IntegratedAutoencoderKL +from backend.nn.vae import IntegratedAutoencoderKL from backend.nn.clip import IntegratedCLIP, CLIPTextConfig +from backend.nn.unet import IntegratedUNet2DConditionModel + dir_path = os.path.dirname(__file__) @@ -54,6 +56,15 @@ def load_component(component_name, lib_name, cls_name, repo_path, state_dict): load_state_dict(model, sd, ignore_errors=['text_projection', 'logit_scale', 'transformer.text_model.embeddings.position_ids']) return model + if cls_name == 'UNet2DConditionModel': + sd = try_filter_state_dict(state_dict, ['model.diffusion_model.']) + config = IntegratedUNet2DConditionModel.load_config(config_path) + + with using_forge_operations(): + model = IntegratedUNet2DConditionModel.from_config(config) + + load_state_dict(model, sd) + return model print(f'Skipped: {component_name} = {lib_name}.{cls_name}') return None diff --git a/backend/memory_management.py b/backend/memory_management.py new file mode 100644 index 00000000..21989979 --- /dev/null +++ b/backend/memory_management.py @@ -0,0 +1,2 @@ +# will rework soon +from ldm_patched.modules.model_management import * diff --git a/backend/clip.py b/backend/modules/clip.py similarity index 100% rename from backend/clip.py rename to backend/modules/clip.py diff --git a/backend/modules/k_model.py b/backend/modules/k_model.py new file mode 100644 index 00000000..f8d57130 --- /dev/null +++ b/backend/modules/k_model.py @@ -0,0 +1,54 @@ +import torch + +from backend import memory_management +from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler + + +class KModel(torch.nn.Module): + def __init__(self, huggingface_components, storage_dtype, computation_dtype): + super().__init__() + + self.storage_dtype = storage_dtype + self.computation_dtype = computation_dtype + + self.diffusion_model = huggingface_components['unet'] + self.prediction = k_prediction_from_diffusers_scheduler(huggingface_components['scheduler']) + + def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + sigma = t + xc = self.prediction.calculate_input(sigma, x) + if c_concat is not None: + xc = torch.cat([xc] + [c_concat], dim=1) + + context = c_crossattn + dtype = self.computation_dtype + + xc = xc.to(dtype) + t = self.prediction.timestep(t).float() + context = context.to(dtype) + extra_conds = {} + for o in kwargs: + extra = kwargs[o] + if hasattr(extra, "dtype"): + if extra.dtype != torch.int and extra.dtype != torch.long: + extra = extra.to(dtype) + extra_conds[o] = extra + + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() + return self.prediction.calculate_denoised(sigma, model_output, x) + + def memory_required(self, input_shape): + area = input_shape[0] * input_shape[2] * input_shape[3] + dtype_size = memory_management.dtype_size(self.computation_dtype) + + scaler = 1.28 + + # TODO: Consider these again + # if ldm_patched.modules.model_management.xformers_enabled() or ldm_patched.modules.model_management.pytorch_attention_flash_attention(): + # scaler = 1.28 + # else: + # scaler = 1.65 + # if ldm_patched.ldm.modules.attention._ATTN_PRECISION == "fp32": + # dtype_size = 4 + + return scaler * area * dtype_size * 16384 diff --git a/backend/modules/k_prediction.py b/backend/modules/k_prediction.py new file mode 100644 index 00000000..f49f5d50 --- /dev/null +++ b/backend/modules/k_prediction.py @@ -0,0 +1,266 @@ +import math +import torch +import numpy as np + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = torch.clamp(betas, min=0, max=0.999) + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas + + +def time_snr_shift(alpha, t): + if alpha == 1.0: + return t + return alpha * t / (1 + (alpha - 1) * t) + + +def flux_time_shift(mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +class AbstractPrediction(torch.nn.Module): + def __init__(self, sigma_data=1.0, prediction_type='epsilon'): + super().__init__() + self.sigma_data = sigma_data + self.prediction_type = prediction_type + assert self.prediction_type in ['epsilon', 'const', 'v_prediction', 'edm'] + + def calculate_input(self, sigma, noise): + if self.prediction_type == 'const': + return noise + else: + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + if self.prediction_type == 'v_prediction': + return model_input * self.sigma_data ** 2 / ( + sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / ( + sigma ** 2 + self.sigma_data ** 2) ** 0.5 + elif self.prediction_type == 'edm': + return model_input * self.sigma_data ** 2 / ( + sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / ( + sigma ** 2 + self.sigma_data ** 2) ** 0.5 + else: + return model_input - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + if self.prediction_type == 'const': + return sigma * noise + (1.0 - sigma) * latent_image + else: + if max_denoise: + noise = noise * torch.sqrt(1.0 + sigma ** 2.0) + else: + noise = noise * sigma + + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + if self.prediction_type == 'const': + return latent / (1.0 - sigma) + else: + return latent + + +class Prediction(AbstractPrediction): + def __init__(self, sigma_data=1.0, prediction_type='eps', beta_schedule='linear', linear_start=0.00085, + linear_end=0.012, timesteps=1000): + super().__init__(sigma_data=sigma_data, prediction_type=prediction_type) + self.register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3) + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + self.set_sigmas(sigmas) + + def set_sigmas(self, sigmas): + self.register_buffer('sigmas', sigmas.float()) + self.register_buffer('log_sigmas', sigmas.log().float()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device) + + def sigma(self, timestep): + t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1)) + low_idx = t.floor().long() + high_idx = t.ceil().long() + w = t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp().to(timestep.device) + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + percent = 1.0 - percent + return self.sigma(torch.tensor(percent * 999.0)).item() + + +class PredictionEDM(Prediction): + def timestep(self, sigma): + return 0.25 * sigma.log() + + def sigma(self, timestep): + return (timestep / 0.25).exp() + + +class PredictionContinuousEDM(AbstractPrediction): + def __init__(self, sigma_data=1.0, prediction_type='eps', sigma_min=0.002, sigma_max=120.0): + super().__init__(sigma_data=sigma_data, prediction_type=prediction_type) + self.set_parameters(sigma_min, sigma_max, sigma_data) + + def set_parameters(self, sigma_min, sigma_max, sigma_data): + self.sigma_data = sigma_data + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() + + self.register_buffer('sigmas', sigmas) + self.register_buffer('log_sigmas', sigmas.log()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return 0.25 * sigma.log() + + def sigma(self, timestep): + return (timestep / 0.25).exp() + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + percent = 1.0 - percent + + log_sigma_min = math.log(self.sigma_min) + return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) + + +class PredictionContinuousV(PredictionContinuousEDM): + def timestep(self, sigma): + return sigma.atan() / math.pi * 2 + + def sigma(self, timestep): + return (timestep * math.pi / 2).tan() + + +class PredictionFlow(AbstractPrediction): + def __init__(self, sigma_data=1.0, prediction_type='eps', shift=1.0, multiplier=1000, timesteps=1000): + super().__init__(sigma_data=sigma_data, prediction_type=prediction_type) + self.shift = shift + self.multiplier = multiplier + ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier) + self.register_buffer('sigmas', ts) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma * self.multiplier + + def sigma(self, timestep): + return time_snr_shift(self.shift, timestep / self.multiplier) + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 1.0 + if percent >= 1.0: + return 0.0 + return 1.0 - percent + + +class PredictionFlux(AbstractPrediction): + def __init__(self, sigma_data=1.0, prediction_type='eps', shift=1.0, timesteps=10000): + super().__init__(sigma_data=sigma_data, prediction_type=prediction_type) + self.shift = shift + ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps)) + self.register_buffer('sigmas', ts) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma + + def sigma(self, timestep): + return flux_time_shift(self.shift, 1.0, timestep) + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 1.0 + if percent >= 1.0: + return 0.0 + return 1.0 - percent + + +def k_prediction_from_diffusers_scheduler(scheduler): + if hasattr(scheduler.config, 'prediction_type') and scheduler.config.prediction_type in ["epsilon", "v_prediction"]: + if scheduler.config.beta_schedule == "scaled_linear": + return Prediction(sigma_data=1.0, prediction_type=scheduler.config.prediction_type, beta_schedule='linear', + linear_start=scheduler.config.beta_start, linear_end=scheduler.config.beta_end, + timesteps=scheduler.config.num_train_timesteps) + + raise NotImplementedError(f'Failed to recognize {scheduler}') diff --git a/backend/nn/unet.py b/backend/nn/unet.py new file mode 100644 index 00000000..c57af1ed --- /dev/null +++ b/backend/nn/unet.py @@ -0,0 +1,1008 @@ +import math +import torch +import torch as th +import torch.nn.functional as F + +from typing import Optional, Tuple, Union +from diffusers.configuration_utils import ConfigMixin, register_to_config +from torch import nn +from einops import rearrange, repeat +from backend.attention import attention_function + +unet_initial_dtype = torch.float16 +unet_initial_device = None + + +def checkpoint(f, args, parameters, enable=False): + if enable: + raise NotImplementedError('Gradient Checkpointing is not implemented.') + return f(*args) + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d + + +def conv_nd(dims, *args, **kwargs): + if dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def avg_pool_nd(dims, *args, **kwargs): + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def apply_control(h, control, name): + if control is not None and name in control and len(control[name]) > 0: + ctrl = control[name].pop() + if ctrl is not None: + try: + h += ctrl + except: + print("warning control could not be applied", h.shape, ctrl.shape) + return h + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + # Consistent with Kohya to reduce differences between model training and inference. + # Will be 0.005% slower than ComfyUI but Forge outweigh image quality than speed. + + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +class TimestepBlock(nn.Module): + pass + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + def forward(self, x, emb, context=None, transformer_options={}, output_shape=None): + block_inner_modifiers = transformer_options.get("block_inner_modifiers", []) + + for layer_index, layer in enumerate(self): + for modifier in block_inner_modifiers: + x = modifier(x, 'before', layer, layer_index, self, transformer_options) + + if isinstance(layer, TimestepBlock): + x = layer(x, emb, transformer_options) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context, transformer_options) + if "transformer_index" in transformer_options: + transformer_options["transformer_index"] += 1 + elif isinstance(layer, Upsample): + x = layer(x, output_shape=output_shape) + else: + x = layer(x) + + for modifier in block_inner_modifiers: + x = modifier(x, 'after', layer, layer_index, self, transformer_options) + return x + + +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out, dtype=None, device=None): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim, dtype=dtype, device=device), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out, dtype=dtype, device=device) + ) + + def forward(self, x): + return self.net(x) + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) + + def forward(self, x, context=None, value=None, mask=None, transformer_options={}): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) + + out = attention_function(q, k, v, self.heads, mask) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, + inner_dim=None, + disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, + dtype=None, device=None): + super().__init__() + + self.ff_in = ff_in or inner_dim is not None + if inner_dim is None: + inner_dim = dim + + self.is_res = inner_dim == dim + + if self.ff_in: + self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device) + self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device) + + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, + device=device) + self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device) + + if disable_temporal_crossattention: + if switch_temporal_ca_to_sa: + raise ValueError + else: + self.attn2 = None + else: + context_dim_attn2 = None + if not switch_temporal_ca_to_sa: + context_dim_attn2 = context_dim + + self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, + heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) + self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + + self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.checkpoint = checkpoint + self.n_heads = n_heads + self.d_head = d_head + self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa + + def forward(self, x, context=None, transformer_options={}): + return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None, transformer_options={}): + # Stolen from ComfyUI with some modifications + + extra_options = {} + block = transformer_options.get("block", None) + block_index = transformer_options.get("block_index", 0) + transformer_patches = {} + transformer_patches_replace = {} + + for k in transformer_options: + if k == "patches": + transformer_patches = transformer_options[k] + elif k == "patches_replace": + transformer_patches_replace = transformer_options[k] + else: + extra_options[k] = transformer_options[k] + + extra_options["n_heads"] = self.n_heads + extra_options["dim_head"] = self.d_head + + if self.ff_in: + x_skip = x + x = self.ff_in(self.norm_in(x)) + if self.is_res: + x += x_skip + + n = self.norm1(x) + if self.disable_self_attn: + context_attn1 = context + else: + context_attn1 = None + value_attn1 = None + + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + if context_attn1 is None: + context_attn1 = n + value_attn1 = context_attn1 + for p in patch: + n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options) + + if block is not None: + transformer_block = (block[0], block[1], block_index) + else: + transformer_block = None + attn1_replace_patch = transformer_patches_replace.get("attn1", {}) + block_attn1 = transformer_block + if block_attn1 not in attn1_replace_patch: + block_attn1 = block + + if block_attn1 in attn1_replace_patch: + if context_attn1 is None: + context_attn1 = n + value_attn1 = n + n = self.attn1.to_q(n) + context_attn1 = self.attn1.to_k(context_attn1) + value_attn1 = self.attn1.to_v(value_attn1) + n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) + n = self.attn1.to_out(n) + else: + n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=extra_options) + + if "attn1_output_patch" in transformer_patches: + patch = transformer_patches["attn1_output_patch"] + for p in patch: + n = p(n, extra_options) + + x += n + if "middle_patch" in transformer_patches: + patch = transformer_patches["middle_patch"] + for p in patch: + x = p(x, extra_options) + + if self.attn2 is not None: + n = self.norm2(x) + if self.switch_temporal_ca_to_sa: + context_attn2 = n + else: + context_attn2 = context + value_attn2 = None + if "attn2_patch" in transformer_patches: + patch = transformer_patches["attn2_patch"] + value_attn2 = context_attn2 + for p in patch: + n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) + + attn2_replace_patch = transformer_patches_replace.get("attn2", {}) + block_attn2 = transformer_block + if block_attn2 not in attn2_replace_patch: + block_attn2 = block + + if block_attn2 in attn2_replace_patch: + if value_attn2 is None: + value_attn2 = context_attn2 + n = self.attn2.to_q(n) + context_attn2 = self.attn2.to_k(context_attn2) + value_attn2 = self.attn2.to_v(value_attn2) + n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) + n = self.attn2.to_out(n) + else: + n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=extra_options) + + if "attn2_output_patch" in transformer_patches: + patch = transformer_patches["attn2_output_patch"] + for p in patch: + n = p(n, extra_options) + + x += n + x_skip = 0 + + if self.is_res: + x_skip = x + x = self.ff(self.norm3(x)) + if self.is_res: + x += x_skip + + return x + + +class SpatialTransformer(nn.Module): + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, dtype=None, device=None): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] * depth + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, + device=device) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0, dtype=dtype, device=device) + else: + self.proj_in = nn.Linear(in_channels, inner_dim, dtype=dtype, device=device) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, + device=device) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = nn.Conv2d(inner_dim, in_channels, + kernel_size=1, + stride=1, + padding=0, dtype=dtype, device=device) + else: + self.proj_out = nn.Linear(in_channels, inner_dim, dtype=dtype, device=device) + self.use_linear = use_linear + + def forward(self, x, context=None, transformer_options={}): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] * len(self.transformer_blocks) + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + transformer_options["block_index"] = i + x = block(x, context=context[i], transformer_options=transformer_options) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class Upsample(nn.Module): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device) + + def forward(self, x, output_shape=None): + assert x.shape[1] == self.channels + if self.dims == 3: + shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2] + if output_shape is not None: + shape[1] = output_shape[3] + shape[2] = output_shape[4] + else: + shape = [x.shape[2] * 2, x.shape[3] * 2] + if output_shape is not None: + shape[0] = output_shape[2] + shape[1] = output_shape[3] + + x = F.interpolate(x, size=shape, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + kernel_size=3, + exchange_temb_dims=False, + skip_t_emb=False, + dtype=None, + device=None, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.exchange_temb_dims = exchange_temb_dims + + if isinstance(kernel_size, list): + padding = [k // 2 for k in kernel_size] + else: + padding = kernel_size // 2 + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels, dtype=dtype, device=device), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device) + self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device) + elif down: + self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device) + self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.skip_t_emb = skip_t_emb + if self.skip_t_emb: + self.emb_layers = None + self.exchange_temb_dims = False + else: + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), + nn.SiLU(), + nn.Dropout(p=dropout), + conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, + device=device) + , + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) + + def forward(self, x, emb, transformer_options={}): + return checkpoint( + self._forward, (x, emb, transformer_options), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb, transformer_options={}): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + if "groupnorm_wrapper" in transformer_options: + in_norm, in_rest = in_rest[0], in_rest[1:] + h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options) + h = in_rest(h) + else: + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + if "groupnorm_wrapper" in transformer_options: + in_norm = self.in_layers[0] + h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options) + h = self.in_layers[1:](h) + else: + h = self.in_layers(x) + + emb_out = None + if not self.skip_t_emb: + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + if "groupnorm_wrapper" in transformer_options: + h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options) + else: + h = out_norm(h) + if emb_out is not None: + scale, shift = torch.chunk(emb_out, 2, dim=1) + h *= (1 + scale) + h += shift + h = out_rest(h) + else: + if emb_out is not None: + if self.exchange_temb_dims: + emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + h = h + emb_out + if "groupnorm_wrapper" in transformer_options: + h = transformer_options["groupnorm_wrapper"](self.out_layers[0], h, transformer_options) + h = self.out_layers[1:](h) + else: + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin): + config_name = 'config.json' + + @register_to_config + def __init__(self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, + center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D",), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, mid_block_scale_factor: float = 1, dropout: float = 0.0, + act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, + use_linear_projection: bool = False, class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", + class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, *args, **kwargs): + super().__init__() + + in_channels = in_channels + out_channels = out_channels + model_channels = block_out_channels[0] + num_res_blocks = [layers_per_block] * len(block_out_channels) + dropout = dropout + channel_mult = [x // model_channels for x in block_out_channels] + conv_resample = True + dims = 2 + num_classes = None + use_checkpoint = False + adm_in_channels = None + num_heads = -1 + num_head_channels = -1 + num_heads_upsample = -1 + use_scale_shift_norm = False + resblock_updown = False + use_spatial_transformer = True + transformer_depth = [] + transformer_depth_output = [] + transformer_depth_middle = 1 + context_dim = cross_attention_dim + disable_self_attentions: list = None + num_attention_blocks: list = None + disable_middle_self_attn = False + use_linear_in_transformer = use_linear_projection + + for i, d in enumerate(down_block_types): + if 'attn' in d.lower(): + current_transformer_depth = 1 + if isinstance(transformer_layers_per_block, list) and len(transformer_layers_per_block) > i: + current_transformer_depth = transformer_layers_per_block[i] + transformer_depth += [current_transformer_depth] * 2 + transformer_depth_output += [current_transformer_depth] * 3 + else: + transformer_depth += [0] * 2 + transformer_depth_output += [0] * 3 + + if transformer_depth_output[-1] > 1: + transformer_depth_middle = transformer_depth_output[-1] + + if isinstance(attention_head_dim, int): + num_heads = attention_head_dim + elif isinstance(attention_head_dim, list): + num_head_channels = model_channels // attention_head_dim[0] + else: + raise ValueError('Wrong attention heads!') + + if isinstance(projection_class_embeddings_input_dim, int) and projection_class_embeddings_input_dim > 0: + num_classes = 'sequential' + adm_in_channels = projection_class_embeddings_input_dim + + dtype = unet_initial_dtype + device = unet_initial_device + + if context_dim is not None: + assert use_spatial_transformer + + if num_heads == -1: + assert num_head_channels != -1 + + if num_head_channels == -1: + assert num_heads != -1 + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("Bad num_res_blocks") + self.num_res_blocks = num_res_blocks + + if disable_self_attentions is not None: + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + + transformer_depth = transformer_depth[:] + transformer_depth_output = transformer_depth_output[:] + + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.num_heads = num_heads + self.num_head_channels = num_head_channels + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(model_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=dtype, device=device) + elif self.num_classes == "continuous": + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + nn.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device), + ) + ) + else: + raise ValueError('Bad ADM') + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=dtype, device=device) + ) + ] + ) + + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=dtype, + device=device, + ) + ] + ch = mult * model_channels + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append(SpatialTransformer( + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, + disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint, + use_linear=use_linear_in_transformer, dtype=dtype, device=device) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + dtype=dtype, + device=device, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch, dtype=dtype, device=device, + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + mid_block = [ + ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=None, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=dtype, + device=device, + )] + if transformer_depth_middle >= 0: + mid_block += [ + SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint, + use_linear=use_linear_in_transformer, dtype=dtype, device=device + ), + ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=None, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=dtype, + device=device, + )] + self.middle_block = TimestepEmbedSequential(*mid_block) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + channels=ch + ich, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=dtype, + device=device, + ) + ] + ch = model_channels * mult + num_transformers = transformer_depth_output.pop() + if num_transformers > 0: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + SpatialTransformer( + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, + disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint, + use_linear=use_linear_in_transformer, dtype=dtype, device=device + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + dtype=dtype, + device=device, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=dtype, + device=device) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + nn.GroupNorm(32, ch, dtype=dtype, device=device), + nn.SiLU(), + conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=dtype, device=device), + ) + + def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): + transformer_options["original_shape"] = list(x.shape) + transformer_options["transformer_index"] = 0 + transformer_patches = transformer_options.get("patches", {}) + block_modifiers = transformer_options.get("block_modifiers", []) + + assert (y is not None) == (self.num_classes is not None) + + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for id, module in enumerate(self.input_blocks): + transformer_options["block"] = ("input", id) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'before', transformer_options) + + h = module(h, emb, context, transformer_options) + h = apply_control(h, control, 'input') + + for block_modifier in block_modifiers: + h = block_modifier(h, 'after', transformer_options) + + if "input_block_patch" in transformer_patches: + patch = transformer_patches["input_block_patch"] + for p in patch: + h = p(h, transformer_options) + + hs.append(h) + if "input_block_patch_after_skip" in transformer_patches: + patch = transformer_patches["input_block_patch_after_skip"] + for p in patch: + h = p(h, transformer_options) + + transformer_options["block"] = ("middle", 0) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'before', transformer_options) + + h = self.middle_block(h, emb, context, transformer_options) + h = apply_control(h, control, 'middle') + + for block_modifier in block_modifiers: + h = block_modifier(h, 'after', transformer_options) + + for id, module in enumerate(self.output_blocks): + transformer_options["block"] = ("output", id) + hsp = hs.pop() + hsp = apply_control(hsp, control, 'output') + + if "output_block_patch" in transformer_patches: + patch = transformer_patches["output_block_patch"] + for p in patch: + h, hsp = p(h, hsp, transformer_options) + + h = th.cat([h, hsp], dim=1) + del hsp + if len(hs) > 0: + output_shape = hs[-1].shape + else: + output_shape = None + + for block_modifier in block_modifiers: + h = block_modifier(h, 'before', transformer_options) + + h = module(h, emb, context, transformer_options, output_shape) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'after', transformer_options) + + transformer_options["block"] = ("last", 0) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'before', transformer_options) + + if "groupnorm_wrapper" in transformer_options: + out_norm, out_rest = self.out[0], self.out[1:] + h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options) + h = out_rest(h) + else: + h = self.out(h) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'after', transformer_options) + + return h.type(x.dtype) diff --git a/backend/nn/autoencoder_kl.py b/backend/nn/vae.py similarity index 97% rename from backend/nn/autoencoder_kl.py rename to backend/nn/vae.py index d2f9db6d..d98644c1 100644 --- a/backend/nn/autoencoder_kl.py +++ b/backend/nn/vae.py @@ -381,7 +381,7 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin): norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, - shift_factor: Optional[float] = None, + shift_factor: Optional[float] = 0.0, latents_mean: Optional[Tuple[float]] = None, latents_std: Optional[Tuple[float]] = None, force_upcast: float = True, @@ -403,6 +403,9 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin): self.scaling_factor = scaling_factor self.shift_factor = shift_factor + if not isinstance(self.shift_factor, float): + self.shift_factor = 0.0 + def encode(self, x, regulation=None): z = self.encoder(x) z = self.quant_conv(z) @@ -416,3 +419,9 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin): z = self.post_quant_conv(z) x = self.decoder(z) return x + + def process_in(self, latent): + return (latent - self.shift_factor) * self.scaling_factor + + def process_out(self, latent): + return (latent / self.scaling_factor) + self.shift_factor diff --git a/extensions-builtin/forge_preprocessor_inpaint/scripts/preprocessor_inpaint.py b/extensions-builtin/forge_preprocessor_inpaint/scripts/preprocessor_inpaint.py index 1ccb65fa..3e07d1f3 100644 --- a/extensions-builtin/forge_preprocessor_inpaint/scripts/preprocessor_inpaint.py +++ b/extensions-builtin/forge_preprocessor_inpaint/scripts/preprocessor_inpaint.py @@ -46,7 +46,7 @@ class PreprocessorInpaintOnly(PreprocessorInpaint): # This is a powerful VAE with integrated memory management, bf16, and tiled fallback. latent_image = vae.encode(self.image.movedim(1, -1)) - latent_image = process.sd_model.forge_objects.unet.model.latent_format.process_in(latent_image) + latent_image = process.sd_model.forge_objects.vae.first_stage_model.process_in(latent_image) B, C, H, W = latent_image.shape @@ -154,7 +154,7 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly): def process_before_every_sampling(self, process, cond, mask, *args, **kwargs): cond, mask = super().process_before_every_sampling(process, cond, mask, *args, **kwargs) - sigma_max = process.sd_model.forge_objects.unet.model.model_sampling.sigma_max + sigma_max = process.sd_model.forge_objects.unet.model.prediction.sigma_max original_noise = kwargs['noise'] process.modified_noise = original_noise + self.latent.to(original_noise) / sigma_max.to(original_noise) return cond, mask diff --git a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py index 69550665..51f48012 100644 --- a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py +++ b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py @@ -61,14 +61,14 @@ class PreprocessorReference(Preprocessor): # This is a powerful VAE with integrated memory management, bf16, and tiled fallback. latent_image = vae.encode(cond.movedim(1, -1)) - latent_image = process.sd_model.forge_objects.unet.model.latent_format.process_in(latent_image) + latent_image = process.sd_model.forge_objects.vae.first_stage_model.process_in(latent_image) gen_seed = process.seeds[0] + 1 gen_cpu = torch.Generator().manual_seed(gen_seed) unet = process.sd_model.forge_objects.unet.clone() - sigma_max = unet.model.model_sampling.percent_to_sigma(start_percent) - sigma_min = unet.model.model_sampling.percent_to_sigma(end_percent) + sigma_max = unet.model.prediction.percent_to_sigma(start_percent) + sigma_min = unet.model.prediction.percent_to_sigma(end_percent) self.recorded_attn1 = {} self.recorded_h = {} diff --git a/extensions-builtin/forge_preprocessor_tile/scripts/preprocessor_tile.py b/extensions-builtin/forge_preprocessor_tile/scripts/preprocessor_tile.py index 083a3e8f..40e4c39a 100644 --- a/extensions-builtin/forge_preprocessor_tile/scripts/preprocessor_tile.py +++ b/extensions-builtin/forge_preprocessor_tile/scripts/preprocessor_tile.py @@ -24,7 +24,7 @@ class PreprocessorTile(Preprocessor): # This is a powerful VAE with integrated memory management, bf16, and tiled fallback. latent_image = vae.encode(cond.movedim(1, -1)) - latent_image = process.sd_model.forge_objects.unet.model.latent_format.process_in(latent_image) + latent_image = process.sd_model.forge_objects.vae.first_stage_model.process_in(latent_image) self.latent = latent_image return self.latent @@ -43,7 +43,7 @@ class PreprocessorTileColorFix(PreprocessorTile): latent = self.register_latent(process, cond) unet = process.sd_model.forge_objects.unet.clone() - sigma_data = process.sd_model.forge_objects.unet.model.model_sampling.sigma_data + sigma_data = process.sd_model.forge_objects.unet.model.prediction.sigma_data if getattr(process, 'is_hr_pass', False): k = int(self.variation * 2) diff --git a/extensions-builtin/sd_forge_dynamic_thresholding/lib_dynamic_thresholding/dynthres.py b/extensions-builtin/sd_forge_dynamic_thresholding/lib_dynamic_thresholding/dynthres.py index 9c55240f..fa286cc0 100644 --- a/extensions-builtin/sd_forge_dynamic_thresholding/lib_dynamic_thresholding/dynthres.py +++ b/extensions-builtin/sd_forge_dynamic_thresholding/lib_dynamic_thresholding/dynthres.py @@ -38,7 +38,7 @@ class DynamicThresholdingNode: cond = input - args["cond"] uncond = input - args["uncond"] cond_scale = args["cond_scale"] - time_step = model.model.model_sampling.timestep(args["sigma"]) + time_step = model.model.prediction.timestep(args["sigma"]) time_step = time_step[0].item() dynamic_thresh.step = 999 - time_step diff --git a/extensions-builtin/sd_forge_fooocus_inpaint/scripts/forge_fooocus_inpaint.py b/extensions-builtin/sd_forge_fooocus_inpaint/scripts/forge_fooocus_inpaint.py index 15d8f809..88066585 100644 --- a/extensions-builtin/sd_forge_fooocus_inpaint/scripts/forge_fooocus_inpaint.py +++ b/extensions-builtin/sd_forge_fooocus_inpaint/scripts/forge_fooocus_inpaint.py @@ -76,7 +76,7 @@ class FooocusInpaintPatcher(ControlModelPatcher): vae = process.sd_model.forge_objects.vae latent_image = vae.encode(cond_original.movedim(1, -1)) - latent_image = process.sd_model.forge_objects.unet.model.latent_format.process_in(latent_image) + latent_image = process.sd_model.forge_objects.vae.first_stage_model.process_in(latent_image) latent_mask = torch.nn.functional.max_pool2d(mask_original, (8, 8)).round().to(cond) feed = torch.cat([ latent_mask.to(device=torch.device('cpu'), dtype=torch.float32), @@ -102,8 +102,8 @@ class FooocusInpaintPatcher(ControlModelPatcher): if not_patched_count > 0: print(f"[Fooocus Patch Loader] Failed to load {not_patched_count} keys") - sigma_start = unet.model.model_sampling.percent_to_sigma(self.start_percent) - sigma_end = unet.model.model_sampling.percent_to_sigma(self.end_percent) + sigma_start = unet.model.prediction.percent_to_sigma(self.start_percent) + sigma_end = unet.model.prediction.percent_to_sigma(self.end_percent) def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed): if timestep > sigma_start or timestep < sigma_end: diff --git a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py index e35b8d60..3fe54d3a 100644 --- a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py +++ b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py @@ -760,8 +760,8 @@ class IPAdapterApply: if attn_mask is not None: attn_mask = attn_mask.to(self.device) - sigma_start = model.model.model_sampling.percent_to_sigma(start_at) - sigma_end = model.model.model_sampling.percent_to_sigma(end_at) + sigma_start = model.model.prediction.percent_to_sigma(start_at) + sigma_end = model.model.prediction.percent_to_sigma(end_at) patch_kwargs = { "number": 0, diff --git a/extensions-builtin/sd_forge_latent_modifier/lib_latent_modifier/sampler_mega_modifier.py b/extensions-builtin/sd_forge_latent_modifier/lib_latent_modifier/sampler_mega_modifier.py index f8b494e2..a5cb522f 100644 --- a/extensions-builtin/sd_forge_latent_modifier/lib_latent_modifier/sampler_mega_modifier.py +++ b/extensions-builtin/sd_forge_latent_modifier/lib_latent_modifier/sampler_mega_modifier.py @@ -919,10 +919,10 @@ class ModelSamplerLatentMegaModifier: cond = args["cond"] uncond = args["uncond"] cond_scale = args["cond_scale"] - timestep = model.model.model_sampling.timestep(args["timestep"]) + timestep = model.model.prediction.timestep(args["timestep"]) sigma = args["sigma"] sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) - #print(model.model.model_sampling.timestep(timestep)) + #print(model.model.prediction.timestep(timestep)) x = x_input / (sigma * sigma + 1.0) cond = ((x - (x_input - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 6192d7ae..46848f8c 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -285,7 +285,7 @@ class ControlNet(ControlBase): def pre_run(self, model, percent_to_timestep_function): super().pre_run(model, percent_to_timestep_function) - self.model_sampling_current = model.model_sampling + self.model_sampling_current = model.prediction def cleanup(self): self.model_sampling_current = None diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index a82c4542..9f97ee16 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -97,7 +97,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen return model, clip -from backend.clip import JointCLIP, JointTokenizer +from backend.modules.clip import JointCLIP, JointTokenizer class CLIP: diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index fb44c5a8..41e5087d 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -10,7 +10,7 @@ sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "conf config_default = shared.sd_default_config -config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") +# config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") @@ -95,10 +95,10 @@ def guess_model_config_from_state_dict(sd, filename): if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if diffusion_model_input.shape[1] == 9: return config_sd2_inpainting - elif is_using_v_parameterization_for_sd2(sd): - return config_sd2v + # elif is_using_v_parameterization_for_sd2(sd): + # return config_sd2v else: - return config_sd2 + return config_sd2v if diffusion_model_input is not None: if diffusion_model_input.shape[1] == 9: diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index 38eddf08..e60f2c78 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -8,6 +8,7 @@ from ldm_patched.modules.sd import VAE, CLIP, load_model_weights import ldm_patched.modules.model_patcher import ldm_patched.modules.utils import ldm_patched.modules.clip_vision +import backend.nn.unet from omegaconf import OmegaConf from modules.sd_models_config import find_checkpoint_config @@ -19,6 +20,7 @@ from modules_forge import forge_clip from modules_forge.unet_patcher import UnetPatcher from ldm_patched.modules.model_base import model_sampling, ModelType from backend.loader import load_huggingface_components +from backend.modules.k_model import KModel import open_clip from transformers import CLIPTextModel, CLIPTokenizer @@ -85,27 +87,20 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) + manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype - class WeightsLoader(torch.nn.Module): - pass - - model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) - model_config.set_manual_cast(manual_cast_dtype) - - if model_config is None: - raise RuntimeError("ERROR: Could not detect model type") - - if model_config.clip_vision_prefix is not None: - if output_clipvision: - clipvision = ldm_patched.modules.clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) + initial_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) + backend.nn.unet.unet_initial_device = initial_load_device + backend.nn.unet.unet_initial_dtype = unet_dtype huggingface_components = load_huggingface_components(sd) if output_model: - inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) - offload_device = model_management.unet_offload_device() - model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) - model.load_model_weights(sd, "model.diffusion_model.") + k_model = KModel(huggingface_components, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype) + k_model.to(device=initial_load_device, dtype=unet_dtype) + model_patcher = UnetPatcher(k_model, load_device=load_device, + offload_device=model_management.unet_offload_device(), + current_device=initial_load_device) if output_vae: vae = huggingface_components['vae'] @@ -118,12 +113,6 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c if len(left_over) > 0: print("left over keys:", left_over) - if output_model: - model_patcher = UnetPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) - if inital_load_device != torch.device("cpu"): - print("loaded straight to GPU") - model_management.load_model_gpu(model_patcher) - return ForgeSD(model_patcher, clip, vae, clipvision) @@ -161,7 +150,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): timer.record("forge load real models") sd_model.first_stage_model = forge_objects.vae.first_stage_model - sd_model.model.diffusion_model = forge_objects.unet.model.diffusion_model + sd_model.model.diffusion_model = forge_objects.unet.model conditioner = getattr(sd_model, 'conditioner', None) if conditioner: @@ -202,8 +191,8 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): model_embeddings.token_embedding, sd_hijack.model_hijack) sd_model.cond_stage_model = forge_clip.CLIP_SD_15_L(sd_model.cond_stage_model, sd_hijack.model_hijack) elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip - sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_h - sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_h.transformer + sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l + sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( model_embeddings.token_embedding, sd_hijack.model_hijack) @@ -216,9 +205,6 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - if getattr(sd_model, 'parameterization', None) == 'v': - sd_model.forge_objects.unet.model.model_sampling = model_sampling(sd_model.forge_objects.unet.model.model_config, ModelType.V_PREDICTION) - sd_model.is_sd3 = False sd_model.latent_channels = 4 sd_model.is_sdxl = conditioner is not None @@ -234,14 +220,14 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): @torch.inference_mode() def patched_decode_first_stage(x): - sample = sd_model.forge_objects.unet.model.model_config.latent_format.process_out(x) + sample = sd_model.forge_objects.vae.first_stage_model.process_out(x) sample = sd_model.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 return sample.to(x) @torch.inference_mode() def patched_encode_first_stage(x): sample = sd_model.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5) - sample = sd_model.forge_objects.unet.model.model_config.latent_format.process_in(sample) + sample = sd_model.forge_objects.vae.first_stage_model.process_in(sample) return sample.to(x) sd_model.ema_scope = lambda *args, **kwargs: contextlib.nullcontext() diff --git a/modules_forge/forge_sampler.py b/modules_forge/forge_sampler.py index 1b950bdf..fe3bd46a 100644 --- a/modules_forge/forge_sampler.py +++ b/modules_forge/forge_sampler.py @@ -108,7 +108,7 @@ def sampling_prepare(unet, x): real_model = unet.model - percent_to_timestep_function = lambda p: real_model.model_sampling.percent_to_sigma(p) + percent_to_timestep_function = lambda p: real_model.prediction.percent_to_sigma(p) for cnet in unet.list_controlnets(): cnet.pre_run(real_model, percent_to_timestep_function) diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index 4b1d14f8..af62b0a8 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -4,11 +4,12 @@ import torch from ldm_patched.modules.model_patcher import ModelPatcher from ldm_patched.modules.sample import convert_cond from ldm_patched.modules.samplers import encode_model_conds +from ldm_patched.modules import model_management class UnetPatcher(ModelPatcher): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, model, *args, **kwargs): + super().__init__(model, *args, **kwargs) self.controlnet_linked_list = None self.extra_preserved_memory_during_sampling = 0 self.extra_model_patchers_during_sampling = []